1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

interactive fixes

This commit is contained in:
uvafan
2020-03-05 19:48:19 -05:00
parent e5cda5731b
commit bc5e4d57a7
3 changed files with 10 additions and 12 deletions

View File

@@ -234,7 +234,7 @@ def parse_recipe_from_args(model, args):
raise ValueError('Invalid recipe {args.recipe}') raise ValueError('Invalid recipe {args.recipe}')
return recipe return recipe
def parse_model_and_attack_from_args(args): def parse_goal_function_and_attack_from_args(args):
if ':' in args.model: if ':' in args.model:
model_name, params = args.model.split(':') model_name, params = args.model.split(':')
if model_name not in MODEL_CLASS_NAMES: if model_name not in MODEL_CLASS_NAMES:
@@ -259,7 +259,7 @@ def parse_model_and_attack_from_args(args):
attack = eval(f'{ATTACK_CLASS_NAMES[args.attack]}(goal_function, transformation, constraints=constraints)') attack = eval(f'{ATTACK_CLASS_NAMES[args.attack]}(goal_function, transformation, constraints=constraints)')
else: else:
raise ValueError(f'Error: unsupported attack {args.attack}') raise ValueError(f'Error: unsupported attack {args.attack}')
return model, attack return goal_function, attack
def parse_logger_from_args(args):# Create logger def parse_logger_from_args(args):# Create logger
attack_logger = textattack.loggers.AttackLogger() attack_logger = textattack.loggers.AttackLogger()

View File

@@ -25,11 +25,11 @@ def attack_from_queue(args, in_queue, out_queue):
gpu_id = torch.multiprocessing.current_process()._identity[0] - 2 gpu_id = torch.multiprocessing.current_process()._identity[0] - 2
print('Using GPU #' + str(gpu_id)) print('Using GPU #' + str(gpu_id))
set_env_variables(gpu_id) set_env_variables(gpu_id)
model, attack = parse_model_and_attack_from_args(args) _, attack = parse_goal_function_and_attack_from_args(args)
while not in_queue.empty(): while not in_queue.empty():
try: try:
label, text = in_queue.get() output, text = in_queue.get()
results_gen = attack.attack_dataset([(label, text)], num_examples=1) results_gen = attack.attack_dataset([(output, text)], num_examples=1)
result = next(results_gen) result = next(results_gen)
out_queue.put(result) out_queue.put(result)
except Exception as e: except Exception as e:
@@ -103,4 +103,4 @@ def pytorch_multiprocessing_workaround():
pass pass
if __name__ == '__main__': if __name__ == '__main__':
run(get_args()) run(get_args())

View File

@@ -23,7 +23,7 @@ def run(args):
start_time = time.time() start_time = time.time()
# Models and Attack # Models and Attack
model, attack = parse_model_and_attack_from_args(args) goal_function, attack = parse_goal_function_and_attack_from_args(args)
# Logger # Logger
attack_logger = parse_logger_from_args(args) attack_logger = parse_logger_from_args(args)
@@ -47,12 +47,10 @@ def run(args):
tokenized_text = textattack.shared.tokenized_text.TokenizedText(text, model.tokenizer) tokenized_text = textattack.shared.tokenized_text.TokenizedText(text, model.tokenizer)
pred = attack._call_model([tokenized_text]) result = goal_function.get_results([tokenized_text])[0]
label = int(pred.argmax())
print('Attacking...') print('Attacking...')
result = next(attack.attack_dataset([(label, text)])) result = next(attack.attack_dataset([(result.output, text, False)]))
print(result.__str__(color_method='stdout')) print(result.__str__(color_method='stdout'))
else: else:
@@ -83,4 +81,4 @@ def run(args):
print(f'Attack time: {time.time() - load_time}s') print(f'Attack time: {time.time() - load_time}s')
if __name__ == '__main__': if __name__ == '__main__':
run(get_args()) run(get_args())