mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
parallel fix; add --check-robustness training option
This commit is contained in:
@@ -20,6 +20,7 @@ logger = textattack.shared.logger
|
||||
def set_env_variables(gpu_id):
|
||||
# Set sharing strategy to file_system to avoid file descriptor leaks
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
# Only use one GPU, if we have one.
|
||||
# For Tensorflow
|
||||
# TODO: Using USE with `--parallel` raises similar issue as https://github.com/tensorflow/tensorflow/issues/38518#
|
||||
@@ -27,6 +28,19 @@ def set_env_variables(gpu_id):
|
||||
# For PyTorch
|
||||
torch.cuda.set_device(gpu_id)
|
||||
|
||||
# Fix TensorFlow GPU memory growth
|
||||
import tensorflow as tf
|
||||
|
||||
gpus = tf.config.experimental.list_physical_devices("GPU")
|
||||
if gpus:
|
||||
try:
|
||||
# Currently, memory growth needs to be the same across GPUs
|
||||
gpu = gpus[gpu_id]
|
||||
tf.config.experimental.set_visible_devices(gpu, "GPU")
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
|
||||
# Disable tensorflow logs, except in the case of an error.
|
||||
if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
|
||||
Reference in New Issue
Block a user