1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

separate different types of word embeddings into differetn classes

This commit is contained in:
Jin Yong Yoo
2020-11-13 01:01:03 -05:00
parent 7caab13a88
commit e583188b91
18 changed files with 435 additions and 332 deletions

View File

@@ -2,7 +2,6 @@ bert-score>=0.3.5
editdistance
flair==0.6.1.post1
filelock
gensim==3.8.3
language_tool_python
lemminflect
lru-dict

View File

@@ -31,6 +31,7 @@ extras["optional"] = [
"stanza",
"visdom",
"wandb",
"gensim==3.8.3",
]
# For developers, install development tools along with all optional dependencies.

View File

@@ -5,11 +5,11 @@
(goal_function): UntargetedClassification
(transformation): WordSwapEmbedding(
(max_candidates): 50
(embedding_type): paragramcf
(embedding): TextAttackWordEmbedding
)
(constraints):
(0): WordEmbeddingDistance(
(embedding_type): paragramcf
(embedding): TextAttackWordEmbedding
(min_cos_sim): 0.5
(cased): False
(include_unknown_words): True

View File

@@ -3,7 +3,7 @@
(goal_function): UntargetedClassification
(transformation): WordSwapEmbedding(
(max_candidates): 15
(embedding_type): paragramcf
(embedding): TextAttackWordEmbedding
)
(constraints):
(0): MaxWordsPerturbed(
@@ -11,7 +11,7 @@
(compare_against_original): True
)
(1): ThoughtVector(
(embedding_type): paragramcf
(word_embedding): TextAttackWordEmbedding
(metric): max_euclidean
(threshold): -0.2
(window_size): inf

View File

@@ -10,7 +10,7 @@
(goal_function): UntargetedClassification
(transformation): WordSwapEmbedding(
(max_candidates): 8
(embedding_type): paragramcf
(embedding): TextAttackWordEmbedding
)
(constraints):
(0): MaxWordsPerturbed(
@@ -18,7 +18,7 @@
(compare_against_original): True
)
(1): WordEmbeddingDistance(
(embedding_type): paragramcf
(embedding): TextAttackWordEmbedding
(max_mse_dist): 0.5
(cased): False
(include_unknown_words): True

View File

@@ -5,7 +5,7 @@
(goal_function): UntargetedClassification
(transformation): WordSwapEmbedding(
(max_candidates): 15
(embedding_type): paragramcf
(embedding): TextAttackWordEmbedding
)
(constraints):
(0): BERTScore(

View File

@@ -5,7 +5,7 @@
(goal_function): UntargetedClassification
(transformation): WordSwapEmbedding(
(max_candidates): 15
(embedding_type): paragramcf
(embedding): TextAttackWordEmbedding
)
(constraints):
(0): RepeatModification

View File

@@ -12,7 +12,7 @@
(compare_against_original): True
)
(1): WordEmbeddingDistance(
(embedding_type): paragramcf
(embedding): TextAttackWordEmbedding
(min_cos_sim): 0.8
(cased): False
(include_unknown_words): True

View File

@@ -3,7 +3,7 @@
(goal_function): UntargetedClassification
(transformation): WordSwapEmbedding(
(max_candidates): 15
(embedding_type): paragramcf
(embedding): TextAttackWordEmbedding
)
(constraints):
(0): PartOfSpeech(

View File

@@ -3,11 +3,11 @@ import os
import numpy as np
import pytest
from textattack.shared import WordEmbedding
from textattack.shared import GensimWordEmbedding, TextAttackWordEmbedding
def test_embedding_paragramcf():
word_embedding = WordEmbedding()
word_embedding = TextAttackWordEmbedding.counterfitted_GLOVE_embedding()
assert pytest.approx(word_embedding[0][0]) == -0.022007
assert pytest.approx(word_embedding["fawn"][0]) == -0.022007
assert word_embedding[10 ** 9] is None
@@ -28,7 +28,7 @@ bye-bye -1 1
"""
)
f.close()
word_embedding = WordEmbedding(embedding_type=path, embedding_source="gensim")
word_embedding = GensimWordEmbedding(path)
assert pytest.approx(word_embedding[0][0]) == 1
assert pytest.approx(word_embedding["bye-bye"][0]) == -1 / np.sqrt(2)
assert word_embedding[10 ** 9] is None
@@ -38,8 +38,8 @@ bye-bye -1 1
# mse dist
assert pytest.approx(word_embedding.get_mse_dist(0, 2)) == 4
# nearest neighbour of hi is hello
assert word_embedding.nn(0, 1)[0] == 1
assert word_embedding.word2ind("bye") == 2
assert word_embedding.ind2word(3) == "bye-bye"
assert word_embedding.nearest_neighbours(0, 1)[0] == 1
assert word_embedding.word2index("bye") == 2
assert word_embedding.index2word(3) == "bye-bye"
# remove test file
os.remove(path)

View File

@@ -49,11 +49,7 @@ class Kuleshov2017(AttackRecipe):
#
# Maximum thought vector Euclidean distance of λ_1 = 0.2. (eq. 4)
#
constraints.append(
ThoughtVector(
embedding_type="paragramcf", threshold=0.2, metric="max_euclidean"
)
)
constraints.append(ThoughtVector(threshold=0.2, metric="max_euclidean"))
#
#
# Maximum language model log-probability difference of λ_2 = 2. (eq. 5)

View File

@@ -115,9 +115,7 @@ class EmbeddingAugmenter(Augmenter):
def __init__(self, **kwargs):
from textattack.transformations import WordSwapEmbedding
transformation = WordSwapEmbedding(
max_candidates=50, embedding_type="paragramcf"
)
transformation = WordSwapEmbedding(max_candidates=50)
from textattack.constraints.semantics import WordEmbeddingDistance
constraints = DEFAULT_CONSTRAINTS + [WordEmbeddingDistance(min_cos_sim=0.8)]

View File

@@ -7,7 +7,7 @@ import functools
import torch
from textattack.shared import WordEmbedding, utils
from textattack.shared import TextAttackWordEmbedding, WordEmbedding, utils
from .sentence_encoder import SentenceEncoder
@@ -16,14 +16,19 @@ class ThoughtVector(SentenceEncoder):
"""A constraint on the distance between two sentences' thought vectors.
Args:
word_embedding (str): The word embedding to use
min_cos_sim: the minimum cosine similarity between thought vectors
max_mse_dist: the maximum euclidean distance between thought vectors
word_embedding (textattack.shared.WordEmbedding): The word embedding to use
"""
def __init__(self, embedding_type="paragramcf", embedding_source=None, **kwargs):
self.word_embedding = WordEmbedding(embedding_type, embedding_source)
self.embedding_type = embedding_type
def __init__(
self,
embedding=TextAttackWordEmbedding.counterfitted_GLOVE_embedding(),
**kwargs
):
if not isinstance(embedding, WordEmbedding):
raise ValueError(
"`embedding` object must be of type `textattack.shared.WordEmbedding`."
)
self.word_embedding = embedding
super().__init__(**kwargs)
def clear_cache(self):
@@ -46,4 +51,4 @@ class ThoughtVector(SentenceEncoder):
def extra_repr_keys(self):
"""Set the extra representation of the constraint using these keys."""
return ["embedding_type"] + super().extra_repr_keys()
return ["word_embedding"] + super().extra_repr_keys()

View File

@@ -1,10 +1,11 @@
"""
Word Embedding Distance
-------------------------
--------------------------
"""
from textattack.constraints import Constraint
from textattack.shared import TextAttackWordEmbedding, WordEmbedding
from textattack.shared.validators import transformation_consists_of_word_swaps
from textattack.shared.word_embedding import WordEmbedding
class WordEmbeddingDistance(Constraint):
@@ -13,24 +14,17 @@ class WordEmbeddingDistance(Constraint):
inserted.
Args:
embedding_type (str): The word embedding to use.
embedding_source (str): The source of embeddings ("defaults" or "gensim")
include_unknown_words (bool): Whether or not the constraint is fulfilled
if the embedding of x or x_adv is unknown.
min_cos_sim: The minimum cosine similarity between word embeddings.
max_mse_dist: The maximum euclidean distance between word embeddings.
cased (bool): Whether embedding supports uppercase & lowercase
(defaults to False, or just lowercase).
compare_against_original (bool): If `True`, compare new `x_adv` against the original `x`.
Otherwise, compare it against the previous `x_adv`.
embedding (obj): Wrapper for word embedding.
include_unknown_words (bool): Whether or not the constraint is fulfilled if the embedding of x or x_adv is unknown.
min_cos_sim (:obj:`float`, optional): The minimum cosine similarity between word embeddings.
max_mse_dist (:obj:`float`, optional): The maximum euclidean distance between word embeddings.
cased (bool): Whether embedding supports uppercase & lowercase (defaults to False, or just lowercase).
compare_against_original (bool): If `True`, compare new `x_adv` against the original `x`. Otherwise, compare it against the previous `x_adv`.
"""
PATH = "word_embeddings"
def __init__(
self,
embedding_type="paragramcf",
embedding_source=None,
embedding=TextAttackWordEmbedding.counterfitted_GLOVE_embedding(),
include_unknown_words=True,
min_cos_sim=None,
max_mse_dist=None,
@@ -40,14 +34,17 @@ class WordEmbeddingDistance(Constraint):
super().__init__(compare_against_original)
self.include_unknown_words = include_unknown_words
self.cased = cased
if bool(min_cos_sim) == bool(max_mse_dist):
raise ValueError("You must choose either `min_cos_sim` or `max_mse_dist`.")
self.min_cos_sim = min_cos_sim
self.max_mse_dist = max_mse_dist
self.embedding_type = embedding_type
self.embedding_source = embedding_source
self.embedding = WordEmbedding(
embedding_type=embedding_type, embedding_source=embedding_source
)
if not isinstance(embedding, WordEmbedding):
raise ValueError(
"`embedding` object must be of type `textattack.shared.WordEmbedding`."
)
self.embedding = embedding
def get_cos_sim(self, a, b):
"""Returns the cosine similarity of words with IDs a and b."""
@@ -59,7 +56,7 @@ class WordEmbeddingDistance(Constraint):
def _check_constraint(self, transformed_text, reference_text):
"""Returns true if (``transformed_text`` and ``reference_text``) are
closer than ``self.min_cos_sim`` and ``self.max_mse_dist``."""
closer than ``self.min_cos_sim`` or ``self.max_mse_dist``."""
try:
indices = transformed_text.attack_attrs["newly_modified_indices"]
except KeyError:
@@ -77,8 +74,8 @@ class WordEmbeddingDistance(Constraint):
transformed_word = transformed_word.lower()
try:
ref_id = self.embedding.word2ind(ref_word)
transformed_id = self.embedding.word2ind(transformed_word)
ref_id = self.embedding.word2index(ref_word)
transformed_id = self.embedding.word2index(transformed_word)
except KeyError:
# This error is thrown if x or x_adv has no corresponding ID.
if self.include_unknown_words:
@@ -116,7 +113,7 @@ class WordEmbeddingDistance(Constraint):
else:
metric = "min_cos_sim"
return [
"embedding_type",
"embedding",
metric,
"cased",
"include_unknown_words",

View File

@@ -13,6 +13,6 @@ from .utils import logger
from . import validators
from .attacked_text import AttackedText
from .word_embedding import WordEmbedding
from .word_embedding import *
from .attack import Attack
from .checkpoint import Checkpoint

View File

@@ -49,15 +49,18 @@ def words_from_text(s, words_to_ignore=[]):
def default_class_repr(self):
extra_params = []
for key in self.extra_repr_keys():
extra_params.append(" (" + key + ")" + ": {" + key + "}")
if len(extra_params):
extra_str = "\n" + "\n".join(extra_params) + "\n"
extra_str = f"({extra_str})"
if hasattr(self, "extra_repr_keys"):
extra_params = []
for key in self.extra_repr_keys():
extra_params.append(" (" + key + ")" + ": {" + key + "}")
if len(extra_params):
extra_str = "\n" + "\n".join(extra_params) + "\n"
extra_str = f"({extra_str})"
else:
extra_str = ""
extra_str = extra_str.format(**self.__dict__)
else:
extra_str = ""
extra_str = extra_str.format(**self.__dict__)
return f"{self.__class__.__name__}{extra_str}"

View File

@@ -3,6 +3,7 @@ Shared loads word embeddings and related distances
=====================================================
"""
from abc import ABC, abstractmethod
from collections import defaultdict
import os
import pickle
@@ -10,77 +11,260 @@ import pickle
import numpy as np
import torch
import textattack
from textattack.shared import utils
class WordEmbedding:
"""An object that loads word embeddings and related distances.
class WordEmbedding(ABC):
"""Abstract class representing word embedding used by TextAttack.
This class specifies all the methods that is required to be defined
so that it can be used for transformation and constraints. For
custom word embedding not supported by TextAttack, please create a
class that inherits this object and implement the required methods.
However, please first check if you can use TextAttackWordEmbedding
object, which has a lot of internal methods implemented.
"""
@abstractmethod
def __getitem__(self, index):
"""Gets the embedding vector for word/id
Args:
index (Union[str|int]): `index` can either be word or integer representing the id of the word.
Returns:
vector (ndarray): 1-D embedding vector. If corresponding vector cannot be found for `index`, returns `None`.
"""
raise NotImplementedError()
@abstractmethod
def get_mse_dist(self, a, b):
"""Return MSE (i.e. L2-norm) distance between vector for word `a` and
vector for word `b`.
Since this is a metric, `get_mse_dist(a,b)` and `get_mse_dist(b,a)` should return the same value.
Args:
a (Union[str|int]): Either word or integer presenting the id of the word
b (Union[str|int]): Either word or integer presenting the id of the word
Returns:
distance (float): MSE (L2) distance
"""
raise NotImplementedError()
@abstractmethod
def get_cos_sim(self, a, b):
"""Return cosine similarity between vector for word `a` and vector for
word `b`.
Since this is a metric, `get_mse_dist(a,b)` and `get_mse_dist(b,a)` should return the same value.
Args:
a (Union[str|int]): Either word or integer presenting the id of the word
b (Union[str|int]): Either word or integer presenting the id of the word
Returns:
distance (float): cosine similarity
"""
raise NotImplementedError()
@abstractmethod
def word2index(self, word):
"""
Convert between word to id (i.e. index of word in embedding matrix)
Args:
word (str)
Returns:
index (int)
"""
raise NotImplementedError()
@abstractmethod
def index2word(self, index):
"""
Convert index to corresponding word
Args:
index (int)
Returns:
word (str)
"""
raise NotImplementedError()
@abstractmethod
def nearest_neighbours(self, index, topn):
"""
Get top-N nearest neighbours for a word
Args:
index (int): ID of the word for which we're finding the nearest neighbours
topn (int): Used for specifying N nearest neighbours
Returns:
neighbours (list[int]): List of indices of the nearest neighbours
"""
raise NotImplementedError()
__repr__ = __str__ = utils.default_class_repr
class TextAttackWordEmbedding(WordEmbedding):
"""An object that loads word embeddings and related distances for
TextAttack This has a lot of internal components (e.g. get consine
similarity) implemented. Consider using this class if you can provide the
appropriate input data to create the object.
Args:
embedding_type (str): The type of the embedding to load automatically
embedding_source (str): Source of embeddings provided,
"defaults" corresponds to the textattack s3 bucket
"gensim" expects a word2vec model
emedding_matrix (ndarray): 2-D array of NxD where N represents size of vocab and D is the dimension of embedding vectors.
word2index (Union[dict|object]): dictionary (or a similar object) that maps word to its index with in the embedding matrix.
index2word (Union[dict|object]): dictionary (or a similar object) that maps index to its word.
nn_matrix (ndarray): Matrix for precomputed nearest neighbours matrix. It should be a 2-D integer array of NxK
where N represents size of vocab and K is the top-K nearest neighbours. If this is not defined, we have to compute nearest neighbours
on the fly for `nearest_neighbours` method, which is costly.
"""
PATH = "word_embeddings"
EMBEDDINGS_AVAILABLE_IN_TEXTATTACK = {"paragramcf"}
def __init__(self, embedding_type=None, embedding_source=None):
def __init__(self, embedding_matrix, word2index, index2word, nn_matrix=None):
self.embedding_matrix = embedding_matrix
self._word2index = word2index
self._index2word = index2word
self.nn_matrix = nn_matrix
self.embedding_type = embedding_type or "paragramcf"
self.embedding_source = embedding_source or "defaults"
# Dictionary for caching results
self._mse_dist_mat = defaultdict(dict)
self._cos_sim_mat = defaultdict(dict)
self._nn_cache = {}
if self.embedding_source == "defaults":
if (
self.embedding_type
not in WordEmbedding.EMBEDDINGS_AVAILABLE_IN_TEXTATTACK
):
raise ValueError(
f"{self.embedding_type} is not available in TextAttack."
)
elif self.embedding_source not in ["gensim"]:
raise ValueError(f"{self.embedding_source} type is not supported.")
self._embeddings = None
self._word2index = None
self._index2word = None
self._cos_sim_mat = None
self._mse_dist_mat = None
self._nn = None
self._gensim_keyed_vectors = None
self._init_embeddings_from_type(self.embedding_source, self.embedding_type)
def _init_embeddings_from_defaults(self, embedding_type):
"""
Init embeddings prepared in the textattack s3 bucket
def __getitem__(self, index):
"""Gets the embedding vector for word/id
Args:
embedding_type:
index (Union[str|int]): `index` can either be word or integer representing the id of the word.
Returns:
vector (ndarray): 1-D embedding vector. If corresponding vector cannot be found for `index`, returns `None`.
"""
if isinstance(index, str):
try:
index = self._word2index[index]
except KeyError:
return None
try:
return self.embedding_matrix[index]
except IndexError:
# word embedding ID out of bounds
return None
def word2index(self, word):
"""
Convert between word to id (i.e. index of word in embedding matrix)
Args:
word (str)
Returns:
index (int)
"""
return self._word2index[word]
def index2word(self, index):
"""
Convert index to corresponding word
Args:
index (int)
Returns:
word (str)
"""
if embedding_type == "paragramcf":
word_embeddings_folder = "paragramcf"
word_embeddings_file = "paragram.npy"
word_list_file = "wordlist.pickle"
mse_dist_file = "mse_dist.p"
cos_sim_file = "cos_sim.p"
nn_matrix_file = "nn.npy"
return self._index2word[index]
def get_mse_dist(self, a, b):
"""Return MSE (i.e. L2-norm) distance between vector for word `a` and
vector for word `b`.
Since this is a metric, `get_mse_dist(a,b)` and `get_mse_dist(b,a)` should return the same value.
Args:
a (Union[str|int]): Either word or integer presenting the id of the word
b (Union[str|int]): Either word or integer presenting the id of the word
Returns:
distance (float): MSE (L2) distance
"""
if isinstance(a, str):
a = self._word2index[a]
if isinstance(b, str):
b = self._word2index[b]
a, b = min(a, b), max(a, b)
try:
mse_dist = self._mse_dist_mat[a][b]
except KeyError:
e1 = self.embedding_matrix[a]
e2 = self.embedding_matrix[b]
e1 = torch.tensor(e1).to(utils.device)
e2 = torch.tensor(e2).to(utils.device)
mse_dist = torch.sum((e1 - e2) ** 2).item()
self._mse_dist_mat[a][b] = mse_dist
return mse_dist
def get_cos_sim(self, a, b):
"""Return cosine similarity between vector for word `a` and vector for
word `b`.
Since this is a metric, `get_mse_dist(a,b)` and `get_mse_dist(b,a)` should return the same value.
Args:
a (Union[str|int]): Either word or integer presenting the id of the word
b (Union[str|int]): Either word or integer presenting the id of the word
Returns:
distance (float): cosine similarity
"""
if isinstance(a, str):
a = self._word2index[a]
if isinstance(b, str):
b = self._word2index[b]
a, b = min(a, b), max(a, b)
try:
cos_sim = self._cos_sim_mat[a][b]
except KeyError:
e1 = self.embedding_matrix[a]
e2 = self.embedding_matrix[b]
e1 = torch.tensor(e1).to(utils.device)
e2 = torch.tensor(e2).to(utils.device)
cos_sim = torch.nn.CosineSimilarity(dim=0)(e1, e2).item()
self._cos_sim_mat[a][b] = cos_sim
return cos_sim
def nearest_neighbours(self, index, topn):
"""
Get top-N nearest neighbours for a word
Args:
index (int): ID of the word for which we're finding the nearest neighbours
topn (int): Used for specifying N nearest neighbours
Returns:
neighbours (list[int]): List of indices of the nearest neighbours
"""
if isinstance(index, str):
index = self._word2index[index]
if self.nn_matrix is not None:
nn = self.nn_matrix[index][1 : (topn + 1)]
else:
raise ValueError(f"Could not find word embedding {embedding_type}")
try:
nn = self._nn_cache[index]
except KeyError:
embedding = torch.tensor(self.embedding_matrix).to(utils.device)
vector = torch.tensor(self.embedding_matrix[index]).to(utils.device)
dist = torch.norm(embedding - vector, dim=1, p=None)
# Since closest neighbour will be the same word, we consider N+1 nearest neighbours
nn = dist.topk(topn + 1, largest=False)[1:].tolist()
self._nn_cache[index] = nn
return nn
@staticmethod
def counterfitted_GLOVE_embedding():
"""Returns a prebuilt counter-fitted GLOVE word embedding proposed by
"Counter-fitting Word Vectors to Linguistic Constraints" (Mrkšić et
al., 2016)"""
word_embeddings_folder = "paragramcf"
word_embeddings_file = "paragram.npy"
word_list_file = "wordlist.pickle"
mse_dist_file = "mse_dist.p"
cos_sim_file = "cos_sim.p"
nn_matrix_file = "nn.npy"
# Download embeddings if they're not cached.
word_embeddings_folder = os.path.join(
WordEmbedding.PATH, word_embeddings_folder
TextAttackWordEmbedding.PATH, word_embeddings_folder
)
word_embeddings_folder = textattack.shared.utils.download_if_needed(
word_embeddings_folder
)
word_embeddings_folder = utils.download_if_needed(word_embeddings_folder)
# Concatenate folder names to create full path to files.
word_embeddings_file = os.path.join(
word_embeddings_folder, word_embeddings_file
@@ -91,234 +275,152 @@ class WordEmbedding:
nn_matrix_file = os.path.join(word_embeddings_folder, nn_matrix_file)
# loading the files
self._embeddings = np.load(word_embeddings_file)
self._word2index = np.load(word_list_file, allow_pickle=True)
if os.path.exists(mse_dist_file):
with open(mse_dist_file, "rb") as f:
self._mse_dist_mat = pickle.load(f)
else:
self._mse_dist_mat = defaultdict(dict)
if os.path.exists(cos_sim_file):
with open(cos_sim_file, "rb") as f:
self._cos_sim_mat = pickle.load(f)
else:
self._cos_sim_mat = defaultdict(dict)
embedding_matrix = np.load(word_embeddings_file)
word2index = np.load(word_list_file, allow_pickle=True)
index2word = {}
for word, index in word2index.items():
index2word[index] = word
nn_matrix = np.load(nn_matrix_file)
self._nn = np.load(nn_matrix_file)
word_embedding = TextAttackWordEmbedding(
embedding_matrix, word2index, index2word, nn_matrix
)
with open(mse_dist_file, "rb") as f:
mse_dist_mat = pickle.load(f)
with open(cos_sim_file, "rb") as f:
cos_sim_mat = pickle.load(f)
# Build glove dict and index.
self._index2word = dict()
for word, index in self._word2index.items():
self._index2word[index] = word
word_embedding._mse_dist_mat = mse_dist_mat
word_embedding._cos_sim_mat = cos_sim_mat
def _init_embeddings_from_gensim(self, embedding_type):
"""
Initialize word embedding from a gensim word2vec model
Args:
embedding_type:
return word_embedding
Returns:
"""
import gensim
class GensimWordEmbedding(WordEmbedding):
"""Wraps Gensim's KeyedVectors
(https://radimrehurek.com/gensim/models/keyedvectors.html)"""
if embedding_type.endswith(".bin"):
self._gensim_keyed_vectors = (
gensim.models.KeyedVectors.load_word2vec_format(
embedding_type, binary=True
def __init__(self, keyed_vectors_or_path):
gensim = utils.LazyLoader("gensim", globals(), "gensim")
if isinstance(keyed_vectors_or_path, str):
if keyed_vectors_or_path.endswith(".bin"):
self.keyed_vectors = gensim.models.KeyedVectors.load_word2vec_format(
keyed_vectors_or_path, binary=True
)
)
else:
self.keyed_vectors = gensim.models.KeyedVectors.load_word2vec_format(
keyed_vectors_or_path
)
elif isinstance(keyed_vectors_or_path, gensim.models.KeyedVectors):
self.keyed_vectors = keyed_vectors_or_path
else:
self._gensim_keyed_vectors = (
gensim.models.KeyedVectors.load_word2vec_format(embedding_type)
raise ValueError(
"`keyed_vectors_or_path` argument must either be `gensim.models.KeyedVectors` object "
"or a path pointing to the saved KeyedVector object"
)
self._gensim_keyed_vectors.init_sims()
def _init_embeddings_from_type(self, embedding_source, embedding_type):
"""Initializes embedding based on the source.
Downloads and loads embeddings into memory.
"""
if embedding_source == "defaults":
self._init_embeddings_from_defaults(embedding_type)
elif embedding_source == "gensim":
self._init_embeddings_from_gensim(embedding_type)
else:
raise ValueError(f"Not supported word embedding source {embedding_source}")
self.keyed_vectors.init_sims()
self._mse_dist_mat = defaultdict(dict)
self._cos_sim_mat = defaultdict(dict)
def __getitem__(self, index):
"""Gets a word embedding by word or ID.
If word or ID not found, returns None.
"""
if self.embedding_source == "defaults":
if isinstance(index, str):
try:
index = self._word2index[index]
except KeyError:
return None
try:
return self._embeddings[index]
except IndexError:
# word embedding ID out of bounds
return None
elif self.embedding_source == "gensim":
if isinstance(index, str):
try:
index = self._gensim_keyed_vectors.vocab.get(index).index
except KeyError:
return None
try:
return self._gensim_keyed_vectors.vectors_norm[index]
except IndexError:
# word embedding ID out of bounds
return None
else:
raise ValueError(
f"Not supported word embedding source {self.embedding_source}"
)
def get_cos_sim(self, a, b):
"""
get cosine similarity of two words/IDs
"""Gets the embedding vector for word/id
Args:
a:
b:
index (Union[str|int]): `index` can either be word or integer representing the id of the word.
Returns:
vector (ndarray): 1-D embedding vector. If corresponding vector cannot be found for `index`, returns `None`.
"""
if isinstance(index, str):
try:
index = self.keyed_vectors.vocab.get(index).index
except KeyError:
return None
try:
return self.keyed_vectors.vectors_norm[index]
except IndexError:
# word embedding ID out of bounds
return None
def word2index(self, word):
"""
Convert between word to id (i.e. index of word in embedding matrix)
Args:
word (str)
Returns:
index (int)
"""
vocab = self.keyed_vectors.vocab.get(word)
if vocab is None:
raise KeyError(word)
return vocab.index
def index2word(self, index):
"""
Convert index to corresponding word
Args:
index (int)
Returns:
word (str)
"""
if self.embedding_source == "defaults":
if isinstance(a, str):
a = self._word2index[a]
if isinstance(b, str):
b = self._word2index[b]
a, b = min(a, b), max(a, b)
try:
cos_sim = self._cos_sim_mat[a][b]
except KeyError:
e1 = self._embeddings[a]
e2 = self._embeddings[b]
e1 = torch.tensor(e1).to(utils.device)
e2 = torch.tensor(e2).to(utils.device)
cos_sim = torch.nn.CosineSimilarity(dim=0)(e1, e2).item()
self._cos_sim_mat[a][b] = cos_sim
return cos_sim
elif self.embedding_source == "gensim":
if not isinstance(a, str):
a = self._gensim_keyed_vectors.index2word[a]
if not isinstance(b, str):
b = self._gensim_keyed_vectors.index2word[b]
cos_sim = self._gensim_keyed_vectors.similarity(a, b)
return cos_sim
else:
raise ValueError(
f"Not supported word embedding source {self.embedding_source}"
)
try:
# this is a list, so the error would be IndexError
return self.keyed_vectors.index2word[index]
except IndexError:
raise KeyError(index)
def get_mse_dist(self, a, b):
"""
get mse distance of two IDs
"""Return MSE (i.e. L2-norm) distance between vector for word `a` and
vector for word `b`.
Since this is a metric, `get_mse_dist(a,b)` and `get_mse_dist(b,a)` should return the same value.
Args:
a:
b:
a (Union[str|int]): Either word or integer presenting the id of the word
b (Union[str|int]): Either word or integer presenting the id of the word
Returns:
distance (float): MSE (L2) distance
"""
if self.embedding_source == "defaults":
a, b = min(a, b), max(a, b)
try:
mse_dist = self._mse_dist_mat[a][b]
except KeyError:
e1 = self._embeddings[a]
e2 = self._embeddings[b]
e1 = torch.tensor(e1).to(utils.device)
e2 = torch.tensor(e2).to(utils.device)
mse_dist = torch.sum((e1 - e2) ** 2).item()
self._mse_dist_mat[a][b] = mse_dist
return mse_dist
elif self.embedding_source == "gensim":
if self._mse_dist_mat is None:
self._mse_dist_mat = defaultdict(dict)
try:
mse_dist = self._mse_dist_mat[a][b]
except KeyError:
e1 = self._gensim_keyed_vectors.vectors_norm[a]
e2 = self._gensim_keyed_vectors.vectors_norm[b]
e1 = torch.tensor(e1).to(utils.device)
e2 = torch.tensor(e2).to(utils.device)
mse_dist = torch.sum((e1 - e2) ** 2).item()
self._mse_dist_mat[a][b] = mse_dist
return mse_dist
else:
raise ValueError(
f"Not supported word embedding source {self.embedding_source}"
)
try:
mse_dist = self._mse_dist_mat[a][b]
except KeyError:
e1 = self.keyed_vectors.vectors_norm[a]
e2 = self.keyed_vectors.vectors_norm[b]
e1 = torch.tensor(e1).to(utils.device)
e2 = torch.tensor(e2).to(utils.device)
mse_dist = torch.sum((e1 - e2) ** 2).item()
self._mse_dist_mat[a][b] = mse_dist
return mse_dist
def word2ind(self, word):
"""
word to index
def get_cos_sim(self, a, b):
"""Return cosine similarity between vector for word `a` and vector for
word `b`.
Since this is a metric, `get_mse_dist(a,b)` and `get_mse_dist(b,a)` should return the same value.
Args:
word:
a (Union[str|int]): Either word or integer presenting the id of the word
b (Union[str|int]): Either word or integer presenting the id of the word
Returns:
distance (float): cosine similarity
"""
if self.embedding_source == "defaults":
return self._word2index[word]
elif self.embedding_source == "gensim":
vocab = self._gensim_keyed_vectors.vocab.get(word)
if vocab is None:
raise KeyError(word)
return vocab.index
else:
raise ValueError(
f"Not supported word embedding source {self.embedding_source}"
)
if not isinstance(a, str):
a = self.keyed_vectors.index2word[a]
if not isinstance(b, str):
b = self.keyed_vectors.index2word[b]
cos_sim = self.keyed_vectors.similarity(a, b)
return cos_sim
def ind2word(self, index):
def nearest_neighbours(self, index, topn, return_words=True):
"""
index to word
Get top-N nearest neighbours for a word
Args:
index:
index (int): ID of the word for which we're finding the nearest neighbours
topn (int): Used for specifying N nearest neighbours
Returns:
neighbours (list[int]): List of indices of the nearest neighbours
"""
if self.embedding_source == "defaults":
return self._index2word[index]
elif self.embedding_source == "gensim":
try:
# this is a list, so the error would be IndexError
return self._gensim_keyed_vectors.index2word[index]
except IndexError:
raise KeyError(index)
else:
raise ValueError(
f"Not supported word embedding source {self.embedding_source}"
)
def nn(self, index, topn):
"""
get top n nearest neighbours for a word
Args:
index:
topn:
Returns:
"""
if self.embedding_source == "defaults":
return self._nn[index][1 : (topn + 1)]
elif self.embedding_source == "gensim":
word = self._gensim_keyed_vectors.index2word[index]
return [
self._gensim_keyed_vectors.index2word.index(i[0])
for i in self._gensim_keyed_vectors.similar_by_word(word, topn)
]
else:
raise ValueError(
f"Not supported word embedding source {self.embedding_source}"
)
word = self.keyed_vectors.index2word[index]
return [
self.keyed_vectors.index2word.index(i[0])
for i in self.keyed_vectors.similar_by_word(word, topn)
]

View File

@@ -3,30 +3,32 @@ Word Swap by Embedding
============================================
"""
from textattack.shared.word_embedding import WordEmbedding
from textattack.shared import TextAttackWordEmbedding, WordEmbedding
from textattack.transformations.word_swap import WordSwap
class WordSwapEmbedding(WordSwap):
"""Transforms an input by replacing its words with synonyms in the word
embedding space."""
embedding space.
PATH = "word_embeddings"
Args:
max_candidates (int): maximum number of synonyms to pick
embedding (textattack.shared.WordEmbedding): Wrapper for word embedding
"""
def __init__(
self,
max_candidates=15,
embedding_type="paragramcf",
embedding_source=None,
embedding=TextAttackWordEmbedding.counterfitted_GLOVE_embedding(),
**kwargs
):
super().__init__(**kwargs)
self.max_candidates = max_candidates
self.embedding_type = embedding_type
self.embedding_source = embedding_source
self.embedding = WordEmbedding(
embedding_type=embedding_type, embedding_source=embedding_source
)
if not isinstance(embedding, WordEmbedding):
raise ValueError(
"`embedding` object must be of type `textattack.shared.WordEmbedding`."
)
self.embedding = embedding
def _get_replacement_words(self, word):
"""Returns a list of possible 'candidate words' to replace a word in a
@@ -35,11 +37,11 @@ class WordSwapEmbedding(WordSwap):
Based on nearest neighbors selected word embeddings.
"""
try:
word_id = self.embedding.word2ind(word.lower())
nnids = self.embedding.nn(word_id, self.max_candidates)
word_id = self.embedding.word2index(word.lower())
nnids = self.embedding.nearest_neighbours(word_id, self.max_candidates)
candidate_words = []
for i, nbr_id in enumerate(nnids):
nbr_word = self.embedding.ind2word(nbr_id)
nbr_word = self.embedding.index2word(nbr_id)
candidate_words.append(recover_word_case(nbr_word, word))
return candidate_words
except KeyError:
@@ -47,7 +49,7 @@ class WordSwapEmbedding(WordSwap):
return []
def extra_repr_keys(self):
return ["max_candidates", "embedding_type"]
return ["max_candidates", "embedding"]
def recover_word_case(word, reference_word):