Source code for OpenAttack.attackers.bert_attack

import copy
from typing import List, Optional, Union
import numpy as np
from transformers import BertConfig, BertTokenizer, BertForMaskedLM
import torch


from ..classification import ClassificationAttacker, Classifier, ClassifierGoal
from ...tags import TAG_English, Tag
from ...exceptions import WordNotInDictionaryException
from ...attack_assist.substitute.word import get_default_substitute, WordSubstitute
from ...attack_assist.filter_words import get_default_filter_words


class Feature(object):
    def __init__(self, seq_a, label):
        self.label = label
        self.seq = seq_a
        self.final_adverse = seq_a
        self.query = 0
        self.change = 0
        self.success = 0
        self.sim = 0.0
        self.changes = []

[docs]class BERTAttacker(ClassificationAttacker): @property def TAGS(self): return { self.__lang_tag, Tag("get_pred", "victim"), Tag("get_prob", "victim") }
[docs] def __init__(self, mlm_path : str = 'bert-base-uncased', k : int = 36, use_bpe : bool = True, sim_mat : Union[None, bool, WordSubstitute] = None, threshold_pred_score : float = 0.3, max_length : int = 512, device : Optional[torch.device] = None, filter_words : List[str] = None ): """ BERT-ATTACK: Adversarial Attack Against BERT Using BERT, Linyang Li, Ruotian Ma, Qipeng Guo, Xiangyang Xue, Xipeng Qiu, EMNLP2020 `[pdf] <https://arxiv.org/abs/2004.09984>`__ `[code] <https://github.com/LinyangLee/BERT-Attack>`__ Args: mlm_path: The path to the masked language model. **Default:** 'bert-base-uncased' k: The k most important words / sub-words to substitute for. **Default:** 36 use_bpe: Whether use bpe. **Default:** `True` sim_mat: Whether use cosine_similarity to filter out atonyms. Keep `None` for not using a sim_mat. threshold_pred_score: Threshold used in substitute module. **Default:** 0.3 max_length: The maximum length of an input sentence for bert. **Default:** 512 device: A computing device for bert. filter_words: A list of words that will be preserved in the attack procesudre. :Classifier Capacity: * get_pred * get_prob """ self.tokenizer_mlm = BertTokenizer.from_pretrained(mlm_path, do_lower_case=True) if device is not None: self.device = device else: self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") config_atk = BertConfig.from_pretrained(mlm_path) self.mlm_model = BertForMaskedLM.from_pretrained(mlm_path, config=config_atk).to(self.device) self.k = k self.use_bpe = use_bpe self.threshold_pred_score = threshold_pred_score self.max_length = max_length self.__lang_tag = TAG_English if filter_words is None: filter_words = get_default_filter_words(self.__lang_tag) self.filter_words = set(filter_words) if sim_mat is None or sim_mat is False: self.use_sim_mat = False else: self.use_sim_mat = True if sim_mat is True: self.substitute = get_default_substitute(self.__lang_tag) else: self.substitute = sim_mat
def attack(self, victim: Classifier, sentence, goal: ClassifierGoal): x_orig = sentence.lower() # return None tokenizer = self.tokenizer_mlm # MLM-process feature = Feature(x_orig, goal.target) words, sub_words, keys = self._tokenize(feature.seq, tokenizer) max_length = self.max_length # original label inputs = tokenizer.encode_plus(feature.seq, None, add_special_tokens=True, max_length=max_length, truncation=True) input_ids, _ = torch.tensor(inputs["input_ids"]), torch.tensor(inputs["token_type_ids"]) orig_probs = torch.Tensor(victim.get_prob([feature.seq])) orig_probs = orig_probs[0].squeeze() orig_probs = torch.softmax(orig_probs, -1) current_prob = orig_probs.max() sub_words = ['[CLS]'] + sub_words[:2] + sub_words[2:max_length - 2] + ['[SEP]'] input_ids_ = torch.tensor([tokenizer.convert_tokens_to_ids(sub_words)]) word_predictions = self.mlm_model(input_ids_.to(self.device))[0].squeeze() # seq-len(sub) vocab word_pred_scores_all, word_predictions = torch.topk(word_predictions, self.k, -1) # seq-len k word_predictions = word_predictions[1:len(sub_words) + 1, :] word_pred_scores_all = word_pred_scores_all[1:len(sub_words) + 1, :] important_scores = self.get_important_scores(words, victim, current_prob, goal.target, orig_probs) feature.query += int(len(words)) list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True) final_words = copy.deepcopy(words) for top_index in list_of_index: if feature.change > int(0.2 * (len(words))): feature.success = 1 # exceed return None tgt_word = words[top_index[0]] if tgt_word in self.filter_words: continue if keys[top_index[0]][0] > max_length - 2: continue substitutes = word_predictions[keys[top_index[0]][0]:keys[top_index[0]][1]] # L, k word_pred_scores = word_pred_scores_all[keys[top_index[0]][0]:keys[top_index[0]][1]] substitutes = self.get_substitues(substitutes, tokenizer, self.mlm_model, self.use_bpe, word_pred_scores, self.threshold_pred_score) if self.use_sim_mat: try: cfs_output = self.substitute(tgt_word) cos_sim_subtitutes = [elem[0] for elem in cfs_output] substitutes = list(set(substitutes) & set(cos_sim_subtitutes)) except WordNotInDictionaryException: pass # print("The target word is not representable by counter fitted vectors. Keeping the substitutes output by the MLM model.") most_gap = 0.0 candidate = None for substitute in substitutes: if substitute == tgt_word: continue # filter out original word if '##' in substitute: continue # filter out sub-word if substitute in self.filter_words: continue # if substitute in self.w2i and tgt_word in self.w2i: # if self.cos_mat[self.w2i[substitute]][self.w2i[tgt_word]] < 0.4: # continue temp_replace = final_words temp_replace[top_index[0]] = substitute temp_text = tokenizer.convert_tokens_to_string(temp_replace) inputs = tokenizer.encode_plus(temp_text, None, add_special_tokens=True, max_length=max_length, truncation=True) input_ids = torch.tensor(inputs["input_ids"]).unsqueeze(0).to(self.device) seq_len = input_ids.size(1) temp_prob = torch.Tensor(victim.get_prob([temp_text]))[0].squeeze() feature.query += 1 temp_prob = torch.softmax(temp_prob, -1) temp_label = torch.argmax(temp_prob) if goal.check(feature.final_adverse, temp_label): feature.change += 1 final_words[top_index[0]] = substitute feature.changes.append([keys[top_index[0]][0], substitute, tgt_word]) feature.final_adverse = temp_text feature.success = 4 return feature.final_adverse else: label_prob = temp_prob[goal.target] gap = current_prob - label_prob if gap > most_gap: most_gap = gap candidate = substitute if most_gap > 0: feature.change += 1 feature.changes.append([keys[top_index[0]][0], candidate, tgt_word]) current_prob = current_prob - most_gap final_words[top_index[0]] = candidate feature.final_adverse = (tokenizer.convert_tokens_to_string(final_words)) feature.success = 2 return None def _tokenize(self, seq, tokenizer): seq = seq.replace('\n', '').lower() words = seq.split(' ') sub_words = [] keys = [] index = 0 for word in words: sub = tokenizer.tokenize(word) sub_words += sub keys.append([index, index + len(sub)]) index += len(sub) return words, sub_words, keys def _get_masked(self, words): len_text = len(words) masked_words = [] for i in range(len_text - 1): masked_words.append(words[0:i] + ['[UNK]'] + words[i + 1:]) # list of words return masked_words def _get_masked_insert(self, words): len_text = len(words) masked_words = [] for i in range(len_text - 1): masked_words.append(words[0:i + 1] + ['[UNK]'] + words[i + 1:]) # list of words return masked_words def get_important_scores(self, words, tgt_model, orig_prob, orig_label, orig_probs): masked_words = self._get_masked(words) texts = [' '.join(words) for words in masked_words] # list of text of masked words leave_1_probs = torch.Tensor(tgt_model.get_prob(texts)) leave_1_probs = torch.softmax(leave_1_probs, -1) leave_1_probs_argmax = torch.argmax(leave_1_probs, dim=-1) import_scores = (orig_prob - leave_1_probs[:, orig_label] + (leave_1_probs_argmax != orig_label).float() * (leave_1_probs.max(dim=-1)[0] - torch.index_select(orig_probs, 0, leave_1_probs_argmax)) ).data.cpu().numpy() return import_scores def get_bpe_substitues(self, substitutes, tokenizer, mlm_model): # substitutes L, k substitutes = substitutes[0:12, 0:4] # maximum BPE candidates # find all possible candidates all_substitutes = [] for i in range(substitutes.size(0)): if len(all_substitutes) == 0: lev_i = substitutes[i] all_substitutes = [[int(c)] for c in lev_i] else: lev_i = [] for all_sub in all_substitutes: for j in substitutes[i]: lev_i.append(all_sub + [int(j)]) all_substitutes = lev_i # all substitutes list of list of token-id (all candidates) c_loss = torch.nn.CrossEntropyLoss(reduction='none') word_list = [] # all_substitutes = all_substitutes[:24] all_substitutes = torch.tensor(all_substitutes) # [ N, L ] all_substitutes = all_substitutes[:24].to(self.device) # print(substitutes.size(), all_substitutes.size()) N, L = all_substitutes.size() word_predictions = mlm_model(all_substitutes)[0] # N L vocab-size ppl = c_loss(word_predictions.view(N*L, -1), all_substitutes.view(-1)) # [ N*L ] ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N _, word_list = torch.sort(ppl) word_list = [all_substitutes[i] for i in word_list] final_words = [] for word in word_list: tokens = [tokenizer._convert_id_to_token(int(i)) for i in word] text = tokenizer.convert_tokens_to_string(tokens) final_words.append(text) return final_words def get_substitues(self, substitutes, tokenizer, mlm_model, use_bpe, substitutes_score=None, threshold=3.0): # substitues L,k # from this matrix to recover a word words = [] sub_len, k = substitutes.size() # sub-len, k if sub_len == 0: return words elif sub_len == 1: for (i,j) in zip(substitutes[0], substitutes_score[0]): if threshold != 0 and j < threshold: break words.append(tokenizer._convert_id_to_token(int(i))) else: if use_bpe == 1: words = self.get_bpe_substitues(substitutes, tokenizer, mlm_model) else: return words return words def get_sim_embed(self, embed_path, sim_path): id2word = {} word2id = {} with open(embed_path, 'r', encoding='utf-8') as ifile: for line in ifile: word = line.split()[0] if word not in id2word: id2word[len(id2word)] = word word2id[word] = len(id2word) - 1 cos_sim = np.load(sim_path) return cos_sim, word2id, id2word