mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
fix parallerl worker
This commit is contained in:
@@ -73,7 +73,6 @@ def run(args, checkpoint=None):
|
||||
# We could do the same thing with the model, but it's actually faster
|
||||
# to let each thread have their own copy of the model.
|
||||
args = torch.multiprocessing.Manager().Namespace(**vars(args))
|
||||
# start_time = time.time()
|
||||
|
||||
if args.checkpoint_resume:
|
||||
attack_log_manager = checkpoint.log_manager
|
||||
@@ -86,7 +85,7 @@ def run(args, checkpoint=None):
|
||||
dataset = parse_dataset_from_args(args)
|
||||
|
||||
textattack.shared.logger.info(f"Running on {num_gpus} GPUs")
|
||||
load_time = time.time()
|
||||
start_time = time.time()
|
||||
|
||||
if args.interactive:
|
||||
raise RuntimeError("Cannot run in parallel if --interactive set")
|
||||
@@ -108,7 +107,7 @@ def run(args, checkpoint=None):
|
||||
worklist.remove(i)
|
||||
|
||||
# Start workers.
|
||||
# pool = torch.multiprocessing.Pool(num_gpus, attack_from_queue, (args, in_queue, out_queue))
|
||||
pool = torch.multiprocessing.Pool(num_gpus, attack_from_queue, (args, in_queue, out_queue))
|
||||
# Log results asynchronously and update progress bar.
|
||||
if args.checkpoint_resume:
|
||||
num_results = checkpoint.results_count
|
||||
@@ -178,8 +177,8 @@ def run(args, checkpoint=None):
|
||||
attack_log_manager.log_summary()
|
||||
attack_log_manager.flush()
|
||||
print()
|
||||
# finish_time = time.time()
|
||||
textattack.shared.logger.info(f"Attack time: {time.time() - load_time}s")
|
||||
|
||||
textattack.shared.logger.info(f"Attack time: {time.time() - start_time}s")
|
||||
|
||||
return attack_log_manager.results
|
||||
|
||||
|
||||
Reference in New Issue
Block a user