mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
support custom word embedding
This commit is contained in:
@@ -10,12 +10,36 @@ class WordEmbedding:
|
||||
"""An object that loads word embeddings and related distances.
|
||||
|
||||
Args:
|
||||
embedding_type (str): The type of the embedding to load
|
||||
embedding_type (str): The type of the embedding to load automatically
|
||||
embeddings: A dictionary or matrix that maps word embeddings from
|
||||
their IDs to vectors (torch tensors or numpy ndarrays). Must be
|
||||
provided for custom embeddings, when embedding_type is not provided.
|
||||
If both `embedding_type` and `embeddings` are provided, `embeddings`
|
||||
overrides `embedding_type`.
|
||||
word2index: A dictionary that maps words by string to ID. Not required,
|
||||
but useful if the user intends to use this class to look up word
|
||||
embeddings by their string. Can omit this argument and solely query
|
||||
words by ID from `self.embeddings`.
|
||||
"""
|
||||
|
||||
PATH = "word_embeddings"
|
||||
|
||||
def __init__(self, embedding_type="paragramcf"):
|
||||
def __init__(self, embedding_type="paragramcf", embeddings=None, word2index=None):
|
||||
if embeddings is not None:
|
||||
self.embeddings = embeddings
|
||||
self.word2index = word2index or {}
|
||||
elif embedding_type:
|
||||
self._init_embeddings_from_type(embedding_type)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Must supply `embedding_type` or `embeddings` as parameters."
|
||||
)
|
||||
|
||||
def _init_embeddings_from_type(self, embedding_type):
|
||||
"""Initializes self.embeddings based on the type of embedding.
|
||||
|
||||
Downloads and loads embeddings into memory.
|
||||
"""
|
||||
self.embedding_type = embedding_type
|
||||
if embedding_type == "paragramcf":
|
||||
word_embeddings_folder = "paragramcf"
|
||||
@@ -69,4 +93,8 @@ class WordEmbedding:
|
||||
index = self.word2index[index]
|
||||
except KeyError:
|
||||
return None
|
||||
return self.embeddings[index]
|
||||
try:
|
||||
return self.embeddings[index]
|
||||
except IndexError:
|
||||
# word embedding ID out of bounds
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user