Module mimir.attacks.attack_utils
Utility functions for attacks
Functions
def apply_extracted_fills(masked_texts: List[str], extracted_fills)
-
Expand source code
def apply_extracted_fills(masked_texts: List[str], extracted_fills): # split masked text into tokens, only splitting on spaces (not newlines) tokens = [x.split(" ") for x in masked_texts] n_expected = count_masks(masked_texts) # replace each mask token with the corresponding fill for idx, (text, fills, n) in enumerate(zip(tokens, extracted_fills, n_expected)): if len(fills) < n: tokens[idx] = [] else: for fill_idx in range(n): text[text.index(f"<extra_id_{fill_idx}>")] = fills[fill_idx] # join tokens back into text texts = [" ".join(x) for x in tokens] return texts
def count_masks(texts)
-
Expand source code
def count_masks(texts): return [ len([x for x in text.split() if x.startswith("<extra_id_")]) for text in texts ]
def f1_score(prediction, ground_truth)
-
Expand source code
def f1_score(prediction, ground_truth): """ Compute F1 score for given prediction and ground truth. """ common = Counter(prediction) & Counter(ground_truth) num_same = sum(common.values()) if num_same == 0: return 0, 0, 0 precision = 1.0 * num_same / len(prediction) recall = 1.0 * num_same / len(ground_truth) f1 = (2 * precision * recall) / (precision + recall) print(num_same, f1, precision, recall) return f1, precision, recall
Compute F1 score for given prediction and ground truth.
def filter_out_nan(x)
-
Expand source code
def filter_out_nan(x): return [element for element in x if not math.isnan(element)]
def get_auc_from_thresholds(preds_member, preds_nonmember, thresholds)
-
Expand source code
def get_auc_from_thresholds(preds_member, preds_nonmember, thresholds): """ Compute FPRs and TPRs corresponding to given thresholds """ tpr, fpr = [], [] for threshold in thresholds: tp = np.sum(preds_nonmember >= threshold) fn = np.sum(preds_nonmember < threshold) fp = np.sum(preds_member >= threshold) tn = np.sum(preds_member < threshold) tpr.append(tp / (tp + fn)) fpr.append(fp / (fp + tn)) tpr = np.array(tpr) fpr = np.array(fpr) roc_auc = auc(fpr, tpr) return roc_auc
Compute FPRs and TPRs corresponding to given thresholds
def get_precision_recall_metrics(preds_member, preds_nonmember)
-
Expand source code
def get_precision_recall_metrics(preds_member, preds_nonmember): preds_member_ = filter_out_nan(preds_member) preds_nonmember_ = filter_out_nan(preds_nonmember) total_preds = preds_member_ + preds_nonmember_ total_labels = [0] * len(preds_member_) + [1] * len(preds_nonmember_) precision, recall, _ = precision_recall_curve(total_labels, total_preds) pr_auc = auc(recall, precision) return precision.tolist(), recall.tolist(), float(pr_auc)
def get_roc_metrics(preds_member,
preds_nonmember,
perform_bootstrap: bool = False,
return_thresholds: bool = False)-
Expand source code
def get_roc_metrics( preds_member, preds_nonmember, perform_bootstrap: bool = False, return_thresholds: bool = False, ): # fpr_list, preds_member_ = filter_out_nan(preds_member) preds_nonmember_ = filter_out_nan(preds_nonmember) total_preds = preds_member_ + preds_nonmember_ # While roc_auc is unaffected by which class we consider # positive/negative, the TPR@lowFPR calculation is. # Make sure the members are positive class (larger values, so negate the raw MIA scores) total_preds = np.array(total_preds) * -1 # Assign label '0' to members for computation, since sklearn # expectes label '0' data to have lower values to get assigned that label # which is true for our attacks (lower loss for members, e.g.) total_labels = [1] * len(preds_member_) + [0] * len(preds_nonmember_) fpr, tpr, thresholds = roc_curve(total_labels, total_preds) roc_auc = auc(fpr, tpr) # tpr_at_low_fpr = {upper_bound: tpr[np.where(np.array(fpr) < upper_bound)[0][-1]] for upper_bound in fpr_list} if perform_bootstrap: def roc_auc_statistic(preds, labels): in_preds = [pred for pred, label in zip(preds, labels) if label == 1] out_preds = [pred for pred, label in zip(preds, labels) if label == 0] _, _, roc_auc = get_roc_metrics(in_preds, out_preds) return roc_auc auc_roc_res = bootstrap( (total_preds, total_labels), roc_auc_statistic, n_resamples=1000, paired=True, ) # tpr_at_low_fpr_res = {} # for ub in fpr_list: # def tpr_at_fpr_statistic(preds, labels): # in_preds = [pred for pred, label in zip(preds, labels) if label == 1] # out_preds = [pred for pred, label in zip(preds, labels) if label == 0] # _, _, _, tpr_at_low_fpr = get_roc_metrics(in_preds, out_preds, [ub]) # return tpr_at_low_fpr[ub] # tpr_at_low_fpr_res[ub] = bootstrap((total_preds, total_labels), tpr_at_fpr_statistic, n_resamples=1000, paired=True) if return_thresholds: return ( fpr.tolist(), tpr.tolist(), float(roc_auc), auc_roc_res, thresholds.tolist(), ) return ( fpr.tolist(), tpr.tolist(), float(roc_auc), auc_roc_res, ) # tpr_at_low_fpr, tpr_at_low_fpr_res if return_thresholds: return fpr.tolist(), tpr.tolist(), float(roc_auc), thresholds.tolist() return fpr.tolist(), tpr.tolist(), float(roc_auc) # , tpr_at_low_fpr