Source code for OpenAttack.attackers.classification
from typing import Any
from ..victim.classifiers.base import Classifier
from .base import Attacker
from ..attack_assist.goal import ClassifierGoal
from ..tags import *
[docs]class ClassificationAttacker(Attacker):
"""
The base class of all classification attackers.
"""
def __call__(self, victim: Classifier, input_: Any):
if not isinstance(victim, Classifier):
raise TypeError("`victim` is an instance of `%s`, but `%s` expected" % (victim.__class__.__name__, "Classifier"))
if Tag("get_pred", "victim") not in victim.TAGS:
raise AttributeError("`%s` needs victim to support `%s` method" % (self.__class__.__name__, "get_pred"))
self._victim_check(victim)
if TAG_Classification not in victim.TAGS:
raise AttributeError("Victim model `%s` must be a classifier" % victim.__class__.__name__)
if "target" in input_:
goal = ClassifierGoal(input_["target"], targeted=True)
else:
origin_x = victim.get_pred([ input_["x"] ])[0]
goal = ClassifierGoal( origin_x, targeted=False )
adversarial_sample = self.attack(victim, input_["x"], goal)
if adversarial_sample is not None:
y_adv = victim.get_pred([ adversarial_sample ])[0]
if not goal.check( adversarial_sample, y_adv ):
raise RuntimeError("Check attacker result failed: result ([%d] %s) expect (%s%d)" % ( y_adv, adversarial_sample, "" if goal.targeted else "not ", goal.target))
return adversarial_sample