Source code for OpenAttack.attack_eval.attack_eval

import sys
from typing import Any, Dict, Generator, Iterable, List, Optional, Union
import logging
from tqdm import tqdm
from ..utils import visualizer, result_visualizer, get_language, language_by_name
from .utils import worker_process, worker_init, attack_process
from ..tags import *
from ..text_process.tokenizer import Tokenizer, get_default_tokenizer
from ..victim.base import Victim
from ..attackers.base import Attacker
from ..metric import AttackMetric, MetricSelector

import multiprocessing as mp

logger = logging.getLogger(__name__)

[docs]class AttackEval:
[docs] def __init__(self, attacker : Attacker, victim : Victim, language : Optional[str] = None, tokenizer : Optional[Tokenizer] = None, invoke_limit : Optional[int] = None, metrics : List[Union[AttackMetric, MetricSelector]] = [] ): """ `AttackEval` is a class used to evaluate attack metrics in OpenAttack. Args: attacker: An attacker, must be an instance of :py:class:`.Attacker` . victim: A victim model, must be an instance of :py:class:`.Vicitm` . language: The language used for the evaluation. If is `None` then `AttackEval` will intelligently select the language based on other parameters. tokenizer: A tokenizer used for visualization. invoke_limit: Limit on the number of model invokes. metrics: A list of metrics. Each element must be an instance of :py:class:`.AttackMetric` or :py:class:`.MetricSelector` . """ if language is None: lst = [attacker] if tokenizer is not None: lst.append(tokenizer) if victim.supported_language is not None: lst.append(victim) for it in metrics: if isinstance(it, AttackMetric): lst.append(it) lang_tag = get_language(lst) else: lang_tag = language_by_name(language) if lang_tag is None: raise ValueError("Unsupported language `%s` in attack eval" % language) self._tags = { lang_tag } if tokenizer is None: self.tokenizer = get_default_tokenizer(lang_tag) else: self.tokenizer = tokenizer self.attacker = attacker self.victim = victim self.metrics = [] for it in metrics: if isinstance(it, MetricSelector): v = it.select(lang_tag) if v is None: raise RuntimeError("`%s` does not support language %s" % (it.__class__.__name__, lang_tag.name)) self.metrics.append( v ) elif isinstance(it, AttackMetric): self.metrics.append( it ) else: raise TypeError("`metrics` got %s, expect `MetricSelector` or `AttackMetric`" % it.__class__.__name__) self.invoke_limit = invoke_limit
@property def TAGS(self): return self._tags def __measure(self, data, adversarial_sample): ret = {} for it in self.metrics: value = it.after_attack(data, adversarial_sample) if value is not None: ret[it.name] = value return ret def __iter_dataset(self, dataset): for data in dataset: v = data for it in self.metrics: ret = it.before_attack(v) if ret is not None: v = ret yield v def __iter_metrics(self, iterable_result): for data, result in iterable_result: adversarial_sample, attack_time, invoke_times = result ret = { "data": data, "success": adversarial_sample is not None, "result": adversarial_sample, "metrics": { "Running Time": attack_time, "Query Exceeded": self.invoke_limit is not None and invoke_times > self.invoke_limit, "Victim Model Queries": invoke_times, ** self.__measure(data, adversarial_sample) } } yield ret
[docs] def ieval(self, dataset : Iterable[Dict[str, Any]], num_workers : int = 0, chunk_size : Optional[int] = None) -> Generator[Dict[str, Any], None, None]: """ Iterable evaluation function of `AttackEval` returns an Iterator of result. Args: dataset: An iterable dataset. num_worers: The number of processes running the attack algorithm. Default: 0 (running on the main process). chunk_size: Processing pool trunks size. Yields: A dict contains the result of each input samples. """ if num_workers > 0: ctx = mp.get_context("spawn") if chunk_size is None: chunk_size = num_workers with ctx.Pool(num_workers, initializer=worker_init, initargs=(self.attacker, self.victim, self.invoke_limit)) as pool: for ret in self.__iter_metrics(zip(dataset, pool.imap(worker_process, self.__iter_dataset(dataset), chunksize=chunk_size))): yield ret else: def result_iter(): for data in self.__iter_dataset(dataset): yield attack_process(self.attacker, self.victim, data, self.invoke_limit) for ret in self.__iter_metrics(zip(dataset, result_iter())): yield ret
[docs] def eval(self, dataset: Iterable[Dict[str, Any]], total_len : Optional[int] = None, visualize : bool = False, progress_bar : bool = False, num_workers : int = 0, chunk_size : Optional[int] = None): """ Evaluation function of `AttackEval`. Args: dataset: An iterable dataset. total_len: Total length of dataset (will be used if dataset doesn't has a `__len__` attribute). visualize: Display a pretty result for each data in the dataset. progress_bar: Display a progress bar if `True`. num_worers: The number of processes running the attack algorithm. Default: 0 (running on the main process). chunk_size: Processing pool trunks size. Returns: A dict of attack evaluation summaries. """ if hasattr(dataset, "__len__"): total_len = len(dataset) def tqdm_writer(x): return tqdm.write(x, end="") if progress_bar: result_iterator = tqdm(self.ieval(dataset, num_workers, chunk_size), total=total_len) else: result_iterator = self.ieval(dataset, num_workers, chunk_size) total_result = {} total_result_cnt = {} total_inst = 0 success_inst = 0 # Begin for for i, res in enumerate(result_iterator): total_inst += 1 success_inst += int(res["success"]) if visualize and (TAG_Classification in self.victim.TAGS): x_orig = res["data"]["x"] if res["success"]: x_adv = res["result"] if Tag("get_prob", "victim") in self.victim.TAGS: self.victim.set_context(res["data"], None) try: probs = self.victim.get_prob([x_orig, x_adv]) finally: self.victim.clear_context() y_orig = probs[0] y_adv = probs[1] elif Tag("get_pred", "victim") in self.victim.TAGS: self.victim.set_context(res["data"], None) try: preds = self.victim.get_pred([x_orig, x_adv]) finally: self.victim.clear_context() y_orig = int(preds[0]) y_adv = int(preds[1]) else: raise RuntimeError("Invalid victim model") else: y_adv = None x_adv = None if Tag("get_prob", "victim") in self.victim.TAGS: self.victim.set_context(res["data"], None) try: probs = self.victim.get_prob([x_orig]) finally: self.victim.clear_context() y_orig = probs[0] elif Tag("get_pred", "victim") in self.victim.TAGS: self.victim.set_context(res["data"], None) try: preds = self.victim.get_pred([x_orig]) finally: self.victim.clear_context() y_orig = int(preds[0]) else: raise RuntimeError("Invalid victim model") info = res["metrics"] info["Succeed"] = res["success"] if progress_bar: visualizer(i + 1, x_orig, y_orig, x_adv, y_adv, info, tqdm_writer, self.tokenizer) else: visualizer(i + 1, x_orig, y_orig, x_adv, y_adv, info, sys.stdout.write, self.tokenizer) for kw, val in res["metrics"].items(): if val is None: continue if kw not in total_result_cnt: total_result_cnt[kw] = 0 total_result[kw] = 0 total_result_cnt[kw] += 1 total_result[kw] += float(val) # End for summary = {} summary["Total Attacked Instances"] = total_inst summary["Successful Instances"] = success_inst summary["Attack Success Rate"] = success_inst / total_inst for kw in total_result_cnt.keys(): if kw in ["Succeed"]: continue if kw in ["Query Exceeded"]: summary["Total " + kw] = total_result[kw] else: summary["Avg. " + kw] = total_result[kw] / total_result_cnt[kw] if visualize: result_visualizer(summary, sys.stdout.write) return summary
## TODO generate adversarial samples