mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
interactive fixes
This commit is contained in:
@@ -234,7 +234,7 @@ def parse_recipe_from_args(model, args):
|
|||||||
raise ValueError('Invalid recipe {args.recipe}')
|
raise ValueError('Invalid recipe {args.recipe}')
|
||||||
return recipe
|
return recipe
|
||||||
|
|
||||||
def parse_model_and_attack_from_args(args):
|
def parse_goal_function_and_attack_from_args(args):
|
||||||
if ':' in args.model:
|
if ':' in args.model:
|
||||||
model_name, params = args.model.split(':')
|
model_name, params = args.model.split(':')
|
||||||
if model_name not in MODEL_CLASS_NAMES:
|
if model_name not in MODEL_CLASS_NAMES:
|
||||||
@@ -259,7 +259,7 @@ def parse_model_and_attack_from_args(args):
|
|||||||
attack = eval(f'{ATTACK_CLASS_NAMES[args.attack]}(goal_function, transformation, constraints=constraints)')
|
attack = eval(f'{ATTACK_CLASS_NAMES[args.attack]}(goal_function, transformation, constraints=constraints)')
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Error: unsupported attack {args.attack}')
|
raise ValueError(f'Error: unsupported attack {args.attack}')
|
||||||
return model, attack
|
return goal_function, attack
|
||||||
|
|
||||||
def parse_logger_from_args(args):# Create logger
|
def parse_logger_from_args(args):# Create logger
|
||||||
attack_logger = textattack.loggers.AttackLogger()
|
attack_logger = textattack.loggers.AttackLogger()
|
||||||
|
|||||||
@@ -25,11 +25,11 @@ def attack_from_queue(args, in_queue, out_queue):
|
|||||||
gpu_id = torch.multiprocessing.current_process()._identity[0] - 2
|
gpu_id = torch.multiprocessing.current_process()._identity[0] - 2
|
||||||
print('Using GPU #' + str(gpu_id))
|
print('Using GPU #' + str(gpu_id))
|
||||||
set_env_variables(gpu_id)
|
set_env_variables(gpu_id)
|
||||||
model, attack = parse_model_and_attack_from_args(args)
|
_, attack = parse_goal_function_and_attack_from_args(args)
|
||||||
while not in_queue.empty():
|
while not in_queue.empty():
|
||||||
try:
|
try:
|
||||||
label, text = in_queue.get()
|
output, text = in_queue.get()
|
||||||
results_gen = attack.attack_dataset([(label, text)], num_examples=1)
|
results_gen = attack.attack_dataset([(output, text)], num_examples=1)
|
||||||
result = next(results_gen)
|
result = next(results_gen)
|
||||||
out_queue.put(result)
|
out_queue.put(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -103,4 +103,4 @@ def pytorch_multiprocessing_workaround():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run(get_args())
|
run(get_args())
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ def run(args):
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# Models and Attack
|
# Models and Attack
|
||||||
model, attack = parse_model_and_attack_from_args(args)
|
goal_function, attack = parse_goal_function_and_attack_from_args(args)
|
||||||
|
|
||||||
# Logger
|
# Logger
|
||||||
attack_logger = parse_logger_from_args(args)
|
attack_logger = parse_logger_from_args(args)
|
||||||
@@ -47,12 +47,10 @@ def run(args):
|
|||||||
|
|
||||||
tokenized_text = textattack.shared.tokenized_text.TokenizedText(text, model.tokenizer)
|
tokenized_text = textattack.shared.tokenized_text.TokenizedText(text, model.tokenizer)
|
||||||
|
|
||||||
pred = attack._call_model([tokenized_text])
|
result = goal_function.get_results([tokenized_text])[0]
|
||||||
label = int(pred.argmax())
|
|
||||||
|
|
||||||
print('Attacking...')
|
print('Attacking...')
|
||||||
|
|
||||||
result = next(attack.attack_dataset([(label, text)]))
|
result = next(attack.attack_dataset([(result.output, text, False)]))
|
||||||
print(result.__str__(color_method='stdout'))
|
print(result.__str__(color_method='stdout'))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -83,4 +81,4 @@ def run(args):
|
|||||||
print(f'Attack time: {time.time() - load_time}s')
|
print(f'Attack time: {time.time() - load_time}s')
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run(get_args())
|
run(get_args())
|
||||||
|
|||||||
Reference in New Issue
Block a user