1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

interactive fixes

This commit is contained in:
uvafan
2020-03-05 19:48:19 -05:00
parent e5cda5731b
commit bc5e4d57a7
3 changed files with 10 additions and 12 deletions

View File

@@ -234,7 +234,7 @@ def parse_recipe_from_args(model, args):
raise ValueError('Invalid recipe {args.recipe}')
return recipe
def parse_model_and_attack_from_args(args):
def parse_goal_function_and_attack_from_args(args):
if ':' in args.model:
model_name, params = args.model.split(':')
if model_name not in MODEL_CLASS_NAMES:
@@ -259,7 +259,7 @@ def parse_model_and_attack_from_args(args):
attack = eval(f'{ATTACK_CLASS_NAMES[args.attack]}(goal_function, transformation, constraints=constraints)')
else:
raise ValueError(f'Error: unsupported attack {args.attack}')
return model, attack
return goal_function, attack
def parse_logger_from_args(args):# Create logger
attack_logger = textattack.loggers.AttackLogger()

View File

@@ -25,11 +25,11 @@ def attack_from_queue(args, in_queue, out_queue):
gpu_id = torch.multiprocessing.current_process()._identity[0] - 2
print('Using GPU #' + str(gpu_id))
set_env_variables(gpu_id)
model, attack = parse_model_and_attack_from_args(args)
_, attack = parse_goal_function_and_attack_from_args(args)
while not in_queue.empty():
try:
label, text = in_queue.get()
results_gen = attack.attack_dataset([(label, text)], num_examples=1)
output, text = in_queue.get()
results_gen = attack.attack_dataset([(output, text)], num_examples=1)
result = next(results_gen)
out_queue.put(result)
except Exception as e:

View File

@@ -23,7 +23,7 @@ def run(args):
start_time = time.time()
# Models and Attack
model, attack = parse_model_and_attack_from_args(args)
goal_function, attack = parse_goal_function_and_attack_from_args(args)
# Logger
attack_logger = parse_logger_from_args(args)
@@ -47,12 +47,10 @@ def run(args):
tokenized_text = textattack.shared.tokenized_text.TokenizedText(text, model.tokenizer)
pred = attack._call_model([tokenized_text])
label = int(pred.argmax())
result = goal_function.get_results([tokenized_text])[0]
print('Attacking...')
result = next(attack.attack_dataset([(label, text)]))
result = next(attack.attack_dataset([(result.output, text, False)]))
print(result.__str__(color_method='stdout'))
else: