Expand source code
def get_attacker(attack: str):
mapping = {
AllAttacks.LOSS: LOSSAttack,
AllAttacks.REFERENCE_BASED: ReferenceAttack,
AllAttacks.ZLIB: ZLIBAttack,
AllAttacks.MIN_K: MinKProbAttack,
AllAttacks.MIN_K_PLUS_PLUS: MinKPlusPlusAttack,
AllAttacks.NEIGHBOR: NeighborhoodAttack,
AllAttacks.GRADNORM: GradNormAttack,
AllAttacks.RECALL: ReCaLLAttack,
AllAttacks.DC_PDD: DC_PDDAttack
}
attack_cls = mapping.get(attack, None)
if attack_cls is None:
raise ValueError(f"Attack {attack} not found")
return attack_cls