Module mimir.attacks.recall

ReCaLL Attack: https://github.com/ruoyuxie/recall/

Classes

class ReCaLLAttack (config: ExperimentConfig, target_model: Model)
Expand source code
class ReCaLLAttack(Attack):

    #** Note: this is a suboptimal implementation of the ReCaLL attack due to necessary changes made to integrate it alongside the other attacks
    #** for a better performing version, please refer to: https://github.com/ruoyuxie/recall 
    
    def __init__(self, config: ExperimentConfig, target_model: Model):
        super().__init__(config, target_model, ref_model = None)
        self.prefix = None

    @torch.no_grad()
    def _attack(self, document, probs, tokens = None, **kwargs):        
        recall_dict: dict = kwargs.get("recall_dict", None)

        nonmember_prefix = recall_dict.get("prefix")
        num_shots = recall_dict.get("num_shots")
        avg_length = recall_dict.get("avg_length")

        assert nonmember_prefix, "nonmember_prefix should not be None or empty"
        assert num_shots, "num_shots should not be None or empty"
        assert avg_length, "avg_length should not be None or empty"

        lls = self.target_model.get_ll(document, probs = probs, tokens = tokens)
        ll_nonmember = self.get_conditional_ll(nonmember_prefix = nonmember_prefix, text = document,
                                                num_shots = num_shots, avg_length = avg_length,
                                                  tokens = tokens)
        recall = ll_nonmember / lls


        assert not np.isnan(recall)
        return recall
    
    def process_prefix(self, prefix, avg_length, total_shots):
        model = self.target_model
        tokenizer = self.target_model.tokenizer

        if self.prefix is not None:
            # We only need to process the prefix once, after that we can just return
            return self.prefix

        max_length = model.max_length
        token_counts = [len(tokenizer.encode(shot)) for shot in prefix]

        target_token_count = avg_length
        total_tokens = sum(token_counts) + target_token_count
        if total_tokens<=max_length:
            self.prefix = prefix
            return self.prefix
        # Determine the maximum number of shots that can fit within the max_length
        max_shots = 0
        cumulative_tokens = target_token_count
        for count in token_counts:
            if cumulative_tokens + count <= max_length:
                max_shots += 1
                cumulative_tokens += count
            else:
                break
        # Truncate the prefix to include only the maximum number of shots
        truncated_prefix = prefix[-max_shots:]
        print(f"""\nToo many shots used. Initial ReCaLL number of shots was {total_shots}. Maximum number of shots is {max_shots}. Defaulting to maximum number of shots.""")
        self.prefix = truncated_prefix
        return self.prefix
    
    def get_conditional_ll(self, nonmember_prefix, text, num_shots, avg_length, tokens=None):
        assert nonmember_prefix, "nonmember_prefix should not be None or empty"

        model = self.target_model
        tokenizer = self.target_model.tokenizer

        if tokens is None:
            target_encodings = tokenizer(text=text, return_tensors="pt")
        else:
            target_encodings = tokens

        processed_prefix = self.process_prefix(nonmember_prefix, avg_length, total_shots=num_shots)
        input_encodings = tokenizer(text="".join(processed_prefix), return_tensors="pt")

        prefix_ids = input_encodings.input_ids.to(model.device)
        text_ids = target_encodings.input_ids.to(model.device)

        max_length = model.max_length

        if prefix_ids.size(1) >= max_length:
            raise ValueError("Prefix length exceeds or equals the model's maximum context window.")

        labels = torch.cat((prefix_ids, text_ids), dim=1)
        total_length = labels.size(1)

        total_loss = 0
        total_tokens = 0
        with torch.no_grad():
            for i in range(0, total_length, max_length):
                begin_loc = i
                end_loc = min(i + max_length, total_length)
                trg_len = end_loc - begin_loc
                
                input_ids = labels[:, begin_loc:end_loc].to(model.device)
                target_ids = input_ids.clone()
                
                if begin_loc < prefix_ids.size(1):
                    prefix_overlap = min(prefix_ids.size(1) - begin_loc, max_length)
                    target_ids[:, :prefix_overlap] = -100
                
                if end_loc > total_length - text_ids.size(1):
                    target_overlap = min(end_loc - (total_length - text_ids.size(1)), max_length)
                    target_ids[:, -target_overlap:] = input_ids[:, -target_overlap:]
                
                if torch.all(target_ids == -100):
                    continue
                
                outputs = model.model(input_ids, labels=target_ids)
                loss = outputs.loss
                if torch.isnan(loss):
                    print(f"NaN detected in loss at iteration {i}. Non masked target_ids size is {(target_ids != -100).sum().item()}")
                    continue
                non_masked_tokens = (target_ids != -100).sum().item()
                total_loss += loss.item() * non_masked_tokens
                total_tokens += non_masked_tokens

        average_loss = total_loss / total_tokens if total_tokens > 0 else 0
        return -average_loss

Ancestors

Methods

def get_conditional_ll(self, nonmember_prefix, text, num_shots, avg_length, tokens=None)
def process_prefix(self, prefix, avg_length, total_shots)

Inherited members