1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

parallel fix; add --check-robustness training option

This commit is contained in:
Jack Morris
2020-08-04 13:49:20 -04:00
parent 505ddfee10
commit 6cd3a2d9a5
7 changed files with 137 additions and 34 deletions

View File

@@ -20,6 +20,7 @@ logger = textattack.shared.logger
def set_env_variables(gpu_id):
# Set sharing strategy to file_system to avoid file descriptor leaks
torch.multiprocessing.set_sharing_strategy("file_system")
# Only use one GPU, if we have one.
# For Tensorflow
# TODO: Using USE with `--parallel` raises similar issue as https://github.com/tensorflow/tensorflow/issues/38518#
@@ -27,6 +28,19 @@ def set_env_variables(gpu_id):
# For PyTorch
torch.cuda.set_device(gpu_id)
# Fix TensorFlow GPU memory growth
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices("GPU")
if gpus:
try:
# Currently, memory growth needs to be the same across GPUs
gpu = gpus[gpu_id]
tf.config.experimental.set_visible_devices(gpu, "GPU")
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e:
print(e)
# Disable tensorflow logs, except in the case of an error.
if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"