mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
separate different types of word embeddings into differetn classes
This commit is contained in:
@@ -2,7 +2,6 @@ bert-score>=0.3.5
|
||||
editdistance
|
||||
flair==0.6.1.post1
|
||||
filelock
|
||||
gensim==3.8.3
|
||||
language_tool_python
|
||||
lemminflect
|
||||
lru-dict
|
||||
|
||||
1
setup.py
1
setup.py
@@ -31,6 +31,7 @@ extras["optional"] = [
|
||||
"stanza",
|
||||
"visdom",
|
||||
"wandb",
|
||||
"gensim==3.8.3",
|
||||
]
|
||||
|
||||
# For developers, install development tools along with all optional dependencies.
|
||||
|
||||
@@ -5,11 +5,11 @@
|
||||
(goal_function): UntargetedClassification
|
||||
(transformation): WordSwapEmbedding(
|
||||
(max_candidates): 50
|
||||
(embedding_type): paragramcf
|
||||
(embedding): TextAttackWordEmbedding
|
||||
)
|
||||
(constraints):
|
||||
(0): WordEmbeddingDistance(
|
||||
(embedding_type): paragramcf
|
||||
(embedding): TextAttackWordEmbedding
|
||||
(min_cos_sim): 0.5
|
||||
(cased): False
|
||||
(include_unknown_words): True
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
(goal_function): UntargetedClassification
|
||||
(transformation): WordSwapEmbedding(
|
||||
(max_candidates): 15
|
||||
(embedding_type): paragramcf
|
||||
(embedding): TextAttackWordEmbedding
|
||||
)
|
||||
(constraints):
|
||||
(0): MaxWordsPerturbed(
|
||||
@@ -11,7 +11,7 @@
|
||||
(compare_against_original): True
|
||||
)
|
||||
(1): ThoughtVector(
|
||||
(embedding_type): paragramcf
|
||||
(word_embedding): TextAttackWordEmbedding
|
||||
(metric): max_euclidean
|
||||
(threshold): -0.2
|
||||
(window_size): inf
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
(goal_function): UntargetedClassification
|
||||
(transformation): WordSwapEmbedding(
|
||||
(max_candidates): 8
|
||||
(embedding_type): paragramcf
|
||||
(embedding): TextAttackWordEmbedding
|
||||
)
|
||||
(constraints):
|
||||
(0): MaxWordsPerturbed(
|
||||
@@ -18,7 +18,7 @@
|
||||
(compare_against_original): True
|
||||
)
|
||||
(1): WordEmbeddingDistance(
|
||||
(embedding_type): paragramcf
|
||||
(embedding): TextAttackWordEmbedding
|
||||
(max_mse_dist): 0.5
|
||||
(cased): False
|
||||
(include_unknown_words): True
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
(goal_function): UntargetedClassification
|
||||
(transformation): WordSwapEmbedding(
|
||||
(max_candidates): 15
|
||||
(embedding_type): paragramcf
|
||||
(embedding): TextAttackWordEmbedding
|
||||
)
|
||||
(constraints):
|
||||
(0): BERTScore(
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
(goal_function): UntargetedClassification
|
||||
(transformation): WordSwapEmbedding(
|
||||
(max_candidates): 15
|
||||
(embedding_type): paragramcf
|
||||
(embedding): TextAttackWordEmbedding
|
||||
)
|
||||
(constraints):
|
||||
(0): RepeatModification
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
(compare_against_original): True
|
||||
)
|
||||
(1): WordEmbeddingDistance(
|
||||
(embedding_type): paragramcf
|
||||
(embedding): TextAttackWordEmbedding
|
||||
(min_cos_sim): 0.8
|
||||
(cased): False
|
||||
(include_unknown_words): True
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
(goal_function): UntargetedClassification
|
||||
(transformation): WordSwapEmbedding(
|
||||
(max_candidates): 15
|
||||
(embedding_type): paragramcf
|
||||
(embedding): TextAttackWordEmbedding
|
||||
)
|
||||
(constraints):
|
||||
(0): PartOfSpeech(
|
||||
|
||||
@@ -3,11 +3,11 @@ import os
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from textattack.shared import WordEmbedding
|
||||
from textattack.shared import GensimWordEmbedding, TextAttackWordEmbedding
|
||||
|
||||
|
||||
def test_embedding_paragramcf():
|
||||
word_embedding = WordEmbedding()
|
||||
word_embedding = TextAttackWordEmbedding.counterfitted_GLOVE_embedding()
|
||||
assert pytest.approx(word_embedding[0][0]) == -0.022007
|
||||
assert pytest.approx(word_embedding["fawn"][0]) == -0.022007
|
||||
assert word_embedding[10 ** 9] is None
|
||||
@@ -28,7 +28,7 @@ bye-bye -1 1
|
||||
"""
|
||||
)
|
||||
f.close()
|
||||
word_embedding = WordEmbedding(embedding_type=path, embedding_source="gensim")
|
||||
word_embedding = GensimWordEmbedding(path)
|
||||
assert pytest.approx(word_embedding[0][0]) == 1
|
||||
assert pytest.approx(word_embedding["bye-bye"][0]) == -1 / np.sqrt(2)
|
||||
assert word_embedding[10 ** 9] is None
|
||||
@@ -38,8 +38,8 @@ bye-bye -1 1
|
||||
# mse dist
|
||||
assert pytest.approx(word_embedding.get_mse_dist(0, 2)) == 4
|
||||
# nearest neighbour of hi is hello
|
||||
assert word_embedding.nn(0, 1)[0] == 1
|
||||
assert word_embedding.word2ind("bye") == 2
|
||||
assert word_embedding.ind2word(3) == "bye-bye"
|
||||
assert word_embedding.nearest_neighbours(0, 1)[0] == 1
|
||||
assert word_embedding.word2index("bye") == 2
|
||||
assert word_embedding.index2word(3) == "bye-bye"
|
||||
# remove test file
|
||||
os.remove(path)
|
||||
|
||||
@@ -49,11 +49,7 @@ class Kuleshov2017(AttackRecipe):
|
||||
#
|
||||
# Maximum thought vector Euclidean distance of λ_1 = 0.2. (eq. 4)
|
||||
#
|
||||
constraints.append(
|
||||
ThoughtVector(
|
||||
embedding_type="paragramcf", threshold=0.2, metric="max_euclidean"
|
||||
)
|
||||
)
|
||||
constraints.append(ThoughtVector(threshold=0.2, metric="max_euclidean"))
|
||||
#
|
||||
#
|
||||
# Maximum language model log-probability difference of λ_2 = 2. (eq. 5)
|
||||
|
||||
@@ -115,9 +115,7 @@ class EmbeddingAugmenter(Augmenter):
|
||||
def __init__(self, **kwargs):
|
||||
from textattack.transformations import WordSwapEmbedding
|
||||
|
||||
transformation = WordSwapEmbedding(
|
||||
max_candidates=50, embedding_type="paragramcf"
|
||||
)
|
||||
transformation = WordSwapEmbedding(max_candidates=50)
|
||||
from textattack.constraints.semantics import WordEmbeddingDistance
|
||||
|
||||
constraints = DEFAULT_CONSTRAINTS + [WordEmbeddingDistance(min_cos_sim=0.8)]
|
||||
|
||||
@@ -7,7 +7,7 @@ import functools
|
||||
|
||||
import torch
|
||||
|
||||
from textattack.shared import WordEmbedding, utils
|
||||
from textattack.shared import TextAttackWordEmbedding, WordEmbedding, utils
|
||||
|
||||
from .sentence_encoder import SentenceEncoder
|
||||
|
||||
@@ -16,14 +16,19 @@ class ThoughtVector(SentenceEncoder):
|
||||
"""A constraint on the distance between two sentences' thought vectors.
|
||||
|
||||
Args:
|
||||
word_embedding (str): The word embedding to use
|
||||
min_cos_sim: the minimum cosine similarity between thought vectors
|
||||
max_mse_dist: the maximum euclidean distance between thought vectors
|
||||
word_embedding (textattack.shared.WordEmbedding): The word embedding to use
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_type="paragramcf", embedding_source=None, **kwargs):
|
||||
self.word_embedding = WordEmbedding(embedding_type, embedding_source)
|
||||
self.embedding_type = embedding_type
|
||||
def __init__(
|
||||
self,
|
||||
embedding=TextAttackWordEmbedding.counterfitted_GLOVE_embedding(),
|
||||
**kwargs
|
||||
):
|
||||
if not isinstance(embedding, WordEmbedding):
|
||||
raise ValueError(
|
||||
"`embedding` object must be of type `textattack.shared.WordEmbedding`."
|
||||
)
|
||||
self.word_embedding = embedding
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def clear_cache(self):
|
||||
@@ -46,4 +51,4 @@ class ThoughtVector(SentenceEncoder):
|
||||
|
||||
def extra_repr_keys(self):
|
||||
"""Set the extra representation of the constraint using these keys."""
|
||||
return ["embedding_type"] + super().extra_repr_keys()
|
||||
return ["word_embedding"] + super().extra_repr_keys()
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""
|
||||
Word Embedding Distance
|
||||
-------------------------
|
||||
--------------------------
|
||||
"""
|
||||
|
||||
from textattack.constraints import Constraint
|
||||
from textattack.shared import TextAttackWordEmbedding, WordEmbedding
|
||||
from textattack.shared.validators import transformation_consists_of_word_swaps
|
||||
from textattack.shared.word_embedding import WordEmbedding
|
||||
|
||||
|
||||
class WordEmbeddingDistance(Constraint):
|
||||
@@ -13,24 +14,17 @@ class WordEmbeddingDistance(Constraint):
|
||||
inserted.
|
||||
|
||||
Args:
|
||||
embedding_type (str): The word embedding to use.
|
||||
embedding_source (str): The source of embeddings ("defaults" or "gensim")
|
||||
include_unknown_words (bool): Whether or not the constraint is fulfilled
|
||||
if the embedding of x or x_adv is unknown.
|
||||
min_cos_sim: The minimum cosine similarity between word embeddings.
|
||||
max_mse_dist: The maximum euclidean distance between word embeddings.
|
||||
cased (bool): Whether embedding supports uppercase & lowercase
|
||||
(defaults to False, or just lowercase).
|
||||
compare_against_original (bool): If `True`, compare new `x_adv` against the original `x`.
|
||||
Otherwise, compare it against the previous `x_adv`.
|
||||
embedding (obj): Wrapper for word embedding.
|
||||
include_unknown_words (bool): Whether or not the constraint is fulfilled if the embedding of x or x_adv is unknown.
|
||||
min_cos_sim (:obj:`float`, optional): The minimum cosine similarity between word embeddings.
|
||||
max_mse_dist (:obj:`float`, optional): The maximum euclidean distance between word embeddings.
|
||||
cased (bool): Whether embedding supports uppercase & lowercase (defaults to False, or just lowercase).
|
||||
compare_against_original (bool): If `True`, compare new `x_adv` against the original `x`. Otherwise, compare it against the previous `x_adv`.
|
||||
"""
|
||||
|
||||
PATH = "word_embeddings"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_type="paragramcf",
|
||||
embedding_source=None,
|
||||
embedding=TextAttackWordEmbedding.counterfitted_GLOVE_embedding(),
|
||||
include_unknown_words=True,
|
||||
min_cos_sim=None,
|
||||
max_mse_dist=None,
|
||||
@@ -40,14 +34,17 @@ class WordEmbeddingDistance(Constraint):
|
||||
super().__init__(compare_against_original)
|
||||
self.include_unknown_words = include_unknown_words
|
||||
self.cased = cased
|
||||
|
||||
if bool(min_cos_sim) == bool(max_mse_dist):
|
||||
raise ValueError("You must choose either `min_cos_sim` or `max_mse_dist`.")
|
||||
self.min_cos_sim = min_cos_sim
|
||||
self.max_mse_dist = max_mse_dist
|
||||
|
||||
self.embedding_type = embedding_type
|
||||
self.embedding_source = embedding_source
|
||||
self.embedding = WordEmbedding(
|
||||
embedding_type=embedding_type, embedding_source=embedding_source
|
||||
)
|
||||
if not isinstance(embedding, WordEmbedding):
|
||||
raise ValueError(
|
||||
"`embedding` object must be of type `textattack.shared.WordEmbedding`."
|
||||
)
|
||||
self.embedding = embedding
|
||||
|
||||
def get_cos_sim(self, a, b):
|
||||
"""Returns the cosine similarity of words with IDs a and b."""
|
||||
@@ -59,7 +56,7 @@ class WordEmbeddingDistance(Constraint):
|
||||
|
||||
def _check_constraint(self, transformed_text, reference_text):
|
||||
"""Returns true if (``transformed_text`` and ``reference_text``) are
|
||||
closer than ``self.min_cos_sim`` and ``self.max_mse_dist``."""
|
||||
closer than ``self.min_cos_sim`` or ``self.max_mse_dist``."""
|
||||
try:
|
||||
indices = transformed_text.attack_attrs["newly_modified_indices"]
|
||||
except KeyError:
|
||||
@@ -77,8 +74,8 @@ class WordEmbeddingDistance(Constraint):
|
||||
transformed_word = transformed_word.lower()
|
||||
|
||||
try:
|
||||
ref_id = self.embedding.word2ind(ref_word)
|
||||
transformed_id = self.embedding.word2ind(transformed_word)
|
||||
ref_id = self.embedding.word2index(ref_word)
|
||||
transformed_id = self.embedding.word2index(transformed_word)
|
||||
except KeyError:
|
||||
# This error is thrown if x or x_adv has no corresponding ID.
|
||||
if self.include_unknown_words:
|
||||
@@ -116,7 +113,7 @@ class WordEmbeddingDistance(Constraint):
|
||||
else:
|
||||
metric = "min_cos_sim"
|
||||
return [
|
||||
"embedding_type",
|
||||
"embedding",
|
||||
metric,
|
||||
"cased",
|
||||
"include_unknown_words",
|
||||
|
||||
@@ -13,6 +13,6 @@ from .utils import logger
|
||||
from . import validators
|
||||
|
||||
from .attacked_text import AttackedText
|
||||
from .word_embedding import WordEmbedding
|
||||
from .word_embedding import *
|
||||
from .attack import Attack
|
||||
from .checkpoint import Checkpoint
|
||||
|
||||
@@ -49,15 +49,18 @@ def words_from_text(s, words_to_ignore=[]):
|
||||
|
||||
|
||||
def default_class_repr(self):
|
||||
extra_params = []
|
||||
for key in self.extra_repr_keys():
|
||||
extra_params.append(" (" + key + ")" + ": {" + key + "}")
|
||||
if len(extra_params):
|
||||
extra_str = "\n" + "\n".join(extra_params) + "\n"
|
||||
extra_str = f"({extra_str})"
|
||||
if hasattr(self, "extra_repr_keys"):
|
||||
extra_params = []
|
||||
for key in self.extra_repr_keys():
|
||||
extra_params.append(" (" + key + ")" + ": {" + key + "}")
|
||||
if len(extra_params):
|
||||
extra_str = "\n" + "\n".join(extra_params) + "\n"
|
||||
extra_str = f"({extra_str})"
|
||||
else:
|
||||
extra_str = ""
|
||||
extra_str = extra_str.format(**self.__dict__)
|
||||
else:
|
||||
extra_str = ""
|
||||
extra_str = extra_str.format(**self.__dict__)
|
||||
return f"{self.__class__.__name__}{extra_str}"
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ Shared loads word embeddings and related distances
|
||||
=====================================================
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
import os
|
||||
import pickle
|
||||
@@ -10,77 +11,260 @@ import pickle
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import textattack
|
||||
from textattack.shared import utils
|
||||
|
||||
|
||||
class WordEmbedding:
|
||||
"""An object that loads word embeddings and related distances.
|
||||
class WordEmbedding(ABC):
|
||||
"""Abstract class representing word embedding used by TextAttack.
|
||||
|
||||
This class specifies all the methods that is required to be defined
|
||||
so that it can be used for transformation and constraints. For
|
||||
custom word embedding not supported by TextAttack, please create a
|
||||
class that inherits this object and implement the required methods.
|
||||
However, please first check if you can use TextAttackWordEmbedding
|
||||
object, which has a lot of internal methods implemented.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, index):
|
||||
"""Gets the embedding vector for word/id
|
||||
Args:
|
||||
index (Union[str|int]): `index` can either be word or integer representing the id of the word.
|
||||
Returns:
|
||||
vector (ndarray): 1-D embedding vector. If corresponding vector cannot be found for `index`, returns `None`.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_mse_dist(self, a, b):
|
||||
"""Return MSE (i.e. L2-norm) distance between vector for word `a` and
|
||||
vector for word `b`.
|
||||
|
||||
Since this is a metric, `get_mse_dist(a,b)` and `get_mse_dist(b,a)` should return the same value.
|
||||
Args:
|
||||
a (Union[str|int]): Either word or integer presenting the id of the word
|
||||
b (Union[str|int]): Either word or integer presenting the id of the word
|
||||
Returns:
|
||||
distance (float): MSE (L2) distance
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_cos_sim(self, a, b):
|
||||
"""Return cosine similarity between vector for word `a` and vector for
|
||||
word `b`.
|
||||
|
||||
Since this is a metric, `get_mse_dist(a,b)` and `get_mse_dist(b,a)` should return the same value.
|
||||
Args:
|
||||
a (Union[str|int]): Either word or integer presenting the id of the word
|
||||
b (Union[str|int]): Either word or integer presenting the id of the word
|
||||
Returns:
|
||||
distance (float): cosine similarity
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def word2index(self, word):
|
||||
"""
|
||||
Convert between word to id (i.e. index of word in embedding matrix)
|
||||
Args:
|
||||
word (str)
|
||||
Returns:
|
||||
index (int)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def index2word(self, index):
|
||||
"""
|
||||
Convert index to corresponding word
|
||||
Args:
|
||||
index (int)
|
||||
Returns:
|
||||
word (str)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def nearest_neighbours(self, index, topn):
|
||||
"""
|
||||
Get top-N nearest neighbours for a word
|
||||
Args:
|
||||
index (int): ID of the word for which we're finding the nearest neighbours
|
||||
topn (int): Used for specifying N nearest neighbours
|
||||
Returns:
|
||||
neighbours (list[int]): List of indices of the nearest neighbours
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
__repr__ = __str__ = utils.default_class_repr
|
||||
|
||||
|
||||
class TextAttackWordEmbedding(WordEmbedding):
|
||||
"""An object that loads word embeddings and related distances for
|
||||
TextAttack This has a lot of internal components (e.g. get consine
|
||||
similarity) implemented. Consider using this class if you can provide the
|
||||
appropriate input data to create the object.
|
||||
|
||||
Args:
|
||||
embedding_type (str): The type of the embedding to load automatically
|
||||
embedding_source (str): Source of embeddings provided,
|
||||
"defaults" corresponds to the textattack s3 bucket
|
||||
"gensim" expects a word2vec model
|
||||
emedding_matrix (ndarray): 2-D array of NxD where N represents size of vocab and D is the dimension of embedding vectors.
|
||||
word2index (Union[dict|object]): dictionary (or a similar object) that maps word to its index with in the embedding matrix.
|
||||
index2word (Union[dict|object]): dictionary (or a similar object) that maps index to its word.
|
||||
nn_matrix (ndarray): Matrix for precomputed nearest neighbours matrix. It should be a 2-D integer array of NxK
|
||||
where N represents size of vocab and K is the top-K nearest neighbours. If this is not defined, we have to compute nearest neighbours
|
||||
on the fly for `nearest_neighbours` method, which is costly.
|
||||
"""
|
||||
|
||||
PATH = "word_embeddings"
|
||||
EMBEDDINGS_AVAILABLE_IN_TEXTATTACK = {"paragramcf"}
|
||||
|
||||
def __init__(self, embedding_type=None, embedding_source=None):
|
||||
def __init__(self, embedding_matrix, word2index, index2word, nn_matrix=None):
|
||||
self.embedding_matrix = embedding_matrix
|
||||
self._word2index = word2index
|
||||
self._index2word = index2word
|
||||
self.nn_matrix = nn_matrix
|
||||
|
||||
self.embedding_type = embedding_type or "paragramcf"
|
||||
self.embedding_source = embedding_source or "defaults"
|
||||
# Dictionary for caching results
|
||||
self._mse_dist_mat = defaultdict(dict)
|
||||
self._cos_sim_mat = defaultdict(dict)
|
||||
self._nn_cache = {}
|
||||
|
||||
if self.embedding_source == "defaults":
|
||||
if (
|
||||
self.embedding_type
|
||||
not in WordEmbedding.EMBEDDINGS_AVAILABLE_IN_TEXTATTACK
|
||||
):
|
||||
raise ValueError(
|
||||
f"{self.embedding_type} is not available in TextAttack."
|
||||
)
|
||||
elif self.embedding_source not in ["gensim"]:
|
||||
raise ValueError(f"{self.embedding_source} type is not supported.")
|
||||
|
||||
self._embeddings = None
|
||||
self._word2index = None
|
||||
self._index2word = None
|
||||
self._cos_sim_mat = None
|
||||
self._mse_dist_mat = None
|
||||
self._nn = None
|
||||
|
||||
self._gensim_keyed_vectors = None
|
||||
|
||||
self._init_embeddings_from_type(self.embedding_source, self.embedding_type)
|
||||
|
||||
def _init_embeddings_from_defaults(self, embedding_type):
|
||||
"""
|
||||
Init embeddings prepared in the textattack s3 bucket
|
||||
def __getitem__(self, index):
|
||||
"""Gets the embedding vector for word/id
|
||||
Args:
|
||||
embedding_type:
|
||||
|
||||
index (Union[str|int]): `index` can either be word or integer representing the id of the word.
|
||||
Returns:
|
||||
vector (ndarray): 1-D embedding vector. If corresponding vector cannot be found for `index`, returns `None`.
|
||||
"""
|
||||
if isinstance(index, str):
|
||||
try:
|
||||
index = self._word2index[index]
|
||||
except KeyError:
|
||||
return None
|
||||
try:
|
||||
return self.embedding_matrix[index]
|
||||
except IndexError:
|
||||
# word embedding ID out of bounds
|
||||
return None
|
||||
|
||||
def word2index(self, word):
|
||||
"""
|
||||
Convert between word to id (i.e. index of word in embedding matrix)
|
||||
Args:
|
||||
word (str)
|
||||
Returns:
|
||||
index (int)
|
||||
"""
|
||||
return self._word2index[word]
|
||||
|
||||
def index2word(self, index):
|
||||
"""
|
||||
Convert index to corresponding word
|
||||
Args:
|
||||
index (int)
|
||||
Returns:
|
||||
word (str)
|
||||
|
||||
"""
|
||||
if embedding_type == "paragramcf":
|
||||
word_embeddings_folder = "paragramcf"
|
||||
word_embeddings_file = "paragram.npy"
|
||||
word_list_file = "wordlist.pickle"
|
||||
mse_dist_file = "mse_dist.p"
|
||||
cos_sim_file = "cos_sim.p"
|
||||
nn_matrix_file = "nn.npy"
|
||||
return self._index2word[index]
|
||||
|
||||
def get_mse_dist(self, a, b):
|
||||
"""Return MSE (i.e. L2-norm) distance between vector for word `a` and
|
||||
vector for word `b`.
|
||||
|
||||
Since this is a metric, `get_mse_dist(a,b)` and `get_mse_dist(b,a)` should return the same value.
|
||||
Args:
|
||||
a (Union[str|int]): Either word or integer presenting the id of the word
|
||||
b (Union[str|int]): Either word or integer presenting the id of the word
|
||||
Returns:
|
||||
distance (float): MSE (L2) distance
|
||||
"""
|
||||
if isinstance(a, str):
|
||||
a = self._word2index[a]
|
||||
if isinstance(b, str):
|
||||
b = self._word2index[b]
|
||||
a, b = min(a, b), max(a, b)
|
||||
try:
|
||||
mse_dist = self._mse_dist_mat[a][b]
|
||||
except KeyError:
|
||||
e1 = self.embedding_matrix[a]
|
||||
e2 = self.embedding_matrix[b]
|
||||
e1 = torch.tensor(e1).to(utils.device)
|
||||
e2 = torch.tensor(e2).to(utils.device)
|
||||
mse_dist = torch.sum((e1 - e2) ** 2).item()
|
||||
self._mse_dist_mat[a][b] = mse_dist
|
||||
|
||||
return mse_dist
|
||||
|
||||
def get_cos_sim(self, a, b):
|
||||
"""Return cosine similarity between vector for word `a` and vector for
|
||||
word `b`.
|
||||
|
||||
Since this is a metric, `get_mse_dist(a,b)` and `get_mse_dist(b,a)` should return the same value.
|
||||
Args:
|
||||
a (Union[str|int]): Either word or integer presenting the id of the word
|
||||
b (Union[str|int]): Either word or integer presenting the id of the word
|
||||
Returns:
|
||||
distance (float): cosine similarity
|
||||
"""
|
||||
if isinstance(a, str):
|
||||
a = self._word2index[a]
|
||||
if isinstance(b, str):
|
||||
b = self._word2index[b]
|
||||
a, b = min(a, b), max(a, b)
|
||||
try:
|
||||
cos_sim = self._cos_sim_mat[a][b]
|
||||
except KeyError:
|
||||
e1 = self.embedding_matrix[a]
|
||||
e2 = self.embedding_matrix[b]
|
||||
e1 = torch.tensor(e1).to(utils.device)
|
||||
e2 = torch.tensor(e2).to(utils.device)
|
||||
cos_sim = torch.nn.CosineSimilarity(dim=0)(e1, e2).item()
|
||||
self._cos_sim_mat[a][b] = cos_sim
|
||||
return cos_sim
|
||||
|
||||
def nearest_neighbours(self, index, topn):
|
||||
"""
|
||||
Get top-N nearest neighbours for a word
|
||||
Args:
|
||||
index (int): ID of the word for which we're finding the nearest neighbours
|
||||
topn (int): Used for specifying N nearest neighbours
|
||||
Returns:
|
||||
neighbours (list[int]): List of indices of the nearest neighbours
|
||||
"""
|
||||
if isinstance(index, str):
|
||||
index = self._word2index[index]
|
||||
if self.nn_matrix is not None:
|
||||
nn = self.nn_matrix[index][1 : (topn + 1)]
|
||||
else:
|
||||
raise ValueError(f"Could not find word embedding {embedding_type}")
|
||||
try:
|
||||
nn = self._nn_cache[index]
|
||||
except KeyError:
|
||||
embedding = torch.tensor(self.embedding_matrix).to(utils.device)
|
||||
vector = torch.tensor(self.embedding_matrix[index]).to(utils.device)
|
||||
dist = torch.norm(embedding - vector, dim=1, p=None)
|
||||
# Since closest neighbour will be the same word, we consider N+1 nearest neighbours
|
||||
nn = dist.topk(topn + 1, largest=False)[1:].tolist()
|
||||
self._nn_cache[index] = nn
|
||||
|
||||
return nn
|
||||
|
||||
@staticmethod
|
||||
def counterfitted_GLOVE_embedding():
|
||||
"""Returns a prebuilt counter-fitted GLOVE word embedding proposed by
|
||||
"Counter-fitting Word Vectors to Linguistic Constraints" (Mrkšić et
|
||||
al., 2016)"""
|
||||
word_embeddings_folder = "paragramcf"
|
||||
word_embeddings_file = "paragram.npy"
|
||||
word_list_file = "wordlist.pickle"
|
||||
mse_dist_file = "mse_dist.p"
|
||||
cos_sim_file = "cos_sim.p"
|
||||
nn_matrix_file = "nn.npy"
|
||||
|
||||
# Download embeddings if they're not cached.
|
||||
word_embeddings_folder = os.path.join(
|
||||
WordEmbedding.PATH, word_embeddings_folder
|
||||
TextAttackWordEmbedding.PATH, word_embeddings_folder
|
||||
)
|
||||
word_embeddings_folder = textattack.shared.utils.download_if_needed(
|
||||
word_embeddings_folder
|
||||
)
|
||||
|
||||
word_embeddings_folder = utils.download_if_needed(word_embeddings_folder)
|
||||
# Concatenate folder names to create full path to files.
|
||||
word_embeddings_file = os.path.join(
|
||||
word_embeddings_folder, word_embeddings_file
|
||||
@@ -91,234 +275,152 @@ class WordEmbedding:
|
||||
nn_matrix_file = os.path.join(word_embeddings_folder, nn_matrix_file)
|
||||
|
||||
# loading the files
|
||||
self._embeddings = np.load(word_embeddings_file)
|
||||
self._word2index = np.load(word_list_file, allow_pickle=True)
|
||||
if os.path.exists(mse_dist_file):
|
||||
with open(mse_dist_file, "rb") as f:
|
||||
self._mse_dist_mat = pickle.load(f)
|
||||
else:
|
||||
self._mse_dist_mat = defaultdict(dict)
|
||||
if os.path.exists(cos_sim_file):
|
||||
with open(cos_sim_file, "rb") as f:
|
||||
self._cos_sim_mat = pickle.load(f)
|
||||
else:
|
||||
self._cos_sim_mat = defaultdict(dict)
|
||||
embedding_matrix = np.load(word_embeddings_file)
|
||||
word2index = np.load(word_list_file, allow_pickle=True)
|
||||
index2word = {}
|
||||
for word, index in word2index.items():
|
||||
index2word[index] = word
|
||||
nn_matrix = np.load(nn_matrix_file)
|
||||
|
||||
self._nn = np.load(nn_matrix_file)
|
||||
word_embedding = TextAttackWordEmbedding(
|
||||
embedding_matrix, word2index, index2word, nn_matrix
|
||||
)
|
||||
with open(mse_dist_file, "rb") as f:
|
||||
mse_dist_mat = pickle.load(f)
|
||||
with open(cos_sim_file, "rb") as f:
|
||||
cos_sim_mat = pickle.load(f)
|
||||
|
||||
# Build glove dict and index.
|
||||
self._index2word = dict()
|
||||
for word, index in self._word2index.items():
|
||||
self._index2word[index] = word
|
||||
word_embedding._mse_dist_mat = mse_dist_mat
|
||||
word_embedding._cos_sim_mat = cos_sim_mat
|
||||
|
||||
def _init_embeddings_from_gensim(self, embedding_type):
|
||||
"""
|
||||
Initialize word embedding from a gensim word2vec model
|
||||
Args:
|
||||
embedding_type:
|
||||
return word_embedding
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
import gensim
|
||||
class GensimWordEmbedding(WordEmbedding):
|
||||
"""Wraps Gensim's KeyedVectors
|
||||
(https://radimrehurek.com/gensim/models/keyedvectors.html)"""
|
||||
|
||||
if embedding_type.endswith(".bin"):
|
||||
self._gensim_keyed_vectors = (
|
||||
gensim.models.KeyedVectors.load_word2vec_format(
|
||||
embedding_type, binary=True
|
||||
def __init__(self, keyed_vectors_or_path):
|
||||
gensim = utils.LazyLoader("gensim", globals(), "gensim")
|
||||
|
||||
if isinstance(keyed_vectors_or_path, str):
|
||||
if keyed_vectors_or_path.endswith(".bin"):
|
||||
self.keyed_vectors = gensim.models.KeyedVectors.load_word2vec_format(
|
||||
keyed_vectors_or_path, binary=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.keyed_vectors = gensim.models.KeyedVectors.load_word2vec_format(
|
||||
keyed_vectors_or_path
|
||||
)
|
||||
elif isinstance(keyed_vectors_or_path, gensim.models.KeyedVectors):
|
||||
self.keyed_vectors = keyed_vectors_or_path
|
||||
else:
|
||||
self._gensim_keyed_vectors = (
|
||||
gensim.models.KeyedVectors.load_word2vec_format(embedding_type)
|
||||
raise ValueError(
|
||||
"`keyed_vectors_or_path` argument must either be `gensim.models.KeyedVectors` object "
|
||||
"or a path pointing to the saved KeyedVector object"
|
||||
)
|
||||
self._gensim_keyed_vectors.init_sims()
|
||||
|
||||
def _init_embeddings_from_type(self, embedding_source, embedding_type):
|
||||
"""Initializes embedding based on the source.
|
||||
|
||||
Downloads and loads embeddings into memory.
|
||||
"""
|
||||
|
||||
if embedding_source == "defaults":
|
||||
self._init_embeddings_from_defaults(embedding_type)
|
||||
elif embedding_source == "gensim":
|
||||
self._init_embeddings_from_gensim(embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Not supported word embedding source {embedding_source}")
|
||||
self.keyed_vectors.init_sims()
|
||||
self._mse_dist_mat = defaultdict(dict)
|
||||
self._cos_sim_mat = defaultdict(dict)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Gets a word embedding by word or ID.
|
||||
|
||||
If word or ID not found, returns None.
|
||||
"""
|
||||
if self.embedding_source == "defaults":
|
||||
if isinstance(index, str):
|
||||
try:
|
||||
index = self._word2index[index]
|
||||
except KeyError:
|
||||
return None
|
||||
try:
|
||||
return self._embeddings[index]
|
||||
except IndexError:
|
||||
# word embedding ID out of bounds
|
||||
return None
|
||||
elif self.embedding_source == "gensim":
|
||||
if isinstance(index, str):
|
||||
try:
|
||||
index = self._gensim_keyed_vectors.vocab.get(index).index
|
||||
except KeyError:
|
||||
return None
|
||||
try:
|
||||
return self._gensim_keyed_vectors.vectors_norm[index]
|
||||
except IndexError:
|
||||
# word embedding ID out of bounds
|
||||
return None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Not supported word embedding source {self.embedding_source}"
|
||||
)
|
||||
|
||||
def get_cos_sim(self, a, b):
|
||||
"""
|
||||
get cosine similarity of two words/IDs
|
||||
"""Gets the embedding vector for word/id
|
||||
Args:
|
||||
a:
|
||||
b:
|
||||
|
||||
index (Union[str|int]): `index` can either be word or integer representing the id of the word.
|
||||
Returns:
|
||||
vector (ndarray): 1-D embedding vector. If corresponding vector cannot be found for `index`, returns `None`.
|
||||
"""
|
||||
if isinstance(index, str):
|
||||
try:
|
||||
index = self.keyed_vectors.vocab.get(index).index
|
||||
except KeyError:
|
||||
return None
|
||||
try:
|
||||
return self.keyed_vectors.vectors_norm[index]
|
||||
except IndexError:
|
||||
# word embedding ID out of bounds
|
||||
return None
|
||||
|
||||
def word2index(self, word):
|
||||
"""
|
||||
Convert between word to id (i.e. index of word in embedding matrix)
|
||||
Args:
|
||||
word (str)
|
||||
Returns:
|
||||
index (int)
|
||||
"""
|
||||
vocab = self.keyed_vectors.vocab.get(word)
|
||||
if vocab is None:
|
||||
raise KeyError(word)
|
||||
return vocab.index
|
||||
|
||||
def index2word(self, index):
|
||||
"""
|
||||
Convert index to corresponding word
|
||||
Args:
|
||||
index (int)
|
||||
Returns:
|
||||
word (str)
|
||||
|
||||
"""
|
||||
if self.embedding_source == "defaults":
|
||||
if isinstance(a, str):
|
||||
a = self._word2index[a]
|
||||
if isinstance(b, str):
|
||||
b = self._word2index[b]
|
||||
a, b = min(a, b), max(a, b)
|
||||
try:
|
||||
cos_sim = self._cos_sim_mat[a][b]
|
||||
except KeyError:
|
||||
e1 = self._embeddings[a]
|
||||
e2 = self._embeddings[b]
|
||||
e1 = torch.tensor(e1).to(utils.device)
|
||||
e2 = torch.tensor(e2).to(utils.device)
|
||||
cos_sim = torch.nn.CosineSimilarity(dim=0)(e1, e2).item()
|
||||
self._cos_sim_mat[a][b] = cos_sim
|
||||
return cos_sim
|
||||
elif self.embedding_source == "gensim":
|
||||
if not isinstance(a, str):
|
||||
a = self._gensim_keyed_vectors.index2word[a]
|
||||
if not isinstance(b, str):
|
||||
b = self._gensim_keyed_vectors.index2word[b]
|
||||
cos_sim = self._gensim_keyed_vectors.similarity(a, b)
|
||||
return cos_sim
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Not supported word embedding source {self.embedding_source}"
|
||||
)
|
||||
try:
|
||||
# this is a list, so the error would be IndexError
|
||||
return self.keyed_vectors.index2word[index]
|
||||
except IndexError:
|
||||
raise KeyError(index)
|
||||
|
||||
def get_mse_dist(self, a, b):
|
||||
"""
|
||||
get mse distance of two IDs
|
||||
"""Return MSE (i.e. L2-norm) distance between vector for word `a` and
|
||||
vector for word `b`.
|
||||
|
||||
Since this is a metric, `get_mse_dist(a,b)` and `get_mse_dist(b,a)` should return the same value.
|
||||
Args:
|
||||
a:
|
||||
b:
|
||||
|
||||
a (Union[str|int]): Either word or integer presenting the id of the word
|
||||
b (Union[str|int]): Either word or integer presenting the id of the word
|
||||
Returns:
|
||||
|
||||
distance (float): MSE (L2) distance
|
||||
"""
|
||||
if self.embedding_source == "defaults":
|
||||
a, b = min(a, b), max(a, b)
|
||||
try:
|
||||
mse_dist = self._mse_dist_mat[a][b]
|
||||
except KeyError:
|
||||
e1 = self._embeddings[a]
|
||||
e2 = self._embeddings[b]
|
||||
e1 = torch.tensor(e1).to(utils.device)
|
||||
e2 = torch.tensor(e2).to(utils.device)
|
||||
mse_dist = torch.sum((e1 - e2) ** 2).item()
|
||||
self._mse_dist_mat[a][b] = mse_dist
|
||||
return mse_dist
|
||||
elif self.embedding_source == "gensim":
|
||||
if self._mse_dist_mat is None:
|
||||
self._mse_dist_mat = defaultdict(dict)
|
||||
try:
|
||||
mse_dist = self._mse_dist_mat[a][b]
|
||||
except KeyError:
|
||||
e1 = self._gensim_keyed_vectors.vectors_norm[a]
|
||||
e2 = self._gensim_keyed_vectors.vectors_norm[b]
|
||||
e1 = torch.tensor(e1).to(utils.device)
|
||||
e2 = torch.tensor(e2).to(utils.device)
|
||||
mse_dist = torch.sum((e1 - e2) ** 2).item()
|
||||
self._mse_dist_mat[a][b] = mse_dist
|
||||
return mse_dist
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Not supported word embedding source {self.embedding_source}"
|
||||
)
|
||||
try:
|
||||
mse_dist = self._mse_dist_mat[a][b]
|
||||
except KeyError:
|
||||
e1 = self.keyed_vectors.vectors_norm[a]
|
||||
e2 = self.keyed_vectors.vectors_norm[b]
|
||||
e1 = torch.tensor(e1).to(utils.device)
|
||||
e2 = torch.tensor(e2).to(utils.device)
|
||||
mse_dist = torch.sum((e1 - e2) ** 2).item()
|
||||
self._mse_dist_mat[a][b] = mse_dist
|
||||
return mse_dist
|
||||
|
||||
def word2ind(self, word):
|
||||
"""
|
||||
word to index
|
||||
def get_cos_sim(self, a, b):
|
||||
"""Return cosine similarity between vector for word `a` and vector for
|
||||
word `b`.
|
||||
|
||||
Since this is a metric, `get_mse_dist(a,b)` and `get_mse_dist(b,a)` should return the same value.
|
||||
Args:
|
||||
word:
|
||||
|
||||
a (Union[str|int]): Either word or integer presenting the id of the word
|
||||
b (Union[str|int]): Either word or integer presenting the id of the word
|
||||
Returns:
|
||||
|
||||
distance (float): cosine similarity
|
||||
"""
|
||||
if self.embedding_source == "defaults":
|
||||
return self._word2index[word]
|
||||
elif self.embedding_source == "gensim":
|
||||
vocab = self._gensim_keyed_vectors.vocab.get(word)
|
||||
if vocab is None:
|
||||
raise KeyError(word)
|
||||
return vocab.index
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Not supported word embedding source {self.embedding_source}"
|
||||
)
|
||||
if not isinstance(a, str):
|
||||
a = self.keyed_vectors.index2word[a]
|
||||
if not isinstance(b, str):
|
||||
b = self.keyed_vectors.index2word[b]
|
||||
cos_sim = self.keyed_vectors.similarity(a, b)
|
||||
return cos_sim
|
||||
|
||||
def ind2word(self, index):
|
||||
def nearest_neighbours(self, index, topn, return_words=True):
|
||||
"""
|
||||
index to word
|
||||
Get top-N nearest neighbours for a word
|
||||
Args:
|
||||
index:
|
||||
|
||||
index (int): ID of the word for which we're finding the nearest neighbours
|
||||
topn (int): Used for specifying N nearest neighbours
|
||||
Returns:
|
||||
|
||||
neighbours (list[int]): List of indices of the nearest neighbours
|
||||
"""
|
||||
if self.embedding_source == "defaults":
|
||||
return self._index2word[index]
|
||||
elif self.embedding_source == "gensim":
|
||||
try:
|
||||
# this is a list, so the error would be IndexError
|
||||
return self._gensim_keyed_vectors.index2word[index]
|
||||
except IndexError:
|
||||
raise KeyError(index)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Not supported word embedding source {self.embedding_source}"
|
||||
)
|
||||
|
||||
def nn(self, index, topn):
|
||||
"""
|
||||
get top n nearest neighbours for a word
|
||||
Args:
|
||||
index:
|
||||
topn:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if self.embedding_source == "defaults":
|
||||
return self._nn[index][1 : (topn + 1)]
|
||||
elif self.embedding_source == "gensim":
|
||||
word = self._gensim_keyed_vectors.index2word[index]
|
||||
return [
|
||||
self._gensim_keyed_vectors.index2word.index(i[0])
|
||||
for i in self._gensim_keyed_vectors.similar_by_word(word, topn)
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Not supported word embedding source {self.embedding_source}"
|
||||
)
|
||||
word = self.keyed_vectors.index2word[index]
|
||||
return [
|
||||
self.keyed_vectors.index2word.index(i[0])
|
||||
for i in self.keyed_vectors.similar_by_word(word, topn)
|
||||
]
|
||||
|
||||
@@ -3,30 +3,32 @@ Word Swap by Embedding
|
||||
============================================
|
||||
|
||||
"""
|
||||
from textattack.shared.word_embedding import WordEmbedding
|
||||
from textattack.shared import TextAttackWordEmbedding, WordEmbedding
|
||||
from textattack.transformations.word_swap import WordSwap
|
||||
|
||||
|
||||
class WordSwapEmbedding(WordSwap):
|
||||
"""Transforms an input by replacing its words with synonyms in the word
|
||||
embedding space."""
|
||||
embedding space.
|
||||
|
||||
PATH = "word_embeddings"
|
||||
Args:
|
||||
max_candidates (int): maximum number of synonyms to pick
|
||||
embedding (textattack.shared.WordEmbedding): Wrapper for word embedding
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_candidates=15,
|
||||
embedding_type="paragramcf",
|
||||
embedding_source=None,
|
||||
embedding=TextAttackWordEmbedding.counterfitted_GLOVE_embedding(),
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.max_candidates = max_candidates
|
||||
self.embedding_type = embedding_type
|
||||
self.embedding_source = embedding_source
|
||||
self.embedding = WordEmbedding(
|
||||
embedding_type=embedding_type, embedding_source=embedding_source
|
||||
)
|
||||
if not isinstance(embedding, WordEmbedding):
|
||||
raise ValueError(
|
||||
"`embedding` object must be of type `textattack.shared.WordEmbedding`."
|
||||
)
|
||||
self.embedding = embedding
|
||||
|
||||
def _get_replacement_words(self, word):
|
||||
"""Returns a list of possible 'candidate words' to replace a word in a
|
||||
@@ -35,11 +37,11 @@ class WordSwapEmbedding(WordSwap):
|
||||
Based on nearest neighbors selected word embeddings.
|
||||
"""
|
||||
try:
|
||||
word_id = self.embedding.word2ind(word.lower())
|
||||
nnids = self.embedding.nn(word_id, self.max_candidates)
|
||||
word_id = self.embedding.word2index(word.lower())
|
||||
nnids = self.embedding.nearest_neighbours(word_id, self.max_candidates)
|
||||
candidate_words = []
|
||||
for i, nbr_id in enumerate(nnids):
|
||||
nbr_word = self.embedding.ind2word(nbr_id)
|
||||
nbr_word = self.embedding.index2word(nbr_id)
|
||||
candidate_words.append(recover_word_case(nbr_word, word))
|
||||
return candidate_words
|
||||
except KeyError:
|
||||
@@ -47,7 +49,7 @@ class WordSwapEmbedding(WordSwap):
|
||||
return []
|
||||
|
||||
def extra_repr_keys(self):
|
||||
return ["max_candidates", "embedding_type"]
|
||||
return ["max_candidates", "embedding"]
|
||||
|
||||
|
||||
def recover_word_case(word, reference_word):
|
||||
|
||||
Reference in New Issue
Block a user