mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
model testing script & fix imports
This commit is contained in:
@@ -5,6 +5,8 @@ import torch.nn as nn
|
||||
|
||||
import textattack.utils as utils
|
||||
|
||||
logger = utils.get_logger()
|
||||
|
||||
class EmbeddingLayer(nn.Module):
|
||||
"""
|
||||
A layer of a model that replaces word IDs with their embeddings.
|
||||
@@ -24,7 +26,7 @@ class EmbeddingLayer(nn.Module):
|
||||
assert word not in word2id, "Duplicate words in pre-trained embeddings"
|
||||
word2id[word] = len(word2id)
|
||||
|
||||
print(f'{len(word2id)} pre-trained word embeddings loaded.\n')
|
||||
logger.debug(f'{len(word2id)} pre-trained word embeddings loaded.\n')
|
||||
|
||||
n_d = len(embvecs[0])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user