1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

support custom word embedding

This commit is contained in:
Jack Morris
2020-09-11 12:24:51 -04:00
parent 4fc83907c6
commit f97475b98d
2 changed files with 56 additions and 3 deletions

View File

@@ -10,12 +10,36 @@ class WordEmbedding:
"""An object that loads word embeddings and related distances.
Args:
embedding_type (str): The type of the embedding to load
embedding_type (str): The type of the embedding to load automatically
embeddings: A dictionary or matrix that maps word embeddings from
their IDs to vectors (torch tensors or numpy ndarrays). Must be
provided for custom embeddings, when embedding_type is not provided.
If both `embedding_type` and `embeddings` are provided, `embeddings`
overrides `embedding_type`.
word2index: A dictionary that maps words by string to ID. Not required,
but useful if the user intends to use this class to look up word
embeddings by their string. Can omit this argument and solely query
words by ID from `self.embeddings`.
"""
PATH = "word_embeddings"
def __init__(self, embedding_type="paragramcf"):
def __init__(self, embedding_type="paragramcf", embeddings=None, word2index=None):
if embeddings is not None:
self.embeddings = embeddings
self.word2index = word2index or {}
elif embedding_type:
self._init_embeddings_from_type(embedding_type)
else:
raise ValueError(
"Must supply `embedding_type` or `embeddings` as parameters."
)
def _init_embeddings_from_type(self, embedding_type):
"""Initializes self.embeddings based on the type of embedding.
Downloads and loads embeddings into memory.
"""
self.embedding_type = embedding_type
if embedding_type == "paragramcf":
word_embeddings_folder = "paragramcf"
@@ -69,4 +93,8 @@ class WordEmbedding:
index = self.word2index[index]
except KeyError:
return None
return self.embeddings[index]
try:
return self.embeddings[index]
except IndexError:
# word embedding ID out of bounds
return None