mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
interactive fixes
This commit is contained in:
@@ -25,11 +25,11 @@ def attack_from_queue(args, in_queue, out_queue):
|
||||
gpu_id = torch.multiprocessing.current_process()._identity[0] - 2
|
||||
print('Using GPU #' + str(gpu_id))
|
||||
set_env_variables(gpu_id)
|
||||
model, attack = parse_model_and_attack_from_args(args)
|
||||
_, attack = parse_goal_function_and_attack_from_args(args)
|
||||
while not in_queue.empty():
|
||||
try:
|
||||
label, text = in_queue.get()
|
||||
results_gen = attack.attack_dataset([(label, text)], num_examples=1)
|
||||
output, text = in_queue.get()
|
||||
results_gen = attack.attack_dataset([(output, text)], num_examples=1)
|
||||
result = next(results_gen)
|
||||
out_queue.put(result)
|
||||
except Exception as e:
|
||||
@@ -103,4 +103,4 @@ def pytorch_multiprocessing_workaround():
|
||||
pass
|
||||
|
||||
if __name__ == '__main__':
|
||||
run(get_args())
|
||||
run(get_args())
|
||||
|
||||
Reference in New Issue
Block a user