mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
170 lines
5.3 KiB
Python
170 lines
5.3 KiB
Python
"""
|
|
|
|
TextAttack Command Class for Attack Single Threaded
|
|
-----------------------------------------------------
|
|
|
|
A command line parser to run an attack in single thread from user specifications.
|
|
|
|
"""
|
|
|
|
from collections import deque
|
|
import os
|
|
import time
|
|
|
|
import tqdm
|
|
|
|
import textattack
|
|
|
|
from .attack_args_helpers import (
|
|
parse_attack_from_args,
|
|
parse_dataset_from_args,
|
|
parse_logger_from_args,
|
|
)
|
|
|
|
logger = textattack.shared.logger
|
|
|
|
|
|
def run(args, checkpoint=None):
|
|
# Only use one GPU, if we have one.
|
|
# TODO: Running Universal Sentence Encoder uses multiple GPUs
|
|
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
# Disable tensorflow logs, except in the case of an error.
|
|
if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
|
|
try:
|
|
# Fix TensorFlow GPU memory growth
|
|
import tensorflow as tf
|
|
|
|
gpus = tf.config.experimental.list_physical_devices("GPU")
|
|
if gpus:
|
|
try:
|
|
# Currently, memory growth needs to be the same across GPUs
|
|
for gpu in gpus:
|
|
tf.config.experimental.set_memory_growth(gpu, True)
|
|
except RuntimeError as e:
|
|
# Memory growth must be set before GPUs have been initialized
|
|
print(e)
|
|
except ModuleNotFoundError:
|
|
pass
|
|
|
|
if args.checkpoint_resume:
|
|
num_remaining_attacks = checkpoint.num_remaining_attacks
|
|
worklist = checkpoint.worklist
|
|
worklist_tail = checkpoint.worklist_tail
|
|
|
|
logger.info(
|
|
"Recovered from checkpoint previously saved at {}".format(
|
|
checkpoint.datetime
|
|
)
|
|
)
|
|
print(checkpoint, "\n")
|
|
else:
|
|
num_remaining_attacks = args.num_examples
|
|
worklist = deque(range(0, args.num_examples))
|
|
worklist_tail = worklist[-1]
|
|
|
|
start_time = time.time()
|
|
|
|
# Attack
|
|
attack = parse_attack_from_args(args)
|
|
print(attack, "\n")
|
|
|
|
# Logger
|
|
if args.checkpoint_resume:
|
|
attack_log_manager = checkpoint.log_manager
|
|
else:
|
|
attack_log_manager = parse_logger_from_args(args)
|
|
|
|
load_time = time.time()
|
|
textattack.shared.logger.info(f"Load time: {load_time - start_time}s")
|
|
|
|
if args.interactive:
|
|
print("Running in interactive mode")
|
|
print("----------------------------")
|
|
|
|
while True:
|
|
print('Enter a sentence to attack or "q" to quit:')
|
|
text = input()
|
|
|
|
if text == "q":
|
|
break
|
|
|
|
if not text:
|
|
continue
|
|
|
|
print("Attacking...")
|
|
|
|
attacked_text = textattack.shared.attacked_text.AttackedText(text)
|
|
initial_result = attack.goal_function.get_output(attacked_text)
|
|
result = next(attack.attack_dataset([(text, initial_result)]))
|
|
print(result.__str__(color_method="ansi") + "\n")
|
|
|
|
else:
|
|
# Not interactive? Use default dataset.
|
|
dataset = parse_dataset_from_args(args)
|
|
|
|
pbar = tqdm.tqdm(total=num_remaining_attacks, smoothing=0)
|
|
if args.checkpoint_resume:
|
|
num_results = checkpoint.results_count
|
|
num_failures = checkpoint.num_failed_attacks
|
|
num_successes = checkpoint.num_successful_attacks
|
|
else:
|
|
num_results = 0
|
|
num_failures = 0
|
|
num_successes = 0
|
|
|
|
for result in attack.attack_dataset(dataset, indices=worklist):
|
|
attack_log_manager.log_result(result)
|
|
|
|
if not args.disable_stdout:
|
|
print("\n")
|
|
if (not args.attack_n) or (
|
|
not isinstance(result, textattack.attack_results.SkippedAttackResult)
|
|
):
|
|
pbar.update(1)
|
|
else:
|
|
# worklist_tail keeps track of highest idx that has been part of worklist
|
|
# Used to get the next dataset element when attacking with `attack_n` = True.
|
|
worklist_tail += 1
|
|
worklist.append(worklist_tail)
|
|
|
|
num_results += 1
|
|
|
|
if (
|
|
type(result) == textattack.attack_results.SuccessfulAttackResult
|
|
or type(result) == textattack.attack_results.MaximizedAttackResult
|
|
):
|
|
num_successes += 1
|
|
if type(result) == textattack.attack_results.FailedAttackResult:
|
|
num_failures += 1
|
|
pbar.set_description(
|
|
"[Succeeded / Failed / Total] {} / {} / {}".format(
|
|
num_successes, num_failures, num_results
|
|
)
|
|
)
|
|
|
|
if (
|
|
args.checkpoint_interval
|
|
and len(attack_log_manager.results) % args.checkpoint_interval == 0
|
|
):
|
|
new_checkpoint = textattack.shared.Checkpoint(
|
|
args, attack_log_manager, worklist, worklist_tail
|
|
)
|
|
new_checkpoint.save()
|
|
attack_log_manager.flush()
|
|
|
|
pbar.close()
|
|
print()
|
|
# Enable summary stdout
|
|
if args.disable_stdout:
|
|
attack_log_manager.enable_stdout()
|
|
attack_log_manager.log_summary()
|
|
attack_log_manager.flush()
|
|
print()
|
|
# finish_time = time.time()
|
|
textattack.shared.logger.info(f"Attack time: {time.time() - load_time}s")
|
|
|
|
return attack_log_manager.results
|