Source code for OpenAttack.attack_assist.substitute.word.embed_based
from typing import Dict, Optional
from .base import WordSubstitute
from ....exceptions import WordNotInDictionaryException
import torch
from ....tags import *
DEFAULT_CONFIG = {"cosine": False}
[docs]class EmbedBasedSubstitute(WordSubstitute):
[docs] def __init__(self, word2id : Dict[str, int], embedding : torch.Tensor, cosine=False, k = 50, threshold = 0.5, device = None):
"""
Embedding based word substitute.
Args:
word2id: A `dict` maps words to indexes.
embedding: A word embedding matrix.
cosine: If `true` then the cosine distance is used, otherwise the Euclidian distance is used.
threshold: Distance threshold. Default: 0.5
k: Top-k results to return. If k is `None`, all results will be returned. Default: 50
device: A pytocrh device for computing distances. Default: "cpu"
"""
if device is None:
device = "cpu"
self.word2id = word2id
self.embedding = embedding
self.cosine = cosine
self.k = k
self.threshold = threshold
self.id2word = {
val: key for key, val in self.word2id.items()
}
if cosine:
self.embedding = self.embedding / self.embedding.norm(dim=1, keepdim=True)
self.embedding = self.embedding.to(device)
def __call__(self, word: str, pos: Optional[str] = None):
return self.substitute(word, pos)
def substitute(self, word, pos):
if word not in self.word2id:
raise WordNotInDictionaryException()
wdid = self.word2id[word]
wdvec = self.embedding[wdid, :]
if self.cosine:
dis = 1 - (wdvec * self.embedding).sum(dim=1)
else:
dis = (wdvec - self.embedding).norm(dim=1)
idx = dis.argsort()
if self.k is not None:
idx = idx[:self.k]
threshold_end = 0
while threshold_end < len(idx) and dis[idx[threshold_end]] < self.threshold:
threshold_end += 1
idx = idx[:threshold_end].tolist()
return [
(self.id2word[id_], dis[id_].item()) for id_ in idx
]