1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

fix parallerl worker

This commit is contained in:
Jin Yong Yoo
2020-07-13 14:30:45 -04:00
parent e83f4a0091
commit 22ef1bd459

View File

@@ -73,7 +73,6 @@ def run(args, checkpoint=None):
# We could do the same thing with the model, but it's actually faster
# to let each thread have their own copy of the model.
args = torch.multiprocessing.Manager().Namespace(**vars(args))
# start_time = time.time()
if args.checkpoint_resume:
attack_log_manager = checkpoint.log_manager
@@ -86,7 +85,7 @@ def run(args, checkpoint=None):
dataset = parse_dataset_from_args(args)
textattack.shared.logger.info(f"Running on {num_gpus} GPUs")
load_time = time.time()
start_time = time.time()
if args.interactive:
raise RuntimeError("Cannot run in parallel if --interactive set")
@@ -108,7 +107,7 @@ def run(args, checkpoint=None):
worklist.remove(i)
# Start workers.
# pool = torch.multiprocessing.Pool(num_gpus, attack_from_queue, (args, in_queue, out_queue))
pool = torch.multiprocessing.Pool(num_gpus, attack_from_queue, (args, in_queue, out_queue))
# Log results asynchronously and update progress bar.
if args.checkpoint_resume:
num_results = checkpoint.results_count
@@ -178,8 +177,8 @@ def run(args, checkpoint=None):
attack_log_manager.log_summary()
attack_log_manager.flush()
print()
# finish_time = time.time()
textattack.shared.logger.info(f"Attack time: {time.time() - load_time}s")
textattack.shared.logger.info(f"Attack time: {time.time() - start_time}s")
return attack_log_manager.results