Source code for OpenAttack.victim.classifiers.base

from typing import Callable, Dict, List, Tuple
import numpy as np
from ..base import Victim
from .methods import *
from ...tags import Tag, TAG_Classification

CLASSIFIER_METHODS : Dict[str, VictimMethod] = {
    "get_pred": GetPredict(),
    "get_prob": GetProbability(),
    "get_grad": GetGradient(),
    "get_embedding": GetEmbedding()
}

[docs]class Classifier(Victim): """ Classifier is the base class of all classifiers. """ get_pred : Callable[[List[str]], np.ndarray] get_prob : Callable[[List[str]], np.ndarray] get_grad : Callable[[List[str]], Tuple[np.ndarray, np.ndarray]] def __init_subclass__(cls): invoke_funcs = [] tags = [ TAG_Classification ] for func_name in CLASSIFIER_METHODS.keys(): if hasattr(cls, func_name): invoke_funcs.append((func_name, CLASSIFIER_METHODS[func_name])) tags.append( Tag(func_name, "victim") ) setattr(cls, func_name, CLASSIFIER_METHODS[func_name].method_decorator( getattr(cls, func_name) ) ) super().__init_subclass__(invoke_funcs, tags)