1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/textattack/shared/word_embedding.py
2020-06-17 19:47:33 -04:00

73 lines
2.4 KiB
Python

import os
import pickle
import numpy as np
import textattack
class WordEmbedding:
""" An object that loads word embeddings and related distances.
Args:
embedding_type (str): The type of the embedding to load
"""
PATH = "word_embeddings"
def __init__(self, embedding_type="paragramcf"):
self.embedding_type = embedding_type
if embedding_type == "paragramcf":
word_embeddings_folder = "paragramcf"
word_embeddings_file = "paragram.npy"
word_list_file = "wordlist.pickle"
mse_dist_file = "mse_dist.p"
cos_sim_file = "cos_sim.p"
else:
raise ValueError(f"Could not find word embedding {word_embedding}")
# Download embeddings if they're not cached.
word_embeddings_root_path = textattack.shared.utils.download_if_needed(
WordEmbedding.PATH
)
word_embeddings_folder = os.path.join(
word_embeddings_root_path, word_embeddings_folder
)
# Concatenate folder names to create full path to files.
word_embeddings_file = os.path.join(
word_embeddings_folder, word_embeddings_file
)
word_list_file = os.path.join(word_embeddings_folder, word_list_file)
mse_dist_file = os.path.join(word_embeddings_folder, mse_dist_file)
cos_sim_file = os.path.join(word_embeddings_folder, cos_sim_file)
# Actually load the files from disk.
self.embeddings = np.load(word_embeddings_file)
self.word2index = np.load(word_list_file, allow_pickle=True)
# Precomputed distance matrices store distances at mat[x][y], where
# x and y are word IDs and x < y.
if os.path.exists(mse_dist_file):
with open(mse_dist_file, "rb") as f:
self.mse_dist_mat = pickle.load(f)
else:
self.mse_dist_mat = {}
if os.path.exists(cos_sim_file):
with open(cos_sim_file, "rb") as f:
self.cos_sim_mat = pickle.load(f)
else:
self.cos_sim_mat = {}
def __getitem__(self, index):
""" Gets a word embedding by word or ID.
If word or ID not found, returns None.
"""
if isinstance(index, str):
try:
index = self.word2index[index]
except KeyError:
return None
return self.embeddings[index]