1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/textattack/commands/attack/run_attack_single_threaded.py
2020-11-01 00:58:16 -04:00

170 lines
5.3 KiB
Python

"""
TextAttack Command Class for Attack Single Threaded
-----------------------------------------------------
A command line parser to run an attack in single thread from user specifications.
"""
from collections import deque
import os
import time
import tqdm
import textattack
from .attack_args_helpers import (
parse_attack_from_args,
parse_dataset_from_args,
parse_logger_from_args,
)
logger = textattack.shared.logger
def run(args, checkpoint=None):
# Only use one GPU, if we have one.
# TODO: Running Universal Sentence Encoder uses multiple GPUs
if "CUDA_VISIBLE_DEVICES" not in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# Disable tensorflow logs, except in the case of an error.
if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
try:
# Fix TensorFlow GPU memory growth
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices("GPU")
if gpus:
try:
# Currently, memory growth needs to be the same across GPUs
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e:
# Memory growth must be set before GPUs have been initialized
print(e)
except ModuleNotFoundError:
pass
if args.checkpoint_resume:
num_remaining_attacks = checkpoint.num_remaining_attacks
worklist = checkpoint.worklist
worklist_tail = checkpoint.worklist_tail
logger.info(
"Recovered from checkpoint previously saved at {}".format(
checkpoint.datetime
)
)
print(checkpoint, "\n")
else:
num_remaining_attacks = args.num_examples
worklist = deque(range(0, args.num_examples))
worklist_tail = worklist[-1]
start_time = time.time()
# Attack
attack = parse_attack_from_args(args)
print(attack, "\n")
# Logger
if args.checkpoint_resume:
attack_log_manager = checkpoint.log_manager
else:
attack_log_manager = parse_logger_from_args(args)
load_time = time.time()
textattack.shared.logger.info(f"Load time: {load_time - start_time}s")
if args.interactive:
print("Running in interactive mode")
print("----------------------------")
while True:
print('Enter a sentence to attack or "q" to quit:')
text = input()
if text == "q":
break
if not text:
continue
print("Attacking...")
attacked_text = textattack.shared.attacked_text.AttackedText(text)
initial_result = attack.goal_function.get_output(attacked_text)
result = next(attack.attack_dataset([(text, initial_result)]))
print(result.__str__(color_method="ansi") + "\n")
else:
# Not interactive? Use default dataset.
dataset = parse_dataset_from_args(args)
pbar = tqdm.tqdm(total=num_remaining_attacks, smoothing=0)
if args.checkpoint_resume:
num_results = checkpoint.results_count
num_failures = checkpoint.num_failed_attacks
num_successes = checkpoint.num_successful_attacks
else:
num_results = 0
num_failures = 0
num_successes = 0
for result in attack.attack_dataset(dataset, indices=worklist):
attack_log_manager.log_result(result)
if not args.disable_stdout:
print("\n")
if (not args.attack_n) or (
not isinstance(result, textattack.attack_results.SkippedAttackResult)
):
pbar.update(1)
else:
# worklist_tail keeps track of highest idx that has been part of worklist
# Used to get the next dataset element when attacking with `attack_n` = True.
worklist_tail += 1
worklist.append(worklist_tail)
num_results += 1
if (
type(result) == textattack.attack_results.SuccessfulAttackResult
or type(result) == textattack.attack_results.MaximizedAttackResult
):
num_successes += 1
if type(result) == textattack.attack_results.FailedAttackResult:
num_failures += 1
pbar.set_description(
"[Succeeded / Failed / Total] {} / {} / {}".format(
num_successes, num_failures, num_results
)
)
if (
args.checkpoint_interval
and len(attack_log_manager.results) % args.checkpoint_interval == 0
):
new_checkpoint = textattack.shared.Checkpoint(
args, attack_log_manager, worklist, worklist_tail
)
new_checkpoint.save()
attack_log_manager.flush()
pbar.close()
print()
# Enable summary stdout
if args.disable_stdout:
attack_log_manager.enable_stdout()
attack_log_manager.log_summary()
attack_log_manager.flush()
print()
# finish_time = time.time()
textattack.shared.logger.info(f"Attack time: {time.time() - load_time}s")
return attack_log_manager.results