1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

model testing script & fix imports

This commit is contained in:
Jack Morris
2019-11-10 21:39:25 -05:00
parent 2c846f5559
commit a3308cd1fe
11 changed files with 88 additions and 51 deletions

View File

@@ -5,6 +5,8 @@ import torch.nn as nn
import textattack.utils as utils
logger = utils.get_logger()
class EmbeddingLayer(nn.Module):
"""
A layer of a model that replaces word IDs with their embeddings.
@@ -24,7 +26,7 @@ class EmbeddingLayer(nn.Module):
assert word not in word2id, "Duplicate words in pre-trained embeddings"
word2id[word] = len(word2id)
print(f'{len(word2id)} pre-trained word embeddings loaded.\n')
logger.debug(f'{len(word2id)} pre-trained word embeddings loaded.\n')
n_d = len(embvecs[0])