1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
This commit is contained in:
Jin Yong Yoo
2020-11-10 21:45:38 -05:00
parent 2328b7f76c
commit b148bd2818
5 changed files with 10 additions and 6 deletions

View File

@@ -19,4 +19,5 @@ tokenizers==0.8.1-rc2
tqdm>=4.27,<4.50.0
word2number
num2words
more-itertools
more-itertools
PySocks!=1.5.7,>=1.5.6

View File

@@ -18,7 +18,8 @@ from textattack.commands.train_model import TrainModelCommand
def main():
"""This is the main command line parer and entry function to use TextAttack via command lines
"""This is the main command line parer and entry function to use TextAttack
via command lines.
texattack <command> [<args>]

View File

@@ -19,7 +19,6 @@ from . import lm_data_utils, lm_utils
tf = utils.LazyLoader("tensorflow", globals(), "tensorflow")
tf.get_logger().setLevel("INFO")
# @TODO automatically choose between GPU and CPU.
@@ -31,6 +30,7 @@ class GoogLMHelper:
CACHE_PATH = "constraints/semantics/language-models/alzantot-goog-lm"
def __init__(self):
tf.get_logger().setLevel("INFO")
lm_folder = utils.download_if_needed(GoogLMHelper.CACHE_PATH)
self.PBTXT_PATH = os.path.join(lm_folder, "graph-2016-09-10-gpu.pbtxt")
self.CKPT_PATH = os.path.join(lm_folder, "ckpt-*")

View File

@@ -13,8 +13,6 @@ tf = LazyLoader("tensorflow", globals(), "tensorflow")
from google.protobuf import text_format # noqa: E402
tf.get_logger().setLevel("INFO")
def LoadModel(sess, graph, gd_file, ckpt_file):
"""Load the model from GraphDef and Checkpoint.
@@ -26,6 +24,7 @@ def LoadModel(sess, graph, gd_file, ckpt_file):
Returns:
TensorFlow session and tensors dict.
"""
tf.get_logger().setLevel("INFO")
with graph.as_default():
sys.stderr.write("Recovering graph.\n")
with tf.io.gfile.GFile(gd_file) as f:

View File

@@ -7,10 +7,13 @@ import os
import numpy as np
import torch
import torchfile
from textattack.shared.utils import LazyLoader
from .rnn_model import RNNModel
torchfile = LazyLoader("torchfile", globals(), "torchfile")
class QueryHandler:
def __init__(self, model, word_to_idx, mapto, device):