mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
fix bugs
This commit is contained in:
@@ -19,4 +19,5 @@ tokenizers==0.8.1-rc2
|
||||
tqdm>=4.27,<4.50.0
|
||||
word2number
|
||||
num2words
|
||||
more-itertools
|
||||
more-itertools
|
||||
PySocks!=1.5.7,>=1.5.6
|
||||
@@ -18,7 +18,8 @@ from textattack.commands.train_model import TrainModelCommand
|
||||
|
||||
def main():
|
||||
|
||||
"""This is the main command line parer and entry function to use TextAttack via command lines
|
||||
"""This is the main command line parer and entry function to use TextAttack
|
||||
via command lines.
|
||||
|
||||
texattack <command> [<args>]
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from . import lm_data_utils, lm_utils
|
||||
|
||||
tf = utils.LazyLoader("tensorflow", globals(), "tensorflow")
|
||||
|
||||
tf.get_logger().setLevel("INFO")
|
||||
|
||||
# @TODO automatically choose between GPU and CPU.
|
||||
|
||||
@@ -31,6 +30,7 @@ class GoogLMHelper:
|
||||
CACHE_PATH = "constraints/semantics/language-models/alzantot-goog-lm"
|
||||
|
||||
def __init__(self):
|
||||
tf.get_logger().setLevel("INFO")
|
||||
lm_folder = utils.download_if_needed(GoogLMHelper.CACHE_PATH)
|
||||
self.PBTXT_PATH = os.path.join(lm_folder, "graph-2016-09-10-gpu.pbtxt")
|
||||
self.CKPT_PATH = os.path.join(lm_folder, "ckpt-*")
|
||||
|
||||
@@ -13,8 +13,6 @@ tf = LazyLoader("tensorflow", globals(), "tensorflow")
|
||||
|
||||
from google.protobuf import text_format # noqa: E402
|
||||
|
||||
tf.get_logger().setLevel("INFO")
|
||||
|
||||
|
||||
def LoadModel(sess, graph, gd_file, ckpt_file):
|
||||
"""Load the model from GraphDef and Checkpoint.
|
||||
@@ -26,6 +24,7 @@ def LoadModel(sess, graph, gd_file, ckpt_file):
|
||||
Returns:
|
||||
TensorFlow session and tensors dict.
|
||||
"""
|
||||
tf.get_logger().setLevel("INFO")
|
||||
with graph.as_default():
|
||||
sys.stderr.write("Recovering graph.\n")
|
||||
with tf.io.gfile.GFile(gd_file) as f:
|
||||
|
||||
@@ -7,10 +7,13 @@ import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchfile
|
||||
|
||||
from textattack.shared.utils import LazyLoader
|
||||
|
||||
from .rnn_model import RNNModel
|
||||
|
||||
torchfile = LazyLoader("torchfile", globals(), "torchfile")
|
||||
|
||||
|
||||
class QueryHandler:
|
||||
def __init__(self, model, word_to_idx, mapto, device):
|
||||
|
||||
Reference in New Issue
Block a user