From e583188b9193bb4989bfa799a4029e2da3ae41c3 Mon Sep 17 00:00:00 2001 From: Jin Yong Yoo Date: Fri, 13 Nov 2020 01:01:03 -0500 Subject: [PATCH] separate different types of word embeddings into differetn classes --- requirements.txt | 1 - setup.py | 1 + tests/sample_outputs/interactive_mode.txt | 4 +- tests/sample_outputs/kuleshov_cnn_sst_2.txt | 4 +- .../run_attack_faster_alzantot_recipe.txt | 4 +- ...run_attack_flair_pos_tagger_bert_score.txt | 2 +- .../run_attack_gradient_greedy_word_wir.txt | 2 +- .../run_attack_hotflip_lstm_mr_4.txt | 2 +- .../run_attack_stanza_pos_tagger.txt | 2 +- tests/test_word_embedding.py | 12 +- textattack/attack_recipes/kuleshov_2017.py | 6 +- textattack/augmentation/recipes.py | 4 +- .../sentence_encoders/thought_vector.py | 21 +- .../semantics/word_embedding_distance.py | 47 +- textattack/shared/__init__.py | 2 +- textattack/shared/utils/strings.py | 17 +- textattack/shared/word_embedding.py | 606 ++++++++++-------- .../transformations/word_swap_embedding.py | 30 +- 18 files changed, 435 insertions(+), 332 deletions(-) diff --git a/requirements.txt b/requirements.txt index 5dd094c8..027dfe79 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,6 @@ bert-score>=0.3.5 editdistance flair==0.6.1.post1 filelock -gensim==3.8.3 language_tool_python lemminflect lru-dict diff --git a/setup.py b/setup.py index 07a5db3f..d85e964d 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,7 @@ extras["optional"] = [ "stanza", "visdom", "wandb", + "gensim==3.8.3", ] # For developers, install development tools along with all optional dependencies. diff --git a/tests/sample_outputs/interactive_mode.txt b/tests/sample_outputs/interactive_mode.txt index d3f65f37..1f4457d1 100644 --- a/tests/sample_outputs/interactive_mode.txt +++ b/tests/sample_outputs/interactive_mode.txt @@ -5,11 +5,11 @@ (goal_function): UntargetedClassification (transformation): WordSwapEmbedding( (max_candidates): 50 - (embedding_type): paragramcf + (embedding): TextAttackWordEmbedding ) (constraints): (0): WordEmbeddingDistance( - (embedding_type): paragramcf + (embedding): TextAttackWordEmbedding (min_cos_sim): 0.5 (cased): False (include_unknown_words): True diff --git a/tests/sample_outputs/kuleshov_cnn_sst_2.txt b/tests/sample_outputs/kuleshov_cnn_sst_2.txt index a8dccbf3..d597ca7b 100644 --- a/tests/sample_outputs/kuleshov_cnn_sst_2.txt +++ b/tests/sample_outputs/kuleshov_cnn_sst_2.txt @@ -3,7 +3,7 @@ (goal_function): UntargetedClassification (transformation): WordSwapEmbedding( (max_candidates): 15 - (embedding_type): paragramcf + (embedding): TextAttackWordEmbedding ) (constraints): (0): MaxWordsPerturbed( @@ -11,7 +11,7 @@ (compare_against_original): True ) (1): ThoughtVector( - (embedding_type): paragramcf + (word_embedding): TextAttackWordEmbedding (metric): max_euclidean (threshold): -0.2 (window_size): inf diff --git a/tests/sample_outputs/run_attack_faster_alzantot_recipe.txt b/tests/sample_outputs/run_attack_faster_alzantot_recipe.txt index bbbe2deb..68010129 100644 --- a/tests/sample_outputs/run_attack_faster_alzantot_recipe.txt +++ b/tests/sample_outputs/run_attack_faster_alzantot_recipe.txt @@ -10,7 +10,7 @@ (goal_function): UntargetedClassification (transformation): WordSwapEmbedding( (max_candidates): 8 - (embedding_type): paragramcf + (embedding): TextAttackWordEmbedding ) (constraints): (0): MaxWordsPerturbed( @@ -18,7 +18,7 @@ (compare_against_original): True ) (1): WordEmbeddingDistance( - (embedding_type): paragramcf + (embedding): TextAttackWordEmbedding (max_mse_dist): 0.5 (cased): False (include_unknown_words): True diff --git a/tests/sample_outputs/run_attack_flair_pos_tagger_bert_score.txt b/tests/sample_outputs/run_attack_flair_pos_tagger_bert_score.txt index b9bb683f..8e045372 100644 --- a/tests/sample_outputs/run_attack_flair_pos_tagger_bert_score.txt +++ b/tests/sample_outputs/run_attack_flair_pos_tagger_bert_score.txt @@ -5,7 +5,7 @@ (goal_function): UntargetedClassification (transformation): WordSwapEmbedding( (max_candidates): 15 - (embedding_type): paragramcf + (embedding): TextAttackWordEmbedding ) (constraints): (0): BERTScore( diff --git a/tests/sample_outputs/run_attack_gradient_greedy_word_wir.txt b/tests/sample_outputs/run_attack_gradient_greedy_word_wir.txt index d07013a1..24448a6f 100644 --- a/tests/sample_outputs/run_attack_gradient_greedy_word_wir.txt +++ b/tests/sample_outputs/run_attack_gradient_greedy_word_wir.txt @@ -5,7 +5,7 @@ (goal_function): UntargetedClassification (transformation): WordSwapEmbedding( (max_candidates): 15 - (embedding_type): paragramcf + (embedding): TextAttackWordEmbedding ) (constraints): (0): RepeatModification diff --git a/tests/sample_outputs/run_attack_hotflip_lstm_mr_4.txt b/tests/sample_outputs/run_attack_hotflip_lstm_mr_4.txt index 05fca604..9eb53cb0 100644 --- a/tests/sample_outputs/run_attack_hotflip_lstm_mr_4.txt +++ b/tests/sample_outputs/run_attack_hotflip_lstm_mr_4.txt @@ -12,7 +12,7 @@ (compare_against_original): True ) (1): WordEmbeddingDistance( - (embedding_type): paragramcf + (embedding): TextAttackWordEmbedding (min_cos_sim): 0.8 (cased): False (include_unknown_words): True diff --git a/tests/sample_outputs/run_attack_stanza_pos_tagger.txt b/tests/sample_outputs/run_attack_stanza_pos_tagger.txt index 9347c3e5..d1988992 100644 --- a/tests/sample_outputs/run_attack_stanza_pos_tagger.txt +++ b/tests/sample_outputs/run_attack_stanza_pos_tagger.txt @@ -3,7 +3,7 @@ (goal_function): UntargetedClassification (transformation): WordSwapEmbedding( (max_candidates): 15 - (embedding_type): paragramcf + (embedding): TextAttackWordEmbedding ) (constraints): (0): PartOfSpeech( diff --git a/tests/test_word_embedding.py b/tests/test_word_embedding.py index f11e1020..537104c1 100644 --- a/tests/test_word_embedding.py +++ b/tests/test_word_embedding.py @@ -3,11 +3,11 @@ import os import numpy as np import pytest -from textattack.shared import WordEmbedding +from textattack.shared import GensimWordEmbedding, TextAttackWordEmbedding def test_embedding_paragramcf(): - word_embedding = WordEmbedding() + word_embedding = TextAttackWordEmbedding.counterfitted_GLOVE_embedding() assert pytest.approx(word_embedding[0][0]) == -0.022007 assert pytest.approx(word_embedding["fawn"][0]) == -0.022007 assert word_embedding[10 ** 9] is None @@ -28,7 +28,7 @@ bye-bye -1 1 """ ) f.close() - word_embedding = WordEmbedding(embedding_type=path, embedding_source="gensim") + word_embedding = GensimWordEmbedding(path) assert pytest.approx(word_embedding[0][0]) == 1 assert pytest.approx(word_embedding["bye-bye"][0]) == -1 / np.sqrt(2) assert word_embedding[10 ** 9] is None @@ -38,8 +38,8 @@ bye-bye -1 1 # mse dist assert pytest.approx(word_embedding.get_mse_dist(0, 2)) == 4 # nearest neighbour of hi is hello - assert word_embedding.nn(0, 1)[0] == 1 - assert word_embedding.word2ind("bye") == 2 - assert word_embedding.ind2word(3) == "bye-bye" + assert word_embedding.nearest_neighbours(0, 1)[0] == 1 + assert word_embedding.word2index("bye") == 2 + assert word_embedding.index2word(3) == "bye-bye" # remove test file os.remove(path) diff --git a/textattack/attack_recipes/kuleshov_2017.py b/textattack/attack_recipes/kuleshov_2017.py index b1eb1017..49bd4933 100644 --- a/textattack/attack_recipes/kuleshov_2017.py +++ b/textattack/attack_recipes/kuleshov_2017.py @@ -49,11 +49,7 @@ class Kuleshov2017(AttackRecipe): # # Maximum thought vector Euclidean distance of λ_1 = 0.2. (eq. 4) # - constraints.append( - ThoughtVector( - embedding_type="paragramcf", threshold=0.2, metric="max_euclidean" - ) - ) + constraints.append(ThoughtVector(threshold=0.2, metric="max_euclidean")) # # # Maximum language model log-probability difference of λ_2 = 2. (eq. 5) diff --git a/textattack/augmentation/recipes.py b/textattack/augmentation/recipes.py index 92014d1c..c4d98bf8 100644 --- a/textattack/augmentation/recipes.py +++ b/textattack/augmentation/recipes.py @@ -115,9 +115,7 @@ class EmbeddingAugmenter(Augmenter): def __init__(self, **kwargs): from textattack.transformations import WordSwapEmbedding - transformation = WordSwapEmbedding( - max_candidates=50, embedding_type="paragramcf" - ) + transformation = WordSwapEmbedding(max_candidates=50) from textattack.constraints.semantics import WordEmbeddingDistance constraints = DEFAULT_CONSTRAINTS + [WordEmbeddingDistance(min_cos_sim=0.8)] diff --git a/textattack/constraints/semantics/sentence_encoders/thought_vector.py b/textattack/constraints/semantics/sentence_encoders/thought_vector.py index 65872ec8..5d5ef063 100644 --- a/textattack/constraints/semantics/sentence_encoders/thought_vector.py +++ b/textattack/constraints/semantics/sentence_encoders/thought_vector.py @@ -7,7 +7,7 @@ import functools import torch -from textattack.shared import WordEmbedding, utils +from textattack.shared import TextAttackWordEmbedding, WordEmbedding, utils from .sentence_encoder import SentenceEncoder @@ -16,14 +16,19 @@ class ThoughtVector(SentenceEncoder): """A constraint on the distance between two sentences' thought vectors. Args: - word_embedding (str): The word embedding to use - min_cos_sim: the minimum cosine similarity between thought vectors - max_mse_dist: the maximum euclidean distance between thought vectors + word_embedding (textattack.shared.WordEmbedding): The word embedding to use """ - def __init__(self, embedding_type="paragramcf", embedding_source=None, **kwargs): - self.word_embedding = WordEmbedding(embedding_type, embedding_source) - self.embedding_type = embedding_type + def __init__( + self, + embedding=TextAttackWordEmbedding.counterfitted_GLOVE_embedding(), + **kwargs + ): + if not isinstance(embedding, WordEmbedding): + raise ValueError( + "`embedding` object must be of type `textattack.shared.WordEmbedding`." + ) + self.word_embedding = embedding super().__init__(**kwargs) def clear_cache(self): @@ -46,4 +51,4 @@ class ThoughtVector(SentenceEncoder): def extra_repr_keys(self): """Set the extra representation of the constraint using these keys.""" - return ["embedding_type"] + super().extra_repr_keys() + return ["word_embedding"] + super().extra_repr_keys() diff --git a/textattack/constraints/semantics/word_embedding_distance.py b/textattack/constraints/semantics/word_embedding_distance.py index 064d4299..f300e4f5 100644 --- a/textattack/constraints/semantics/word_embedding_distance.py +++ b/textattack/constraints/semantics/word_embedding_distance.py @@ -1,10 +1,11 @@ """ Word Embedding Distance -------------------------- +-------------------------- """ + from textattack.constraints import Constraint +from textattack.shared import TextAttackWordEmbedding, WordEmbedding from textattack.shared.validators import transformation_consists_of_word_swaps -from textattack.shared.word_embedding import WordEmbedding class WordEmbeddingDistance(Constraint): @@ -13,24 +14,17 @@ class WordEmbeddingDistance(Constraint): inserted. Args: - embedding_type (str): The word embedding to use. - embedding_source (str): The source of embeddings ("defaults" or "gensim") - include_unknown_words (bool): Whether or not the constraint is fulfilled - if the embedding of x or x_adv is unknown. - min_cos_sim: The minimum cosine similarity between word embeddings. - max_mse_dist: The maximum euclidean distance between word embeddings. - cased (bool): Whether embedding supports uppercase & lowercase - (defaults to False, or just lowercase). - compare_against_original (bool): If `True`, compare new `x_adv` against the original `x`. - Otherwise, compare it against the previous `x_adv`. + embedding (obj): Wrapper for word embedding. + include_unknown_words (bool): Whether or not the constraint is fulfilled if the embedding of x or x_adv is unknown. + min_cos_sim (:obj:`float`, optional): The minimum cosine similarity between word embeddings. + max_mse_dist (:obj:`float`, optional): The maximum euclidean distance between word embeddings. + cased (bool): Whether embedding supports uppercase & lowercase (defaults to False, or just lowercase). + compare_against_original (bool): If `True`, compare new `x_adv` against the original `x`. Otherwise, compare it against the previous `x_adv`. """ - PATH = "word_embeddings" - def __init__( self, - embedding_type="paragramcf", - embedding_source=None, + embedding=TextAttackWordEmbedding.counterfitted_GLOVE_embedding(), include_unknown_words=True, min_cos_sim=None, max_mse_dist=None, @@ -40,14 +34,17 @@ class WordEmbeddingDistance(Constraint): super().__init__(compare_against_original) self.include_unknown_words = include_unknown_words self.cased = cased + + if bool(min_cos_sim) == bool(max_mse_dist): + raise ValueError("You must choose either `min_cos_sim` or `max_mse_dist`.") self.min_cos_sim = min_cos_sim self.max_mse_dist = max_mse_dist - self.embedding_type = embedding_type - self.embedding_source = embedding_source - self.embedding = WordEmbedding( - embedding_type=embedding_type, embedding_source=embedding_source - ) + if not isinstance(embedding, WordEmbedding): + raise ValueError( + "`embedding` object must be of type `textattack.shared.WordEmbedding`." + ) + self.embedding = embedding def get_cos_sim(self, a, b): """Returns the cosine similarity of words with IDs a and b.""" @@ -59,7 +56,7 @@ class WordEmbeddingDistance(Constraint): def _check_constraint(self, transformed_text, reference_text): """Returns true if (``transformed_text`` and ``reference_text``) are - closer than ``self.min_cos_sim`` and ``self.max_mse_dist``.""" + closer than ``self.min_cos_sim`` or ``self.max_mse_dist``.""" try: indices = transformed_text.attack_attrs["newly_modified_indices"] except KeyError: @@ -77,8 +74,8 @@ class WordEmbeddingDistance(Constraint): transformed_word = transformed_word.lower() try: - ref_id = self.embedding.word2ind(ref_word) - transformed_id = self.embedding.word2ind(transformed_word) + ref_id = self.embedding.word2index(ref_word) + transformed_id = self.embedding.word2index(transformed_word) except KeyError: # This error is thrown if x or x_adv has no corresponding ID. if self.include_unknown_words: @@ -116,7 +113,7 @@ class WordEmbeddingDistance(Constraint): else: metric = "min_cos_sim" return [ - "embedding_type", + "embedding", metric, "cased", "include_unknown_words", diff --git a/textattack/shared/__init__.py b/textattack/shared/__init__.py index 89dbe67e..71540a7d 100644 --- a/textattack/shared/__init__.py +++ b/textattack/shared/__init__.py @@ -13,6 +13,6 @@ from .utils import logger from . import validators from .attacked_text import AttackedText -from .word_embedding import WordEmbedding +from .word_embedding import * from .attack import Attack from .checkpoint import Checkpoint diff --git a/textattack/shared/utils/strings.py b/textattack/shared/utils/strings.py index 4c5fd1fc..b82271be 100644 --- a/textattack/shared/utils/strings.py +++ b/textattack/shared/utils/strings.py @@ -49,15 +49,18 @@ def words_from_text(s, words_to_ignore=[]): def default_class_repr(self): - extra_params = [] - for key in self.extra_repr_keys(): - extra_params.append(" (" + key + ")" + ": {" + key + "}") - if len(extra_params): - extra_str = "\n" + "\n".join(extra_params) + "\n" - extra_str = f"({extra_str})" + if hasattr(self, "extra_repr_keys"): + extra_params = [] + for key in self.extra_repr_keys(): + extra_params.append(" (" + key + ")" + ": {" + key + "}") + if len(extra_params): + extra_str = "\n" + "\n".join(extra_params) + "\n" + extra_str = f"({extra_str})" + else: + extra_str = "" + extra_str = extra_str.format(**self.__dict__) else: extra_str = "" - extra_str = extra_str.format(**self.__dict__) return f"{self.__class__.__name__}{extra_str}" diff --git a/textattack/shared/word_embedding.py b/textattack/shared/word_embedding.py index 9a2ad561..59801717 100644 --- a/textattack/shared/word_embedding.py +++ b/textattack/shared/word_embedding.py @@ -3,6 +3,7 @@ Shared loads word embeddings and related distances ===================================================== """ +from abc import ABC, abstractmethod from collections import defaultdict import os import pickle @@ -10,77 +11,260 @@ import pickle import numpy as np import torch -import textattack from textattack.shared import utils -class WordEmbedding: - """An object that loads word embeddings and related distances. +class WordEmbedding(ABC): + """Abstract class representing word embedding used by TextAttack. + + This class specifies all the methods that is required to be defined + so that it can be used for transformation and constraints. For + custom word embedding not supported by TextAttack, please create a + class that inherits this object and implement the required methods. + However, please first check if you can use TextAttackWordEmbedding + object, which has a lot of internal methods implemented. + """ + + @abstractmethod + def __getitem__(self, index): + """Gets the embedding vector for word/id + Args: + index (Union[str|int]): `index` can either be word or integer representing the id of the word. + Returns: + vector (ndarray): 1-D embedding vector. If corresponding vector cannot be found for `index`, returns `None`. + """ + raise NotImplementedError() + + @abstractmethod + def get_mse_dist(self, a, b): + """Return MSE (i.e. L2-norm) distance between vector for word `a` and + vector for word `b`. + + Since this is a metric, `get_mse_dist(a,b)` and `get_mse_dist(b,a)` should return the same value. + Args: + a (Union[str|int]): Either word or integer presenting the id of the word + b (Union[str|int]): Either word or integer presenting the id of the word + Returns: + distance (float): MSE (L2) distance + """ + raise NotImplementedError() + + @abstractmethod + def get_cos_sim(self, a, b): + """Return cosine similarity between vector for word `a` and vector for + word `b`. + + Since this is a metric, `get_mse_dist(a,b)` and `get_mse_dist(b,a)` should return the same value. + Args: + a (Union[str|int]): Either word or integer presenting the id of the word + b (Union[str|int]): Either word or integer presenting the id of the word + Returns: + distance (float): cosine similarity + """ + raise NotImplementedError() + + @abstractmethod + def word2index(self, word): + """ + Convert between word to id (i.e. index of word in embedding matrix) + Args: + word (str) + Returns: + index (int) + """ + raise NotImplementedError() + + @abstractmethod + def index2word(self, index): + """ + Convert index to corresponding word + Args: + index (int) + Returns: + word (str) + """ + raise NotImplementedError() + + @abstractmethod + def nearest_neighbours(self, index, topn): + """ + Get top-N nearest neighbours for a word + Args: + index (int): ID of the word for which we're finding the nearest neighbours + topn (int): Used for specifying N nearest neighbours + Returns: + neighbours (list[int]): List of indices of the nearest neighbours + """ + raise NotImplementedError() + + __repr__ = __str__ = utils.default_class_repr + + +class TextAttackWordEmbedding(WordEmbedding): + """An object that loads word embeddings and related distances for + TextAttack This has a lot of internal components (e.g. get consine + similarity) implemented. Consider using this class if you can provide the + appropriate input data to create the object. Args: - embedding_type (str): The type of the embedding to load automatically - embedding_source (str): Source of embeddings provided, - "defaults" corresponds to the textattack s3 bucket - "gensim" expects a word2vec model + emedding_matrix (ndarray): 2-D array of NxD where N represents size of vocab and D is the dimension of embedding vectors. + word2index (Union[dict|object]): dictionary (or a similar object) that maps word to its index with in the embedding matrix. + index2word (Union[dict|object]): dictionary (or a similar object) that maps index to its word. + nn_matrix (ndarray): Matrix for precomputed nearest neighbours matrix. It should be a 2-D integer array of NxK + where N represents size of vocab and K is the top-K nearest neighbours. If this is not defined, we have to compute nearest neighbours + on the fly for `nearest_neighbours` method, which is costly. """ PATH = "word_embeddings" - EMBEDDINGS_AVAILABLE_IN_TEXTATTACK = {"paragramcf"} - def __init__(self, embedding_type=None, embedding_source=None): + def __init__(self, embedding_matrix, word2index, index2word, nn_matrix=None): + self.embedding_matrix = embedding_matrix + self._word2index = word2index + self._index2word = index2word + self.nn_matrix = nn_matrix - self.embedding_type = embedding_type or "paragramcf" - self.embedding_source = embedding_source or "defaults" + # Dictionary for caching results + self._mse_dist_mat = defaultdict(dict) + self._cos_sim_mat = defaultdict(dict) + self._nn_cache = {} - if self.embedding_source == "defaults": - if ( - self.embedding_type - not in WordEmbedding.EMBEDDINGS_AVAILABLE_IN_TEXTATTACK - ): - raise ValueError( - f"{self.embedding_type} is not available in TextAttack." - ) - elif self.embedding_source not in ["gensim"]: - raise ValueError(f"{self.embedding_source} type is not supported.") - - self._embeddings = None - self._word2index = None - self._index2word = None - self._cos_sim_mat = None - self._mse_dist_mat = None - self._nn = None - - self._gensim_keyed_vectors = None - - self._init_embeddings_from_type(self.embedding_source, self.embedding_type) - - def _init_embeddings_from_defaults(self, embedding_type): - """ - Init embeddings prepared in the textattack s3 bucket + def __getitem__(self, index): + """Gets the embedding vector for word/id Args: - embedding_type: - + index (Union[str|int]): `index` can either be word or integer representing the id of the word. Returns: + vector (ndarray): 1-D embedding vector. If corresponding vector cannot be found for `index`, returns `None`. + """ + if isinstance(index, str): + try: + index = self._word2index[index] + except KeyError: + return None + try: + return self.embedding_matrix[index] + except IndexError: + # word embedding ID out of bounds + return None + + def word2index(self, word): + """ + Convert between word to id (i.e. index of word in embedding matrix) + Args: + word (str) + Returns: + index (int) + """ + return self._word2index[word] + + def index2word(self, index): + """ + Convert index to corresponding word + Args: + index (int) + Returns: + word (str) """ - if embedding_type == "paragramcf": - word_embeddings_folder = "paragramcf" - word_embeddings_file = "paragram.npy" - word_list_file = "wordlist.pickle" - mse_dist_file = "mse_dist.p" - cos_sim_file = "cos_sim.p" - nn_matrix_file = "nn.npy" + return self._index2word[index] + + def get_mse_dist(self, a, b): + """Return MSE (i.e. L2-norm) distance between vector for word `a` and + vector for word `b`. + + Since this is a metric, `get_mse_dist(a,b)` and `get_mse_dist(b,a)` should return the same value. + Args: + a (Union[str|int]): Either word or integer presenting the id of the word + b (Union[str|int]): Either word or integer presenting the id of the word + Returns: + distance (float): MSE (L2) distance + """ + if isinstance(a, str): + a = self._word2index[a] + if isinstance(b, str): + b = self._word2index[b] + a, b = min(a, b), max(a, b) + try: + mse_dist = self._mse_dist_mat[a][b] + except KeyError: + e1 = self.embedding_matrix[a] + e2 = self.embedding_matrix[b] + e1 = torch.tensor(e1).to(utils.device) + e2 = torch.tensor(e2).to(utils.device) + mse_dist = torch.sum((e1 - e2) ** 2).item() + self._mse_dist_mat[a][b] = mse_dist + + return mse_dist + + def get_cos_sim(self, a, b): + """Return cosine similarity between vector for word `a` and vector for + word `b`. + + Since this is a metric, `get_mse_dist(a,b)` and `get_mse_dist(b,a)` should return the same value. + Args: + a (Union[str|int]): Either word or integer presenting the id of the word + b (Union[str|int]): Either word or integer presenting the id of the word + Returns: + distance (float): cosine similarity + """ + if isinstance(a, str): + a = self._word2index[a] + if isinstance(b, str): + b = self._word2index[b] + a, b = min(a, b), max(a, b) + try: + cos_sim = self._cos_sim_mat[a][b] + except KeyError: + e1 = self.embedding_matrix[a] + e2 = self.embedding_matrix[b] + e1 = torch.tensor(e1).to(utils.device) + e2 = torch.tensor(e2).to(utils.device) + cos_sim = torch.nn.CosineSimilarity(dim=0)(e1, e2).item() + self._cos_sim_mat[a][b] = cos_sim + return cos_sim + + def nearest_neighbours(self, index, topn): + """ + Get top-N nearest neighbours for a word + Args: + index (int): ID of the word for which we're finding the nearest neighbours + topn (int): Used for specifying N nearest neighbours + Returns: + neighbours (list[int]): List of indices of the nearest neighbours + """ + if isinstance(index, str): + index = self._word2index[index] + if self.nn_matrix is not None: + nn = self.nn_matrix[index][1 : (topn + 1)] else: - raise ValueError(f"Could not find word embedding {embedding_type}") + try: + nn = self._nn_cache[index] + except KeyError: + embedding = torch.tensor(self.embedding_matrix).to(utils.device) + vector = torch.tensor(self.embedding_matrix[index]).to(utils.device) + dist = torch.norm(embedding - vector, dim=1, p=None) + # Since closest neighbour will be the same word, we consider N+1 nearest neighbours + nn = dist.topk(topn + 1, largest=False)[1:].tolist() + self._nn_cache[index] = nn + + return nn + + @staticmethod + def counterfitted_GLOVE_embedding(): + """Returns a prebuilt counter-fitted GLOVE word embedding proposed by + "Counter-fitting Word Vectors to Linguistic Constraints" (Mrkšić et + al., 2016)""" + word_embeddings_folder = "paragramcf" + word_embeddings_file = "paragram.npy" + word_list_file = "wordlist.pickle" + mse_dist_file = "mse_dist.p" + cos_sim_file = "cos_sim.p" + nn_matrix_file = "nn.npy" # Download embeddings if they're not cached. word_embeddings_folder = os.path.join( - WordEmbedding.PATH, word_embeddings_folder + TextAttackWordEmbedding.PATH, word_embeddings_folder ) - word_embeddings_folder = textattack.shared.utils.download_if_needed( - word_embeddings_folder - ) - + word_embeddings_folder = utils.download_if_needed(word_embeddings_folder) # Concatenate folder names to create full path to files. word_embeddings_file = os.path.join( word_embeddings_folder, word_embeddings_file @@ -91,234 +275,152 @@ class WordEmbedding: nn_matrix_file = os.path.join(word_embeddings_folder, nn_matrix_file) # loading the files - self._embeddings = np.load(word_embeddings_file) - self._word2index = np.load(word_list_file, allow_pickle=True) - if os.path.exists(mse_dist_file): - with open(mse_dist_file, "rb") as f: - self._mse_dist_mat = pickle.load(f) - else: - self._mse_dist_mat = defaultdict(dict) - if os.path.exists(cos_sim_file): - with open(cos_sim_file, "rb") as f: - self._cos_sim_mat = pickle.load(f) - else: - self._cos_sim_mat = defaultdict(dict) + embedding_matrix = np.load(word_embeddings_file) + word2index = np.load(word_list_file, allow_pickle=True) + index2word = {} + for word, index in word2index.items(): + index2word[index] = word + nn_matrix = np.load(nn_matrix_file) - self._nn = np.load(nn_matrix_file) + word_embedding = TextAttackWordEmbedding( + embedding_matrix, word2index, index2word, nn_matrix + ) + with open(mse_dist_file, "rb") as f: + mse_dist_mat = pickle.load(f) + with open(cos_sim_file, "rb") as f: + cos_sim_mat = pickle.load(f) - # Build glove dict and index. - self._index2word = dict() - for word, index in self._word2index.items(): - self._index2word[index] = word + word_embedding._mse_dist_mat = mse_dist_mat + word_embedding._cos_sim_mat = cos_sim_mat - def _init_embeddings_from_gensim(self, embedding_type): - """ - Initialize word embedding from a gensim word2vec model - Args: - embedding_type: + return word_embedding - Returns: - """ - import gensim +class GensimWordEmbedding(WordEmbedding): + """Wraps Gensim's KeyedVectors + (https://radimrehurek.com/gensim/models/keyedvectors.html)""" - if embedding_type.endswith(".bin"): - self._gensim_keyed_vectors = ( - gensim.models.KeyedVectors.load_word2vec_format( - embedding_type, binary=True + def __init__(self, keyed_vectors_or_path): + gensim = utils.LazyLoader("gensim", globals(), "gensim") + + if isinstance(keyed_vectors_or_path, str): + if keyed_vectors_or_path.endswith(".bin"): + self.keyed_vectors = gensim.models.KeyedVectors.load_word2vec_format( + keyed_vectors_or_path, binary=True ) - ) + else: + self.keyed_vectors = gensim.models.KeyedVectors.load_word2vec_format( + keyed_vectors_or_path + ) + elif isinstance(keyed_vectors_or_path, gensim.models.KeyedVectors): + self.keyed_vectors = keyed_vectors_or_path else: - self._gensim_keyed_vectors = ( - gensim.models.KeyedVectors.load_word2vec_format(embedding_type) + raise ValueError( + "`keyed_vectors_or_path` argument must either be `gensim.models.KeyedVectors` object " + "or a path pointing to the saved KeyedVector object" ) - self._gensim_keyed_vectors.init_sims() - def _init_embeddings_from_type(self, embedding_source, embedding_type): - """Initializes embedding based on the source. - - Downloads and loads embeddings into memory. - """ - - if embedding_source == "defaults": - self._init_embeddings_from_defaults(embedding_type) - elif embedding_source == "gensim": - self._init_embeddings_from_gensim(embedding_type) - else: - raise ValueError(f"Not supported word embedding source {embedding_source}") + self.keyed_vectors.init_sims() + self._mse_dist_mat = defaultdict(dict) + self._cos_sim_mat = defaultdict(dict) def __getitem__(self, index): - """Gets a word embedding by word or ID. - - If word or ID not found, returns None. - """ - if self.embedding_source == "defaults": - if isinstance(index, str): - try: - index = self._word2index[index] - except KeyError: - return None - try: - return self._embeddings[index] - except IndexError: - # word embedding ID out of bounds - return None - elif self.embedding_source == "gensim": - if isinstance(index, str): - try: - index = self._gensim_keyed_vectors.vocab.get(index).index - except KeyError: - return None - try: - return self._gensim_keyed_vectors.vectors_norm[index] - except IndexError: - # word embedding ID out of bounds - return None - else: - raise ValueError( - f"Not supported word embedding source {self.embedding_source}" - ) - - def get_cos_sim(self, a, b): - """ - get cosine similarity of two words/IDs + """Gets the embedding vector for word/id Args: - a: - b: - + index (Union[str|int]): `index` can either be word or integer representing the id of the word. Returns: + vector (ndarray): 1-D embedding vector. If corresponding vector cannot be found for `index`, returns `None`. + """ + if isinstance(index, str): + try: + index = self.keyed_vectors.vocab.get(index).index + except KeyError: + return None + try: + return self.keyed_vectors.vectors_norm[index] + except IndexError: + # word embedding ID out of bounds + return None + + def word2index(self, word): + """ + Convert between word to id (i.e. index of word in embedding matrix) + Args: + word (str) + Returns: + index (int) + """ + vocab = self.keyed_vectors.vocab.get(word) + if vocab is None: + raise KeyError(word) + return vocab.index + + def index2word(self, index): + """ + Convert index to corresponding word + Args: + index (int) + Returns: + word (str) """ - if self.embedding_source == "defaults": - if isinstance(a, str): - a = self._word2index[a] - if isinstance(b, str): - b = self._word2index[b] - a, b = min(a, b), max(a, b) - try: - cos_sim = self._cos_sim_mat[a][b] - except KeyError: - e1 = self._embeddings[a] - e2 = self._embeddings[b] - e1 = torch.tensor(e1).to(utils.device) - e2 = torch.tensor(e2).to(utils.device) - cos_sim = torch.nn.CosineSimilarity(dim=0)(e1, e2).item() - self._cos_sim_mat[a][b] = cos_sim - return cos_sim - elif self.embedding_source == "gensim": - if not isinstance(a, str): - a = self._gensim_keyed_vectors.index2word[a] - if not isinstance(b, str): - b = self._gensim_keyed_vectors.index2word[b] - cos_sim = self._gensim_keyed_vectors.similarity(a, b) - return cos_sim - else: - raise ValueError( - f"Not supported word embedding source {self.embedding_source}" - ) + try: + # this is a list, so the error would be IndexError + return self.keyed_vectors.index2word[index] + except IndexError: + raise KeyError(index) def get_mse_dist(self, a, b): - """ - get mse distance of two IDs + """Return MSE (i.e. L2-norm) distance between vector for word `a` and + vector for word `b`. + + Since this is a metric, `get_mse_dist(a,b)` and `get_mse_dist(b,a)` should return the same value. Args: - a: - b: - + a (Union[str|int]): Either word or integer presenting the id of the word + b (Union[str|int]): Either word or integer presenting the id of the word Returns: - + distance (float): MSE (L2) distance """ - if self.embedding_source == "defaults": - a, b = min(a, b), max(a, b) - try: - mse_dist = self._mse_dist_mat[a][b] - except KeyError: - e1 = self._embeddings[a] - e2 = self._embeddings[b] - e1 = torch.tensor(e1).to(utils.device) - e2 = torch.tensor(e2).to(utils.device) - mse_dist = torch.sum((e1 - e2) ** 2).item() - self._mse_dist_mat[a][b] = mse_dist - return mse_dist - elif self.embedding_source == "gensim": - if self._mse_dist_mat is None: - self._mse_dist_mat = defaultdict(dict) - try: - mse_dist = self._mse_dist_mat[a][b] - except KeyError: - e1 = self._gensim_keyed_vectors.vectors_norm[a] - e2 = self._gensim_keyed_vectors.vectors_norm[b] - e1 = torch.tensor(e1).to(utils.device) - e2 = torch.tensor(e2).to(utils.device) - mse_dist = torch.sum((e1 - e2) ** 2).item() - self._mse_dist_mat[a][b] = mse_dist - return mse_dist - else: - raise ValueError( - f"Not supported word embedding source {self.embedding_source}" - ) + try: + mse_dist = self._mse_dist_mat[a][b] + except KeyError: + e1 = self.keyed_vectors.vectors_norm[a] + e2 = self.keyed_vectors.vectors_norm[b] + e1 = torch.tensor(e1).to(utils.device) + e2 = torch.tensor(e2).to(utils.device) + mse_dist = torch.sum((e1 - e2) ** 2).item() + self._mse_dist_mat[a][b] = mse_dist + return mse_dist - def word2ind(self, word): - """ - word to index + def get_cos_sim(self, a, b): + """Return cosine similarity between vector for word `a` and vector for + word `b`. + + Since this is a metric, `get_mse_dist(a,b)` and `get_mse_dist(b,a)` should return the same value. Args: - word: - + a (Union[str|int]): Either word or integer presenting the id of the word + b (Union[str|int]): Either word or integer presenting the id of the word Returns: - + distance (float): cosine similarity """ - if self.embedding_source == "defaults": - return self._word2index[word] - elif self.embedding_source == "gensim": - vocab = self._gensim_keyed_vectors.vocab.get(word) - if vocab is None: - raise KeyError(word) - return vocab.index - else: - raise ValueError( - f"Not supported word embedding source {self.embedding_source}" - ) + if not isinstance(a, str): + a = self.keyed_vectors.index2word[a] + if not isinstance(b, str): + b = self.keyed_vectors.index2word[b] + cos_sim = self.keyed_vectors.similarity(a, b) + return cos_sim - def ind2word(self, index): + def nearest_neighbours(self, index, topn, return_words=True): """ - index to word + Get top-N nearest neighbours for a word Args: - index: - + index (int): ID of the word for which we're finding the nearest neighbours + topn (int): Used for specifying N nearest neighbours Returns: - + neighbours (list[int]): List of indices of the nearest neighbours """ - if self.embedding_source == "defaults": - return self._index2word[index] - elif self.embedding_source == "gensim": - try: - # this is a list, so the error would be IndexError - return self._gensim_keyed_vectors.index2word[index] - except IndexError: - raise KeyError(index) - else: - raise ValueError( - f"Not supported word embedding source {self.embedding_source}" - ) - - def nn(self, index, topn): - """ - get top n nearest neighbours for a word - Args: - index: - topn: - - Returns: - - """ - if self.embedding_source == "defaults": - return self._nn[index][1 : (topn + 1)] - elif self.embedding_source == "gensim": - word = self._gensim_keyed_vectors.index2word[index] - return [ - self._gensim_keyed_vectors.index2word.index(i[0]) - for i in self._gensim_keyed_vectors.similar_by_word(word, topn) - ] - else: - raise ValueError( - f"Not supported word embedding source {self.embedding_source}" - ) + word = self.keyed_vectors.index2word[index] + return [ + self.keyed_vectors.index2word.index(i[0]) + for i in self.keyed_vectors.similar_by_word(word, topn) + ] diff --git a/textattack/transformations/word_swap_embedding.py b/textattack/transformations/word_swap_embedding.py index 3d268aa2..55d2f003 100644 --- a/textattack/transformations/word_swap_embedding.py +++ b/textattack/transformations/word_swap_embedding.py @@ -3,30 +3,32 @@ Word Swap by Embedding ============================================ """ -from textattack.shared.word_embedding import WordEmbedding +from textattack.shared import TextAttackWordEmbedding, WordEmbedding from textattack.transformations.word_swap import WordSwap class WordSwapEmbedding(WordSwap): """Transforms an input by replacing its words with synonyms in the word - embedding space.""" + embedding space. - PATH = "word_embeddings" + Args: + max_candidates (int): maximum number of synonyms to pick + embedding (textattack.shared.WordEmbedding): Wrapper for word embedding + """ def __init__( self, max_candidates=15, - embedding_type="paragramcf", - embedding_source=None, + embedding=TextAttackWordEmbedding.counterfitted_GLOVE_embedding(), **kwargs ): super().__init__(**kwargs) self.max_candidates = max_candidates - self.embedding_type = embedding_type - self.embedding_source = embedding_source - self.embedding = WordEmbedding( - embedding_type=embedding_type, embedding_source=embedding_source - ) + if not isinstance(embedding, WordEmbedding): + raise ValueError( + "`embedding` object must be of type `textattack.shared.WordEmbedding`." + ) + self.embedding = embedding def _get_replacement_words(self, word): """Returns a list of possible 'candidate words' to replace a word in a @@ -35,11 +37,11 @@ class WordSwapEmbedding(WordSwap): Based on nearest neighbors selected word embeddings. """ try: - word_id = self.embedding.word2ind(word.lower()) - nnids = self.embedding.nn(word_id, self.max_candidates) + word_id = self.embedding.word2index(word.lower()) + nnids = self.embedding.nearest_neighbours(word_id, self.max_candidates) candidate_words = [] for i, nbr_id in enumerate(nnids): - nbr_word = self.embedding.ind2word(nbr_id) + nbr_word = self.embedding.index2word(nbr_id) candidate_words.append(recover_word_case(nbr_word, word)) return candidate_words except KeyError: @@ -47,7 +49,7 @@ class WordSwapEmbedding(WordSwap): return [] def extra_repr_keys(self): - return ["max_candidates", "embedding_type"] + return ["max_candidates", "embedding"] def recover_word_case(word, reference_word):