1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

make trainer extendable

This commit is contained in:
Jin Yong Yoo
2021-05-01 04:42:01 -04:00
parent 72409a0dd9
commit 1d8b72b85d
13 changed files with 821 additions and 600 deletions

View File

@@ -28,6 +28,7 @@ Complete API Reference
:members:
:undoc-members:
:show-inheritance:
:exclude-members: CommandLineAttackArgs
.. automodule:: textattack.attack
:members:
@@ -48,18 +49,9 @@ Complete API Reference
:members:
:undoc-members:
:show-inheritance:
:exclude-members: CommandLineTrainingArgs
.. automodule:: textattack.trainer
:members:
:undoc-members:
:show-inheritance:
.. automodule:: textattack.dataset_args
:members:
:undoc-members:
:show-inheritance:
.. automodule:: textattack.model_args
:members:
:undoc-members:
:show-inheritance:

View File

@@ -1,6 +1,6 @@
bert-score>=0.3.5
editdistance
flair==0.6.1.post1
flair
filelock
language_tool_python
lemminflect

View File

@@ -30,24 +30,22 @@ goldmember is funny enough to reasoned the embarrassment of bringing a
--------------------------------------------- Result 3 ---------------------------------------------
Positive (100%) --> Negative (60%)
Positive (100%) --> [FAILED]
it may not be particularly innovative , but the film's crisp , unaffected style and air of gentle longing make it unexpectedly rewarding .
it probable not be particularly creative , but the film's brusque , undamaged shape and midair of mild desiring doing it surprisingly beneficial .
it may not be particularly innovative , but the film's crisp , unaffected style and air of gentle longing make it unexpectedly rewarding .
+-------------------------------+--------+
| Attack Results | |
+-------------------------------+--------+
| Number of successful attacks: | 3 |
| Number of failed attacks: | 0 |
| Number of successful attacks: | 2 |
| Number of failed attacks: | 1 |
| Number of skipped attacks: | 0 |
| Original accuracy: | 100.0% |
| Accuracy under attack: | 0.0% |
| Attack success rate: | 100.0% |
| Average perturbed word %: | 23.71% |
| Accuracy under attack: | 33.33% |
| Attack success rate: | 66.67% |
| Average perturbed word %: | 9.38% |
| Average num. words per input: | 15.0 |
| Avg num queries: | 71.0 |
+-------------------------------+--------+

View File

@@ -9,6 +9,7 @@ Attack: TextAttack builds attacks from four components:
The ``Attack`` class represents an adversarial attack composed of a goal function, search method, transformation, and constraints.
"""
from collections import OrderedDict
import lru
import torch
@@ -358,8 +359,11 @@ class Attack:
AttackResult
"""
assert isinstance(
example, AttackedText
), "`example` must be of type `AttackedText`."
example, (str, OrderedDict, AttackedText)
), "`example` must either be `str`, `collections.OrderedDict`, `textattack.shared.AttackedText`."
if isinstance(example, (str, OrderedDict)):
example = AttackedText(example)
assert isinstance(
ground_truth_output, (int, str)
), "`ground_truth_output` must either be `str` or `int`."

View File

@@ -111,43 +111,73 @@ GOAL_FUNCTION_CLASS_NAMES = {
@dataclass
class AttackArgs:
"""Attack args for running attacks via API. This assumes that ``Attack``
has already been created by the user.
"""Attack arguments for running attacks via API. This assumes that
``Attack`` has already been created by the user.
Args:
num_examples (int): The number of examples to attack. -1 for entire dataset.
num_examples_offset (int): The offset to start at in the dataset.
query_budget (int): The maximum number of model queries allowed per example attacked.
This is optional and setting this overwrites the query budget set in `GoalFunction` object.
shuffle (bool): If `True`, shuffle the samples before we attack the dataset. Note this does not involve shuffling the dataset internally. Default is False.
attack_n (bool): Whether to run attack until total of `n` examples have been attacked (not skipped).
checkpoint_dir (str): The directory to save checkpoint files.
checkpoint_interval (int): If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.
random_seed (int): Random seed for reproducibility. Default is 765.
parallel (bool): Run attack using multiple CPUs/GPUs.
num_workers_per_device (int): Number of worker processes to run per device. For example, if you are using GPUs and ``num_workers_per_device=2``,
then 2 processes will be running in each GPU. If you are only using CPU, then this is equivalent to running 2 processes concurrently.
log_to_txt (str): Path to which to save attack logs as a text file. Set this argument if you want to save text logs.
If the last part of the path ends with `.txt` extension, the path is assumed to path for output file.
log_to_csv (str): Path to which to save attack logs as a CSV file. Set this argument if you want to save CSV logs.
If the last part of the path ends with `.csv` extension, the path is assumed to path for output file.
csv_coloring_style (str): Method for choosing how to mark perturbed parts of the text. Options are "file" and "plain".
"file" wraps text with double brackets `[[ <text> ]]` while "plain" does not mark any text. Default is "file".
log_to_visdom (dict): Set this argument if you want to log attacks to Visdom. The dictionary should have the following
three keys and their corresponding values: `"env", "port", "hostname"` (e.g. `{"env": "main", "port": 8097, "hostname": "localhost"}`).
log_to_wandb (str): Name of the wandb project. Set this argument if you want to log attacks to Wandb.
disable_stdout (bool): Disable logging attack results to stdout.
silent (bool): Disable all logging.
ignore_exceptions (bool): Skip examples that raise an error instead of exiting.
num_examples (:obj:`int`, 'optional`, defaults to :obj:`10`):
The number of examples to attack. -1 for entire dataset.
num_successful_examples (:obj:`int`, `optional`, defaults to :obj:`None`):
The number of successful adversarial examples we want. This is different from `num_examples`
as `num_examples` only cares about attacking `N` samples while `num_successful_examples` aims to keep attacking
until we have `N` successful cases.
.. note::
If set, this argument overrides `num_examples` argument.
num_examples_offset (:obj: `int`, `optional`, defaults to :obj:`0`):
The offset index to start at in the dataset.
attack_n (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to run attack until total of `N` examples have been attacked (and not skipped).
shuffle (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, we randomly shuffle the dataset before attacking. However, this avoids actually shuffling
the dataset internally and opts for shuffling the list of indices of examples we want to attack. This means
`shuffle` can now be used with checkpoint saving.
query_budget (:obj:`int`, `optional`, defaults to :obj:`None`):
The maximum number of model queries allowed per example attacked.
If not set, we use the query budget set in the `GoalFunction` object (which by default is `float("inf")`).
.. note::
Setting this overwrites the query budget set in `GoalFunction` object.
checkpoint_interval (:obj:`int`, `optional`, defaults to :obj:`None`):
If set, checkpoint will be saved after attacking every `N` examples. If `None` is passed, no checkpoints will be saved.
checkpoint_dir (:obj:`str`, `optional`, defaults to :obj:`"checkpoints"`):
The directory to save checkpoint files.
random_seed (:obj:`int`, `optional`, defaults to :obj:`765`):
Random seed for reproducibility.
parallel (:obj:`False`, `optional`, defaults to :obj:`False`):
If :obj:`True`, run attack using multiple CPUs/GPUs.
num_workers_per_device (:obj:`int`, `optional`, defaults to :obj:`1`):
Number of worker processes to run per device in parallel mode (i.e. `parallel=True`). For example, if you are using GPUs and `num_workers_per_device=2`,
then 2 processes will be running in each GPU.
log_to_txt (:obj:`str`, `optional`, defaults to :obj:`None`):
If set, save attack logs as a `.txt` file to the directory specified by this argument.
If the last part of the provided path ends with `.txt` extension, it is assumed to the desired path of the log file.
log_to_csv (:obj:`str`, `optional`, defaults to :obj:`None`):
If set, save attack logs as a CSV file to the directory specified by this argument.
If the last part of the provided path ends with `.csv` extension, it is assumed to the desired path of the log file.
csv_coloring_style (:obj:`str`, `optional`, defaults to :obj:`"file"`):
Method for choosing how to mark perturbed parts of the text. Options are `"file"` and `"plain"`.
`"file"` wraps perturbed parts with double brackets `[[ <text> ]]` while `"plain"` does not mark the text in any way.
log_to_visdom (:obj:`dict`, `optional`, defaults to :obj:`None`):
If set, Visdom logger is used with the provided dictionary passed as a keyword arguments to `textattack.loggers.VisdomLogger`.
Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following
three keys and their corresponding values: `"env"`, `"port"`, `"hostname"`.
log_to_wandb (:obj:`str`, `optional`, defaults to :obj:`None`):
If set, log the attack results and summary to Wandb project specified by this argument.
disable_stdout (:obj:`bool`, `optional`, defaults to :obj:`False`):
Disable displaying individual attack results to stdout.
silent (:obj:`bool`, `optional`, defaults to :obj:`False`):
Disable all logging (except for errors). This is stronger than `disable_stdout`.
"""
num_examples: int = 5
num_examples: int = 10
num_successful_examples: int = None
num_examples_offset: int = 0
query_budget: int = None
shuffle: bool = False
attack_n: bool = False
checkpoint_dir: str = "checkpoints"
shuffle: bool = False
query_budget: int = None
checkpoint_interval: int = None
checkpoint_dir: str = "checkpoints"
random_seed: int = 765 # equivalent to sum((ord(c) for c in "TEXTATTACK"))
parallel: bool = False
num_workers_per_device: int = 1
@@ -158,106 +188,138 @@ class AttackArgs:
log_to_wandb: str = None
disable_stdout: bool = False
silent: bool = False
ignore_exceptions: bool = False
def __post_init__(self):
if self.num_successful_examples:
self.num_examples = None
if self.num_examples:
assert (
self.num_examples > 0 or self.num_examples == -1
), "`num_examples` must be greater than 0 or equal to -1."
if self.num_successful_examples:
assert (
self.num_successful_examples > 0
), "`num_examples` must be greater than 0."
if self.query_budget:
assert self.query_budget > 0, "`query_budget` must be greater than 0"
if self.checkpoint_interval:
assert (
self.checkpoint_interval > 0
), "`checkpoint_interval` must be greater than 0"
assert (
self.num_workers_per_device > 0
), "`num_workers_per_device` must be greater than 0"
assert self.csv_coloring_style in {
"file",
"plain",
}, '`csv_coloring_style` must either be "file" or "plain".'
@classmethod
def add_parser_args(cls, parser):
"""Add listed args to command line parser."""
parser.add_argument(
default_obj = cls()
num_ex_group = parser.add_mutually_exclusive_group(required=False)
num_ex_group.add_argument(
"--num-examples",
"-n",
type=int,
required=False,
default=5,
help="The number of examples to process, -1 for entire dataset",
default=default_obj.num_examples,
help="The number of examples to process, -1 for entire dataset.",
)
num_ex_group.add_argument(
"--num-successful-examples",
type=int,
default=default_obj.num_successful_examples,
help="The number of successful adversarial examples we want.",
)
parser.add_argument(
"--num-examples-offset",
"-o",
type=int,
required=False,
default=0,
default=default_obj.num_examples_offset,
help="The offset to start at in the dataset.",
)
parser.add_argument(
"--query-budget",
"-q",
type=int,
default=None,
default=default_obj.query_budget,
help="The maximum number of model queries allowed per example attacked. Setting this overwrites the query budget set in `GoalFunction` object.",
)
parser.add_argument(
"--shuffle",
action="store_true",
default=False,
default=default_obj.shuffle,
help="If `True`, shuffle the samples before we attack the dataset. Default is False.",
)
parser.add_argument(
"--attack-n",
action="store_true",
default=False,
default=default_obj.attack_n,
help="Whether to run attack until `n` examples have been attacked (not skipped).",
)
parser.add_argument(
"--checkpoint-dir",
required=False,
type=str,
default="checkpoints",
default=default_obj.checkpoint_dir,
help="The directory to save checkpoint files.",
)
parser.add_argument(
"--checkpoint-interval",
required=False,
type=int,
default=default_obj.checkpoint_interval,
help="If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.",
)
def str_to_int(s):
return sum((ord(c) for c in s))
parser.add_argument("--random-seed", default=str_to_int("TEXTATTACK"), type=int)
parser.add_argument(
"--random-seed",
default=default_obj.random_seed,
type=int,
help="Random seed for reproducibility.",
)
parser.add_argument(
"--parallel",
action="store_true",
default=False,
default=default_obj.parallel,
help="Run attack using multiple GPUs.",
)
parser.add_argument(
"--num-workers-per-device",
default=1,
default=default_obj.num_workers_per_device,
type=int,
help="Number of worker processes to run per device.",
)
parser.add_argument(
"--log-to-txt",
nargs="?",
default=None,
default=default_obj.log_to_txt,
const="",
type=str,
help="Path to which to save attack logs as a text file. Set this argument if you want to save text logs. "
"If the last part of the path ends with `.txt` extension, the path is assumed to path for output file.",
)
parser.add_argument(
"--log-to-csv",
nargs="?",
default=None,
default=default_obj.log_to_csv,
const="",
type=str,
help="Path to which to save attack logs as a CSV file. Set this argument if you want to save CSV logs. "
"If the last part of the path ends with `.csv` extension, the path is assumed to path for output file.",
)
parser.add_argument(
"--csv-coloring-style",
default="file",
default=default_obj.csv_coloring_style,
type=str,
help='Method for choosing how to mark perturbed parts of the text in CSV logs. Options are "file" and "plain". '
'"file" wraps text with double brackets `[[ <text> ]]` while "plain" does not mark any text. Default is "file".',
)
parser.add_argument(
"--log-to-visdom",
nargs="?",
@@ -268,30 +330,25 @@ class AttackArgs:
'three keys and their corresponding values: `"env", "port", "hostname"`. '
'Example for command line use: `--log-to-visdom {"env": "main", "port": 8097, "hostname": "localhost"}`.',
)
parser.add_argument(
"--log-to-wandb",
nargs="?",
default=None,
default=default_obj.log_to_wandb,
const="textattack",
type=str,
help="Name of the wandb project. Set this argument if you want to log attacks to Wandb.",
)
parser.add_argument(
"--disable-stdout",
action="store_true",
default=default_obj.disable_stdout,
help="Disable logging attack results to stdout",
)
parser.add_argument(
"--silent", action="store_true", default=False, help="Disable all logging"
)
parser.add_argument(
"--ignore-exceptions",
"--silent",
action="store_true",
default=False,
help="Skip examples that raise an error instead of exiting.",
default=default_obj.silent,
help="Disable all logging",
)
return parser
@@ -356,23 +413,36 @@ class AttackArgs:
@dataclass
class _CommandLineAttackArgs:
"""Command line interface attack args. This requires more arguments to
"""Attack args for command line execution. This requires more arguments to
create ``Attack`` object as specified.
Args:
transformation (str): Name of transformation to use.
constraints (list[str]): List of names of constraints to use.
goal_function (str): Name of goal function to use.
search_method (str): Name of search method to use.
attack_recipe (str): Name of attack recipe to use.
If this is set, it overrides any previous selection of transformation, constraints, goal function, and search method.
attack_from_file (str): Path of `.py` file from which to load attack from. Use `<path>^<variable_name>` to specifiy which variable to import from the file.
If this is set, it overrides any previous selection of transformation, constraints, goal function, and search method
interactive (bool): If `True`, carry attack in interactive mode. Default is `False`.
parallel (bool): If `True`, attack in parallel. Default is `False`.
model_batch_size (int): The batch size for making calls to the model.
model_cache_size (int): The maximum number of items to keep in the model results cache at once.
constraint-cache-size (int): The maximum number of items to keep in the constraints cache at once.
transformation (:obj:`str`, `optional`, defaults to :obj:`"word-swap-embedding"`):
Name of transformation to use.
constraints (:obj:`list[str]`, `optional`, defaults to :obj:`["repeat", "stopword"]`):
List of names of constraints to use.
goal_function (:obj:`str`, `optional`, defaults to :obj:`"untargeted-classification"`):
Name of goal function to use.
search_method (:obj:`str`, `optional`, defualts to :obj:`"greedy-word-wir"`):
Name of search method to use.
attack_recipe (:obj:`str`, `optional`, defaults to :obj:`None`):
Name of attack recipe to use.
.. note::
Setting this overrides any previous selection of transformation, constraints, goal function, and search method.
attack_from_file (:obj:`str`, `optional`, defaults to :obj:`None`):
Path of `.py` file from which to load attack from. Use `<path>^<variable_name>` to specifiy which variable to import from the file.
.. note::
If this is set, it overrides any previous selection of transformation, constraints, goal function, and search method
interactive (:obj:`bool`, `optional`, defaults to :obj:`False`):
If `True`, carry attack in interactive mode.
parallel (:obj:`bool`, `optional`, defaults to :obj:`False`):
If `True`, attack in parallel.
model_batch_size (:obj:`int`, `optional`, defaults to :obj:`32`):
The batch size for making queries to the victim model.
model_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**18`):
The maximum number of items to keep in the model results cache at once.
constraint-cache-size (:obj:`int`, `optional`, defaults to :obj:`2**18`):
The maximum number of items to keep in the constraints cache at once.
"""
transformation: str = "word-swap-embedding"
@@ -390,6 +460,7 @@ class _CommandLineAttackArgs:
@classmethod
def add_parser_args(cls, parser):
"""Add listed args to command line parser."""
default_obj = cls()
transformation_names = set(BLACK_BOX_TRANSFORMATION_CLASS_NAMES.keys()) | set(
WHITE_BOX_TRANSFORMATION_CLASS_NAMES.keys()
)
@@ -397,7 +468,7 @@ class _CommandLineAttackArgs:
"--transformation",
type=str,
required=False,
default="word-swap-embedding",
default=default_obj.transformation,
help='The transformation to apply. Usage: "--transformation {transformation}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
+ str(transformation_names),
)
@@ -406,7 +477,7 @@ class _CommandLineAttackArgs:
type=str,
required=False,
nargs="*",
default=["repeat", "stopword"],
default=default_obj.constraints,
help='Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
+ str(CONSTRAINT_CLASS_NAMES.keys()),
)
@@ -414,7 +485,7 @@ class _CommandLineAttackArgs:
parser.add_argument(
"--goal-function",
"-g",
default="untargeted-classification",
default=default_obj.goal_function,
help=f"The goal function to use. choices: {goal_function_choices}",
)
attack_group = parser.add_mutually_exclusive_group(required=False)
@@ -425,7 +496,7 @@ class _CommandLineAttackArgs:
"-s",
type=str,
required=False,
default="greedy-word-wir",
default=default_obj.search_method,
help=f"The search method to use. choices: {search_choices}",
)
attack_group.add_argument(
@@ -434,7 +505,7 @@ class _CommandLineAttackArgs:
"-r",
type=str,
required=False,
default=None,
default=default_obj.attack_recipe,
help="full attack recipe (overrides provided goal function, transformation & constraints)",
choices=ATTACK_RECIPE_NAMES.keys(),
)
@@ -442,31 +513,31 @@ class _CommandLineAttackArgs:
"--attack-from-file",
type=str,
required=False,
default=None,
default=default_obj.attack_from_file,
help="Path of `.py` file from which to load attack from. Use `<path>^<variable_name>` to specifiy which variable to import from the file.",
)
parser.add_argument(
"--interactive",
action="store_true",
default=False,
default=default_obj.interactive,
help="Whether to run attacks interactively.",
)
parser.add_argument(
"--model-batch-size",
type=int,
default=32,
default=default_obj.model_batch_size,
help="The batch size for making calls to the model.",
)
parser.add_argument(
"--model-cache-size",
type=int,
default=2 ** 18,
default=default_obj.model_cache_size,
help="The maximum number of items to keep in the model results cache at once.",
)
parser.add_argument(
"--constraint-cache-size",
type=int,
default=2 ** 18,
default=default_obj.constraint_cache_size,
help="The maximum number of items to keep in the constraints cache at once.",
)

View File

@@ -4,6 +4,7 @@ import multiprocessing as mp
import os
import queue
import random
import traceback
import torch
import tqdm
@@ -126,13 +127,13 @@ class Attacker:
else:
self._attack()
# Turn logger back on since other modules can be using Attacker (e.g. Trainer)
if self.attack_args.silent:
logger.setLevel(logging.INFO)
return self.attack_log_manager.results
def _get_worklist(self, start, end, num_examples, shuffle):
if end - start > num_examples:
if end - start < num_examples:
logger.warn(
f"Attempting to attack {num_examples} samples when only {end-start} are available."
)
@@ -141,6 +142,7 @@ class Attacker:
random.shuffle(candidates)
worklist = collections.deque(candidates[:num_examples])
candidates = collections.deque(candidates[num_examples:])
assert (len(worklist) + len(candidates)) == (end - start)
return worklist, candidates
def _attack(self):
@@ -157,15 +159,26 @@ class Attacker:
f"Recovered from checkpoint previously saved at {self._checkpoint.datetime}."
)
else:
num_remaining_attacks = self.attack_args.num_examples
# We make `worklist` deque (linked-list) for easy pop and append.
# Candidates are other samples we can attack if we need more samples.
worklist, worklist_candidates = self._get_worklist(
self.attack_args.num_examples_offset,
len(self.dataset),
self.attack_args.num_examples,
self.attack_args.shuffle,
)
if self.attack_args.num_successful_examples:
num_remaining_attacks = self.attack_args.num_successful_examples
# We make `worklist` deque (linked-list) for easy pop and append.
# Candidates are other samples we can attack if we need more samples.
worklist, worklist_candidates = self._get_worklist(
self.attack_args.num_examples_offset,
len(self.dataset),
self.attack_args.num_successful_examples,
self.attack_args.shuffle,
)
else:
num_remaining_attacks = self.attack_args.num_examples
# We make `worklist` deque (linked-list) for easy pop and append.
# Candidates are other samples we can attack if we need more samples.
worklist, worklist_candidates = self._get_worklist(
self.attack_args.num_examples_offset,
len(self.dataset),
self.attack_args.num_examples,
self.attack_args.shuffle,
)
if not self.attack_args.silent:
print(self.attack, "\n")
@@ -174,19 +187,20 @@ class Attacker:
if self._checkpoint:
num_results = self._checkpoint.results_count
num_failures = self._checkpoint.num_failed_attacks
num_skipped = self._checkpoint.num_skipped_attacks
num_successes = self._checkpoint.num_successful_attacks
else:
num_results = 0
num_failures = 0
num_skipped = 0
num_successes = 0
if hasattr(self.attack.goal_function.model, "to"):
self.attack.goal_function.model.to(textattack.shared.utils.device)
i = 0
sample_exhaustion_warned = False
while worklist:
idx = worklist.popleft()
i += 1
try:
example, ground_truth_output = self.dataset[idx]
except IndexError:
@@ -197,32 +211,36 @@ class Attacker:
try:
result = self.attack.attack(example, ground_truth_output)
except Exception as e:
if self.attack_args.ignore_exceptions:
continue
else:
raise e
self.attack_log_manager.log_result(result)
if not self.attack_args.disable_stdout:
print("\n")
if isinstance(result, SkippedAttackResult) and self.attack_args.attack_n:
raise e
if (
isinstance(result, SkippedAttackResult) and self.attack_args.attack_n
) or (
not isinstance(result, SuccessfulAttackResult)
and self.attack_args.num_successful_examples
):
if worklist_candidates:
next_sample = worklist_candidates.popleft()
worklist.append(next_sample)
else:
logger.warn("`attack_n=True` but no more samples to attack.")
if not sample_exhaustion_warned:
logger.warn("Ran out of samples to attack!")
sample_exhaustion_warned = True
else:
pbar.update(1)
self.attack_log_manager.log_result(result)
if not self.attack_args.disable_stdout and not self.attack_args.silent:
print("\n")
num_results += 1
if isinstance(result, SkippedAttackResult):
num_skipped += 1
if isinstance(result, (SuccessfulAttackResult, MaximizedAttackResult)):
num_successes += 1
if isinstance(result, FailedAttackResult):
num_failures += 1
pbar.set_description(
f"[Succeeded / Failed / Total] {num_successes} / {num_failures} / {num_results}"
f"[Succeeded / Failed / Skipped / Total] {num_successes} / {num_failures} / {num_skipped} / {num_results}"
)
if (
@@ -260,13 +278,26 @@ class Attacker:
f"Recovered from checkpoint previously saved at {self._checkpoint.datetime}."
)
else:
num_remaining_attacks = self.attack_args.num_examples
worklist, worklist_candidates = self._get_worklist(
self.attack_args.num_examples_offset,
len(self.dataset),
self.attack_args.num_examples,
self.attack_args.shuffle,
)
if self.attack_args.num_successful_examples:
num_remaining_attacks = self.attack_args.num_successful_examples
# We make `worklist` deque (linked-list) for easy pop and append.
# Candidates are other samples we can attack if we need more samples.
worklist, worklist_candidates = self._get_worklist(
self.attack_args.num_examples_offset,
len(self.dataset),
self.attack_args.num_successful_examples,
self.attack_args.shuffle,
)
else:
num_remaining_attacks = self.attack_args.num_examples
# We make `worklist` deque (linked-list) for easy pop and append.
# Candidates are other samples we can attack if we need more samples.
worklist, worklist_candidates = self._get_worklist(
self.attack_args.num_examples_offset,
len(self.dataset),
self.attack_args.num_examples,
self.attack_args.shuffle,
)
in_queue = torch.multiprocessing.Queue()
out_queue = torch.multiprocessing.Queue()
@@ -313,25 +344,38 @@ class Attacker:
if self._checkpoint:
num_results = self._checkpoint.results_count
num_failures = self._checkpoint.num_failed_attacks
num_skipped = self._checkpoint.num_skipped_attacks
num_successes = self._checkpoint.num_successful_attacks
else:
num_results = 0
num_failures = 0
num_skipped = 0
num_successes = 0
logger.info(f"Worklist size: {len(worklist)}")
logger.info(f"Worklist candidate size: {len(worklist_candidates)}")
sample_exhaustion_warned = False
pbar = tqdm.tqdm(total=num_remaining_attacks, smoothing=0)
while worklist:
idx, result = out_queue.get(block=True)
worklist.remove(idx)
if isinstance(result, Exception):
if self.attack_args.ignore_exceptions:
continue
else:
worker_pool.terminate()
worker_pool.join()
raise result
self.attack_log_manager.log_result(result)
if self.attack_args.attack_n and isinstance(result, SkippedAttackResult):
if isinstance(result, tuple) and isinstance(result[0], Exception):
logger.error(
f'Exception encountered for input "{self.dataset[idx][0]}".'
)
error_trace = result[1]
logger.error(error_trace)
worker_pool.terminate()
worker_pool.join()
exit(1)
elif (
isinstance(result, SkippedAttackResult) and self.attack_args.attack_n
) or (
not isinstance(result, SuccessfulAttackResult)
and self.attack_args.num_successful_examples
):
if worklist_candidates:
next_sample = worklist_candidates.popleft()
example, ground_truth_output = self.dataset[next_sample]
@@ -341,21 +385,24 @@ class Attacker:
worklist.append(next_sample)
in_queue.put((next_sample, example, ground_truth_output))
else:
logger.warn(
f"Attempted to attack {self.attack_args.num_examples} examples with but ran out of examples. "
f"You might see fewer number of results than {self.attack_args.num_examples}."
)
if not sample_exhaustion_warned:
logger.warn("Ran out of samples to attack!")
sample_exhaustion_warned = True
else:
pbar.update()
num_results += 1
if isinstance(result, (SuccessfulAttackResult, MaximizedAttackResult)):
num_successes += 1
if isinstance(result, FailedAttackResult):
num_failures += 1
pbar.set_description(
f"[Succeeded / Failed / Total] {num_successes} / {num_failures} / {num_results}"
)
self.attack_log_manager.log_result(result)
num_results += 1
if isinstance(result, SkippedAttackResult):
num_skipped += 1
if isinstance(result, (SuccessfulAttackResult, MaximizedAttackResult)):
num_successes += 1
if isinstance(result, FailedAttackResult):
num_failures += 1
pbar.set_description(
f"[Succeeded / Failed / Skipped / Total] {num_successes} / {num_failures} / {num_skipped} / {num_results}"
)
if (
self.attack_args.checkpoint_interval
@@ -372,7 +419,10 @@ class Attacker:
new_checkpoint.save()
self.attack_log_manager.flush()
worker_pool.terminate()
# Send sentinel values to worker processes
for _ in range(num_workers):
in_queue.put(("END", "END", "END"))
worker_pool.close()
worker_pool.join()
pbar.close()
@@ -504,11 +554,14 @@ def attack_from_queue(
while True:
try:
i, example, ground_truth_output = in_queue.get(timeout=5)
result = attack.attack(example, ground_truth_output)
out_queue.put((i, result))
if i == "END" and example == "END" and ground_truth_output == "END":
# End process when sentinel value is received
break
else:
result = attack.attack(example, ground_truth_output)
out_queue.put((i, result))
except Exception as e:
if isinstance(e, queue.Empty):
continue
out_queue.put((i, e))
if not attack_args.ignore_exceptions:
exit(1)
else:
out_queue.put((i, (e, traceback.format_exc(e))))

View File

@@ -10,3 +10,4 @@ from .repeat_modification import RepeatModification
from .input_column_modification import InputColumnModification
from .max_word_index_modification import MaxWordIndexModification
from .min_word_length import MinWordLength
from .max_modification_rate import MaxModificationRate

View File

@@ -92,7 +92,7 @@ class HuggingFaceModelWrapper(PyTorchModelWrapper):
self.model.zero_grad()
model_device = next(self.model.parameters()).device
input_dict = self.tokenizer(
text_input,
[text_input],
add_special_tokens=True,
return_tensors="pt",
padding="max_length",
@@ -135,7 +135,7 @@ class HuggingFaceModelWrapper(PyTorchModelWrapper):
"""
return [
self.tokenizer.convert_ids_to_tokens(
self.tokenizer(x, truncation=True)["input_ids"]
self.tokenizer([x], truncation=True)["input_ids"][0]
)
for x in inputs
]

View File

@@ -72,6 +72,9 @@ class GreedyWordSwapWIR(SearchMethod):
continue
swap_results, _ = self.get_goal_results(transformed_text_candidates)
score_change = [result.score for result in swap_results]
if not score_change:
delta_ps.append(0.0)
continue
max_score_change = np.max(score_change)
delta_ps.append(max_score_change)

View File

@@ -103,11 +103,13 @@ class AttackedText:
"""
if "previous_attacked_text" in self.attack_attrs:
self.attack_attrs["previous_attacked_text"].free_memory()
if "last_transformation" in self.attack_attrs:
del self.attack_attrs["last_transformation"]
self.attack_attrs.pop("previous_attacked_text", None)
self.attack_attrs.pop("last_transformation", None)
for key in self.attack_attrs:
if isinstance(self.attack_attrs[key], torch.Tensor):
del self.attack_attrs[key]
self.attack_attrs.pop(key, None)
def text_window_around_index(self, index, window_size):
"""The text window of ``window_size`` words centered around

View File

@@ -1,42 +1,28 @@
import collections
import functools
import json
import logging
import math
import os
import numpy as np
import scipy
import torch
import tqdm
import transformers
import textattack
from textattack.models.helpers import LSTMForClassification, WordCNNForClassification
from textattack.models.wrappers import ModelWrapper
from .attack import Attack
from .attack_args import AttackArgs
from .attack_results import MaximizedAttackResult, SuccessfulAttackResult
from .attacker import Attacker
from .model_args import HUGGINGFACE_MODELS
from .models.helpers import LSTMForClassification, WordCNNForClassification
from .models.wrappers import ModelWrapper
from .training_args import CommandLineTrainingArgs, TrainingArgs
logger = textattack.shared.logger
# Helper functions for collating data
def collate_fn(input_columns, data):
input_texts = []
labels = []
for _input, label in data:
_input = tuple(_input[c] for c in input_columns)
if len(_input) == 1:
_input = _input[0]
input_texts.append(_input)
labels.append(label)
return input_texts, torch.tensor(labels)
class Trainer:
"""Trainer is training and eval loop for adversarial training.
@@ -47,10 +33,10 @@ class Trainer:
self,
model_wrapper,
task_type,
attack,
train_dataset,
eval_dataset,
training_args,
attack=None,
train_dataset=None,
eval_dataset=None,
training_args=TrainingArgs(),
):
assert isinstance(
model_wrapper, ModelWrapper
@@ -60,15 +46,18 @@ class Trainer:
"classification",
"regression",
}, '`task_type` must either be "classification" or "regression"'
assert isinstance(
attack, Attack
), f"`attack` argument must be of type `textattack.Attack`, but got type of `{type(attack)}`."
assert isinstance(
train_dataset, textattack.datasets.Dataset
), f"`train_dataset` must be of type `textattack.datasets.Dataset`, but got type `{type(train_dataset)}`."
assert isinstance(
eval_dataset, textattack.datasets.Dataset
), f"`eval_dataset` must be of type `textattack.datasets.Dataset`, but got type `{type(eval_dataset)}`."
if attack:
assert isinstance(
attack, Attack
), f"`attack` argument must be of type `textattack.Attack`, but got type of `{type(attack)}`."
if train_dataset:
assert isinstance(
train_dataset, textattack.datasets.Dataset
), f"`train_dataset` must be of type `textattack.datasets.Dataset`, but got type `{type(train_dataset)}`."
if eval_dataset:
assert isinstance(
eval_dataset, textattack.datasets.Dataset
), f"`eval_dataset` must be of type `textattack.datasets.Dataset`, but got type `{type(eval_dataset)}`."
assert isinstance(
training_args, TrainingArgs
), f"`training_args` must be of type `textattack.TrainingArgs`, but got type `{type(training_args)}`."
@@ -94,121 +83,88 @@ class Trainer:
self.eval_dataset = eval_dataset
self.training_args = training_args
def _generate_adversarial_examples(self, dataset, epoch, eval_mode=False):
"""Generate adversarial examples using attacker."""
if eval_mode:
logger.info("Attacking model to evaluate adversarial robustness...")
self._metric_name = (
"pearson_correlation" if self.task_type == "regression" else "accuracy"
)
if self.task_type == "regression":
self.loss_fct = torch.nn.MSELoss(reduction="none")
else:
logger.info("Attacking model to generate new adversarial training set...")
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
num_examples = (
self.training_args.num_eval_adv_examples
if eval_mode
else self.training_args.num_train_adv_examples
)
query_budget = (
self.training_args.query_budget_eval
if eval_mode
else self.training_args.query_budget_train
)
shuffle = False if eval_mode else True
base_file_name = (
f"attack-eval-{epoch}" if eval_mode else f"attack-train-{epoch}"
)
self._global_step = 0
def _generate_adversarial_examples(self, epoch):
"""Generate adversarial examples using attacker."""
base_file_name = f"attack-train-{epoch}"
log_file_name = os.path.join(self.training_args.output_dir, base_file_name)
logger.info("Attacking model to generate new adversarial training set...")
if isinstance(self.training_args.num_train_adv_examples, float):
num_train_adv_examples = math.ceil(
len(self.train_dataset) * self.training_args.num_train_adv_examples
)
else:
num_train_adv_examples = self.training_args.num_train_adv_examples
attack_args = AttackArgs(
num_examples=num_examples,
num_successful_examples=num_train_adv_examples,
num_examples_offset=0,
query_budget=query_budget,
shuffle=shuffle,
attack_n=True,
query_budget=self.training_args.query_budget_train,
shuffle=True,
parallel=self.training_args.parallel,
num_workers_per_device=self.training_args.attack_num_workers_per_device,
disable_stdout=True,
silent=True,
ignore_exceptions=True,
log_to_txt=log_file_name + ".txt",
log_to_csv=log_file_name + ".csv",
)
attacker = Attacker(self.attack, dataset, attack_args)
attacker = Attacker(self.attack, self.train_dataset, attack_args=attack_args)
results = attacker.attack_dataset()
if eval_mode:
return results
else:
attack_types = collections.Counter(r.__class__.__name__ for r in results)
total_attacks = (
attack_types["SuccessfulAttackResult"]
+ attack_types["FailedAttackResult"]
)
success_rate = attack_types["SuccessfulAttackResult"] / total_attacks * 100
logger.info(
f"Attack Success Rate: {success_rate:.2f}% [{attack_types['SuccessfulAttackResult']} / {total_attacks}]"
)
adversarial_examples = [
(
tuple(r.perturbed_result.attacked_text._text_input.values()),
r.perturbed_result.ground_truth_output,
)
for r in results
]
adversarial_dataset = textattack.datasets.Dataset(
adversarial_examples,
input_columns=dataset.input_columns,
label_map=dataset.label_map,
label_names=dataset.label_names,
output_scale_factor=dataset.output_scale_factor,
shuffle=False,
)
return adversarial_dataset
def _training_setup(self):
"""Handle all the training set ups including logging."""
textattack.shared.utils.set_seed(self.training_args.random_seed)
if not os.path.exists(self.training_args.output_dir):
os.makedirs(self.training_args.output_dir)
# Save logger writes to file
log_txt_path = os.path.join(self.training_args.output_dir, "train_log.txt")
fh = logging.FileHandler(log_txt_path)
fh.setLevel(logging.DEBUG)
logger.addHandler(fh)
logger.info(f"Writing logs to {log_txt_path}.")
# Save original self.training_args to file
args_save_path = os.path.join(
self.training_args.output_dir, "training_args.json"
attack_types = collections.Counter(r.__class__.__name__ for r in results)
total_attacks = (
attack_types["SuccessfulAttackResult"] + attack_types["FailedAttackResult"]
)
with open(args_save_path, "w", encoding="utf-8") as f:
json.dump(self.training_args.__dict__, f)
logger.info(f"Wrote original training args to {args_save_path}.")
def _print_training_args(self, total_training_steps):
logger.info("==================== Running Training ====================")
logger.info(f"Num epochs = {self.training_args.num_epochs}")
logger.info(f"Num clean epochs = {self.training_args.num_clean_epochs}")
logger.info(f"Num total steps = {total_training_steps}")
logger.info(f"Num training examples = {len(self.train_dataset)}")
logger.info(f"Num evaluation examples = {len(self.eval_dataset)}")
logger.info(f"Starting learning rate = {self.training_args.learning_rate}")
logger.info(f"Num warmup steps = {self.training_args.num_warmup_steps}")
logger.info(f"Weight decay = {self.training_args.weight_decay}")
def _get_tensorboard_writer(self):
from torch.utils.tensorboard import SummaryWriter
tb_writer = SummaryWriter(self.training_args.tb_log_dir)
tb_writer.add_hparams(self.training_args.__dict__, {})
tb_writer.flush()
return tb_writer
def _init_wandb(self):
global wandb
import wandb
wandb.init(
project=self.training_args.wand_project, config=self.training_args.__dict__
success_rate = attack_types["SuccessfulAttackResult"] / total_attacks * 100
logger.info(f"Total number of attack results: {len(results)}")
logger.info(
f"Attack success rate: {success_rate:.2f}% [{attack_types['SuccessfulAttackResult']} / {total_attacks}]"
)
adversarial_examples = [
(
tuple(r.perturbed_result.attacked_text._text_input.values()),
r.perturbed_result.ground_truth_output,
"adversarial_example",
)
for r in results
if isinstance(r, (SuccessfulAttackResult, MaximizedAttackResult))
]
adversarial_dataset = textattack.datasets.Dataset(
adversarial_examples,
input_columns=self.train_dataset.input_columns,
label_map=self.train_dataset.label_map,
label_names=self.train_dataset.label_names,
output_scale_factor=self.train_dataset.output_scale_factor,
shuffle=False,
)
return adversarial_dataset
def _print_training_args(self, total_training_steps, train_batch_size):
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(self.train_dataset)}")
logger.info(f" Num epochs = {self.training_args.num_epochs}")
logger.info(f" Num clean epochs = {self.training_args.num_clean_epochs}")
logger.info(
f" Instantaneous batch size per device = {self.training_args.per_device_train_batch_size}"
)
logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {train_batch_size * self.training_args.gradient_accumulation_steps}"
)
logger.info(
f" Gradient accumulation steps = {self.training_args.gradient_accumulation_steps}"
)
logger.info(f" Total optimization steps = {total_training_steps}")
def _save_model_checkpoint(
self, model, tokenizer, step=None, epoch=None, best=False, last=False
@@ -242,7 +198,36 @@ class Trainer:
os.path.join(output_dir, "pytorch_model.bin"),
)
def _get_optimizer_and_scheduler(self, model, total_training_steps):
def _tb_log(self, log, step):
if not hasattr(self, "_tb_writer"):
from torch.utils.tensorboard import SummaryWriter
self._tb_writer = SummaryWriter(self.training_args.tb_log_dir)
self._tb_writer.add_hparams(self.training_args.__dict__, {})
self._tb_writer.flush()
for key in log:
self._tb_writer.add_scalar(key, log[key], step)
self.tb_writer.flush()
def _wandb_log(self, log, step):
if not hasattr(self, "_wandb_init"):
global wandb
import wandb
self._wandb_init = True
wandb.init(
project=self.training_args.wandb_project,
config=self.training_args.__dict__,
)
wandb.log(log, step=step)
def get_optimizer_and_scheduler(self, model, total_training_steps):
if isinstance(model, torch.nn.DataParallel):
model = model.module
if isinstance(model, transformers.PreTrainedModel):
# Reference https://huggingface.co/transformers/training.html
param_optimizer = list(model.named_parameters())
@@ -267,10 +252,16 @@ class Trainer:
optimizer = transformers.optimization.AdamW(
optimizer_grouped_parameters, lr=self.training_args.learning_rate
)
if isinstance(self.training_args.num_warmup_steps, float):
num_warmup_steps = math.ceil(
self.training_args.num_warmup_steps * total_training_steps
)
else:
num_warmup_steps = self.training_args.num_warmup_steps
scheduler = transformers.optimization.get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=self.training_args.num_warmup_steps,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_training_steps,
)
else:
@@ -282,8 +273,162 @@ class Trainer:
return optimizer, scheduler
def get_train_dataloader(self, dataset, train_batch_size):
# Helper functions for collating data
def collate_fn(data):
input_texts = []
targets = []
is_adv_sample = []
for item in data:
if len(item) == 3:
# `len(item)` is 3 for adversarial training dataset
_input, label, adv = item
if adv != "adversarial_example":
raise ValueError(
"`item` has length of 3 but last element is not for marking if the item is an `adversarial example`."
)
else:
is_adv_sample.append(True)
else:
# else `len(item)` is 2.
_input, label = item
is_adv_sample.append(False)
_input = tuple(_input[c] for c in dataset.input_columns)
if len(_input) == 1:
_input = _input[0]
input_texts.append(_input)
targets.append(label)
return input_texts, torch.tensor(targets), torch.tensor(is_adv_sample)
train_dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=train_batch_size,
shuffle=True,
collate_fn=collate_fn,
pin_memory=True,
)
return train_dataloader
def get_eval_dataloader(self, dataset, eval_batch_size):
# Helper functions for collating data
def collate_fn(data):
input_texts = []
targets = []
for _input, label in data:
_input = tuple(_input[c] for c in dataset.input_columns)
if len(_input) == 1:
_input = _input[0]
input_texts.append(_input)
targets.append(label)
return input_texts, torch.tensor(targets)
eval_dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=eval_batch_size,
shuffle=True,
collate_fn=collate_fn,
pin_memory=True,
)
return eval_dataloader
def training_step(self, model, tokenizer, batch):
"""
Args:
model (:obj:`torch.nn.Module`):
Model to train.
tokenizer:
Tokenizer used to tokenize input text.
batch (:obj:`tuple[list[str], torch.Tensor, torch.Tensor]`):
Tuple of input texts, targets, and boolean tensor indicating if the sample is an adversarial example.
"""
input_texts, targets, is_adv_sample = batch
_targets = targets
targets = targets.to(textattack.shared.utils.device)
if isinstance(model, transformers.PreTrainedModel) or (isinstance(model, torch.nn.DataParallel) and isinstance(model.module, transformers.PreTrainedModel)):
input_ids = tokenizer(
input_texts,
padding="max_length",
return_tensors="pt",
truncation=True,
)
input_ids.to(textattack.shared.utils.device)
logits = model(**input_ids)[0]
else:
input_ids = tokenizer(input_texts)
if not isinstance(input_ids, torch.Tensor):
input_ids = torch.tensor(input_ids)
input_ids = input_ids.to(textattack.shared.utils.device)
logits = model(input_ids)
if self.task_type == "regression":
loss = self.loss_fct(logits.squeeze(), targets.squeeze())
preds = logits
else:
loss = self.loss_fct(logits, targets)
preds = logits.argmax(dim=-1)
sample_weights = torch.ones(
is_adv_sample.size(), device=textattack.shared.utils.device
)
sample_weights[is_adv_sample] *= self.training_args.alpha
loss = loss * sample_weights
loss = torch.mean(loss)
preds = preds.cpu()
return loss, preds, _targets
def evaluate_step(self, model, tokenizer, batch):
input_texts, targets = batch
_targets = targets
targets = targets.to(textattack.shared.utils.device)
if isinstance(model, transformers.PreTrainedModel):
input_ids = tokenizer(
input_texts,
padding="max_length",
return_tensors="pt",
truncation=True,
)
input_ids.to(textattack.shared.utils.device)
logits = model(**input_ids)[0]
else:
input_ids = tokenizer(input_texts)
if not isinstance(input_ids, torch.Tensor):
input_ids = torch.tensor(input_ids)
input_ids = input_ids.to(textattack.shared.utils.device)
logits = model(input_ids)
if self.task_type == "regression":
preds = logits
else:
preds = logits.argmax(dim=-1)
return preds.cpu(), _targets
def train(self):
self._training_setup()
textattack.shared.utils.set_seed(self.training_args.random_seed)
if not os.path.exists(self.training_args.output_dir):
os.makedirs(self.training_args.output_dir)
# Save logger writes to file
log_txt_path = os.path.join(self.training_args.output_dir, "train_log.txt")
fh = logging.FileHandler(log_txt_path)
fh.setLevel(logging.DEBUG)
logger.addHandler(fh)
logger.info(f"Writing logs to {log_txt_path}.")
# Save original self.training_args to file
args_save_path = os.path.join(
self.training_args.output_dir, "training_args.json"
)
with open(args_save_path, "w", encoding="utf-8") as f:
json.dump(self.training_args.__dict__, f)
logger.info(f"Wrote original training args to {args_save_path}.")
num_gpus = torch.cuda.device_count()
tokenizer = self.model_wrapper.tokenizer
model = self.model_wrapper.model
@@ -297,41 +442,39 @@ class Trainer:
else:
train_batch_size = self.training_args.per_device_train_batch_size
model.to(textattack.shared.utils.device)
total_training_steps = (
total_clean_training_steps = (
math.ceil(
len(self.train_dataset)
/ (train_batch_size * self.training_args.gradient_accumulation_steps)
)
* self.training_args.num_epochs
* self.training_args.num_clean_epochs
)
total_adv_training_steps = math.ceil(
(len(self.train_dataset) + self.training_args.num_train_adv_examples)
/ (train_batch_size * self.training_args.gradient_accumulation_steps)
) * (self.training_args.num_epochs - self.training_args.num_clean_epochs)
if self.training_args.log_to_tb:
tb_writer = self._get_tensorboard_writer()
if self.training_args.log_to_wandb:
self._init_wandb()
total_training_steps = total_clean_training_steps + total_adv_training_steps
self._print_training_args(total_training_steps, train_batch_size)
optimizer, scheduler = self._get_optimizer_and_scheduler(
optimizer, scheduler = self.get_optimizer_and_scheduler(
model, total_training_steps
)
collate_func = functools.partial(collate_fn, self.train_dataset.input_columns)
if self.task_type == "regression":
loss_fct = torch.nn.MSELoss()
else:
loss_fct = torch.nn.CrossEntropyLoss()
model.to(textattack.shared.utils.device)
# Variables across epochs
global_step = 0
total_loss = 0.0
self._total_loss = 0.0
self._current_loss = 0.0
self._last_log_step = 0
# `best_score` is used to keep track of the best model across training.
# Could be loss, accuracy, or other metrics.
best_eval_score = 0.0
best_eval_score_epoch = 0
best_model_path = None
epochs_since_best_eval_score = 0
self._print_training_args(total_training_steps)
for epoch in range(1, self.training_args.num_epochs + 1):
logger.info("==========================================================")
logger.info(f"Epoch {epoch}")
@@ -344,14 +487,10 @@ class Trainer:
# after the clean epochs
# adv_example_dataset is instance of `textattack.datasets.Dataset
model.eval()
model.cpu()
adv_example_dataset = self._generate_adversarial_examples(
self.train_dataset, epoch
)
adv_example_dataset = self._generate_adversarial_examples(epoch)
train_dataset = torch.utils.data.ConcatDataset(
[self.train_dataset, adv_example_dataset]
)
model.to(textattack.shared.utils.device)
model.train()
else:
train_dataset = self.train_dataset
@@ -361,116 +500,123 @@ class Trainer:
)
train_dataset = self.train_dataset
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=train_batch_size,
shuffle=True,
collate_fn=collate_func,
train_dataloader = self.get_train_dataloader(
train_dataset, train_batch_size
)
model.train()
# Epoch-specific variables
correct_predictions = 0
total_predictions = 0
# Epoch variables
all_preds = []
all_targets = []
prog_bar = tqdm.tqdm(
train_dataloader, desc="Iteration", position=0, leave=True
)
for step, batch in enumerate(prog_bar):
input_texts, labels = batch
labels = labels.to(textattack.shared.utils.device)
if isinstance(model, transformers.PreTrainedModel):
input_ids = tokenizer(
input_texts,
padding="max_length",
return_tensors="pt",
truncation=True,
)
for key in input_ids:
if isinstance(input_ids[key], torch.Tensor):
input_ids[key] = input_ids[key].to(
textattack.shared.utils.device
)
logits = model(**input_ids)[0]
else:
input_ids = tokenizer(input_texts)
if not isinstance(input_ids, torch.Tensor):
input_ids = torch.tensor(input_ids)
input_ids = input_ids.to(textattack.shared.utils.device)
logits = model(input_ids)
if self.task_type == "regression":
# TODO integrate with textattack `metrics` package
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
pred_labels = logits.argmax(dim=-1)
correct_predictions += (pred_labels == labels).sum().item()
total_predictions += len(pred_labels)
loss, preds, targets = self.training_step(model, tokenizer, batch)
if isinstance(model, torch.nn.DataParallel):
loss = loss.mean()
if self.training_args.gradient_accumulation_steps > 1:
loss = loss / self.training_args.gradient_accumulation_steps
loss.backward()
total_loss += loss.item()
loss = loss / self.training_args.gradient_accumulation_steps
loss.backward()
loss = loss.item()
self._total_loss += loss
self._current_loss += loss
all_preds.append(preds)
all_targets.append(targets)
if (step + 1) % self.training_args.gradient_accumulation_steps == 0:
optimizer.step()
if scheduler:
scheduler.step()
optimizer.zero_grad()
global_step += 1
self._global_step += 1
if self._global_step > 0:
prog_bar.set_description(
f"Loss {self._total_loss/self._global_step:.5f}"
)
# TODO: Better way to handle TB and Wandb logging
if global_step % self.training_args.logging_interval_step == 0:
if (self._global_step > 0) and (
self._global_step % self.training_args.logging_interval_step == 0
):
lr_to_log = (
scheduler.get_last_lr()[0]
if scheduler
else self.training_args.learning_rate
)
if self._global_step - self._last_log_step >= 1:
loss_to_log = round(
self._current_loss
/ (self._global_step - self._last_log_step),
4,
)
else:
loss_to_log = round(self._current_loss, 4)
log = {"train/loss": loss_to_log, "train/learning_rate": lr_to_log}
if self.training_args.log_to_tb:
tb_writer.add_scalar("loss", loss.item(), global_step)
tb_writer.add_scalar("lr", lr_to_log, global_step)
self._tb_log(log, self._global_step)
if self.training_args.log_to_wandb:
wandb.log({"loss": loss.item()}, step=global_step)
wandb.log({"lr": lr_to_log}, step=global_step)
self._wandb_log(log, self._global_step)
if global_step > 0:
prog_bar.set_description(f"Loss {total_loss/global_step:.5f}")
self._current_loss = 0.0
self._last_log_step = self._global_step
# Save model checkpoint to file.
if self.training_args.checkpoint_interval_steps:
if (
global_step > 0
and self.training_args.checkpoint_interval_steps > 0
and (global_step % self.training_args.checkpoint_interval_steps)
self._global_step > 0
and (
self._global_step
% self.training_args.checkpoint_interval_steps
)
== 0
):
self._save_model_checkpoint(model, tokenizer, step=global_step)
self._save_model_checkpoint(
model, tokenizer, step=self._global_step
)
# Print training accuracy, if we're tracking it.
if total_predictions > 0:
train_acc = correct_predictions / total_predictions
logger.info(f"Train accuracy: {train_acc*100:.2f}%")
preds = torch.cat(all_preds)
targets = torch.cat(all_targets)
if self._metric_name == "accuracy":
correct_predictions = (preds == targets).sum().item()
accuracy = correct_predictions / len(targets)
metric_log = {"train/train_accuracy": accuracy}
logger.info(f"Train accuracy: {accuracy*100:.2f}%")
else:
pearson_correlation, pearson_pvalue = scipy.stats.pearsonr(
preds, targets
)
metric_log = {
"train/pearson_correlation": pearson_correlation,
"train/pearson_pvalue": pearson_pvalue,
}
logger.info(f"Train Pearson correlation: {pearson_correlation:.4f}%")
if len(targets) > 0:
if self.training_args.log_to_tb:
tb_writer.add_scalar("epoch_train_acc", train_acc, global_step)
self._tb_log(metric_log, epoch)
if self.training_args.log_to_wandb:
wandb.log({"epoch_train_acc": train_acc}, step=global_step)
metric_log["epoch"] = epoch
self._wandb_log(metric_log, self._global_step)
# Evaluate after each epoch.
eval_score = self.evaluate()
# Check eval accuracy after each epoch.
eval_score = self._evaluate(model, tokenizer)
logger.info(
f"Eval {'pearson correlation' if self.task_type == 'regression' else 'accuracy'}: {eval_score*100:.2f}%"
)
if self.training_args.log_to_tb:
tb_writer.add_scalar("epoch_eval_score", eval_score, global_step)
self._tb_log({f"eval/{self._metric_name}": eval_score}, epoch)
if self.training_args.log_to_wandb:
wandb.log({"epoch_eval_score": eval_score}, step=global_step)
self._wandb_log(
{f"eval/{self._metric_name}": eval_score, "epoch": epoch},
self._global_step,
)
if (
self.training_args.checkpoint_interval_epochs
and ((epoch - 1) % self.training_args.checkpoint_interval_epochs) == 0
and (epoch % self.training_args.checkpoint_interval_epochs) == 0
):
self._save_model_checkpoint(model, tokenizer, epoch=epoch)
@@ -480,7 +626,7 @@ class Trainer:
epochs_since_best_eval_score = 0
self._save_model_checkpoint(model, tokenizer, best=True)
logger.info(
f"Best acc found. Saved model to {self.training_args.output_dir}/best_model/"
f"Best score found. Saved model to {self.training_args.output_dir}/best_model/"
)
else:
epochs_since_best_eval_score += 1
@@ -493,64 +639,7 @@ class Trainer:
)
break
if self.training_args.eval_adversarial_robustness and (
epoch >= self.training_args.num_clean_epochs
):
# Evaluate adversarial robustness
model.eval()
model.cpu()
adv_attack_results = self._generate_adversarial_examples(
self.eval_dataset, epoch, eval_mode=True
)
model.to(textattack.shared.utils.device)
model.train()
attack_types = [r.__class__.__name__ for r in adv_attack_results]
attack_types = collections.Counter(attack_types)
total_attacks = (
attack_types["SuccessfulAttackResult"]
+ attack_types["FailedAttackResult"]
)
adv_succ_rate = attack_types["SuccessfulAttackResult"] / total_attacks
num_queries = np.array(
[
r.num_queries
for r in adv_attack_results
if not isinstance(
r, textattack.attack_results.SkippedAttackResult
)
]
)
avg_num_queries = round(num_queries.mean(), 2)
if self.training_args.log_to_tb:
tb_writer.add_scalar(
"robustness_total_attacks", total_attacks, global_step
)
tb_writer.add_scalar(
"robustness_attack_succ_rate", adv_succ_rate, global_step
)
tb_writer.add_scalar(
"robustness_avg_num_queries", avg_num_queries, global_step
)
if self.training_args.log_to_wandb:
wandb.log(
{"robustness_total_attacks": total_attacks}, step=global_step
)
wandb.log(
{"robustness_attack_succ_rate": adv_succ_rate}, step=global_step
)
wandb.log(
{"robustness_avg_num_queries": avg_num_queries},
step=global_step,
)
logger.info(f"Eval total attack: {total_attacks}")
logger.info(f"Eval attack success rate: {100*adv_succ_rate:.2f}%")
logger.info(f"Eval avg num queries: {avg_num_queries}")
if self.training_args.log_to_tb:
tb_writer.flush()
# Finish training
if isinstance(model, torch.nn.DataParallel):
model = model.module
@@ -560,20 +649,23 @@ class Trainer:
model = model.__class__.from_pretrained(best_model_path)
else:
model = model.load_state_dict(
torch.load(os.path.join(best_model_path, "model.pt"))
torch.load(os.path.join(best_model_path, "pytorch_model.bin"))
)
if self.training_args.save_last:
self._save_model_checkpoint(model, tokenizer, last=True)
self.model_wrapper.model = model
self.write_readme(best_eval_score, best_eval_score_epoch, train_batch_size)
self._write_readme(best_eval_score, best_eval_score_epoch, train_batch_size)
def evaluate(self):
logging.info("Evaluating model on evaluation dataset.")
model = self.model_wrapper.model
tokenizer = self.model_wrapper.tokenizer
def _evaluate(self, model, tokenizer):
model.eval()
correct = 0
logits = []
labels = []
all_preds = []
all_targets = []
if isinstance(model, torch.nn.DataParallel):
num_gpus = torch.cuda.device_count()
@@ -581,58 +673,33 @@ class Trainer:
else:
eval_batch_size = self.training_args.per_device_eval_batch_size
collate_func = functools.partial(collate_fn, self.eval_dataset.input_columns)
eval_dataloader = torch.utils.data.DataLoader(
self.eval_dataset,
batch_size=eval_batch_size,
collate_fn=collate_func,
)
eval_dataloader = self.get_eval_dataloader(self.eval_dataset, eval_batch_size)
with torch.no_grad():
for input_texts, batch_labels in eval_dataloader:
batch_labels = batch_labels.to(textattack.shared.utils.device)
if isinstance(model, transformers.PreTrainedModel):
input_ids = tokenizer(
input_texts,
padding="max_length",
return_tensors="pt",
truncation=True,
)
for key in input_ids:
if isinstance(input_ids[key], torch.Tensor):
input_ids[key] = input_ids[key].to(
textattack.shared.utils.device
)
batch_logits = model(**input_ids)[0]
else:
input_ids = tokenizer(input_texts)
if not isinstance(input_ids, torch.Tensor):
input_ids = torch.tensor(input_ids)
input_ids = input_ids.to(textattack.shared.utils.device)
batch_logits = model(input_ids)
for step, batch in enumerate(eval_dataloader):
preds, targets = self.evaluate_step(model, tokenizer, batch)
all_preds.append(preds)
all_targets.append(targets)
logits.extend(batch_logits.cpu().squeeze().tolist())
labels.extend(batch_labels)
model.train()
logits = torch.tensor(logits)
labels = torch.tensor(labels)
preds = torch.cat(all_preds)
targets = torch.cat(all_targets)
if self.task_type == "regression":
pearson_correlation, pearson_p_value = scipy.stats.pearsonr(logits, labels)
return pearson_correlation
pearson_correlation, pearson_p_value = scipy.stats.pearsonr(preds, targets)
eval_score = pearson_correlation
else:
preds = logits.argmax(dim=1)
correct = (preds == labels).sum()
return float(correct) / len(labels)
correct_predictions = (preds == targets).sum().item()
accuracy = correct_predictions / len(targets)
eval_score = accuracy
def evaluate(self):
logging.info("Evaluating model on evaluation dataset.")
model = self.model_wrapper.model
tokenizer = self.model_wrapper.tokenizer
return self._evaluate(model, tokenizer)
if self._metric_name == "accuracy":
logger.info(f"Eval {self._metric_name}: {eval_score*100:.2f}%")
else:
logger.info(f"Eval {self._metric_name}: {eval_score:.4f}%")
def write_readme(self, best_eval_score, best_eval_score_epoch, train_batch_size):
return eval_score
def _write_readme(self, best_eval_score, best_eval_score_epoch, train_batch_size):
if isinstance(self.training_args, CommandLineTrainingArgs):
model_name = self.training_args.model_name_or_path
elif isinstance(self.model_wrapper.model, transformers.PreTrainedModel):
@@ -658,7 +725,7 @@ class Trainer:
):
model_max_length = self.training_args.model_max_length
elif isinstance(self.model_wrapper.model, transformers.PreTrainedModel):
model_max_length = self.model_wrapper.config.max_position_embeddings
model_max_length = self.model_wrapper.model.config.max_position_embeddings
elif isinstance(
self.model_wrapper.model, (LSTMForClassification, WordCNNForClassification)
):

View File

@@ -1,6 +1,7 @@
from dataclasses import dataclass, field
import datetime
import os
from typing import Union
from textattack.datasets import HuggingFaceDataset
from textattack.models.helpers import LSTMForClassification, WordCNNForClassification
@@ -15,8 +16,6 @@ from textattack.shared.utils import ARGS_SPLIT_TOKEN
from .attack import Attack
from .attack_args import ATTACK_RECIPE_NAMES
# TODO Add `metric_for_best_model` argument. Currently we just use accuracy for classification and MSE for regression by default.
def default_output_dir():
return os.path.join(
@@ -26,47 +25,72 @@ def default_output_dir():
@dataclass
class TrainingArgs:
"""Args for TextAttack ``Trainer`` class that is used for running
adversarial training.
"""Arguments for ``Trainer`` class that is used for adversarial training.
Args:
num_epochs (int): Total number of epochs for training. Default is 5.
num_clean_epochs (int): Number of epochs to train on just the original training dataset before adversarial training. Default is 0.
attack_epoch_interval (int): Generate a new adversarial training set every N epochs. Default is 1.
early_stopping_epochs (int): Number of epochs validation must increase before stopping early (-1 for no early stopping). Default is `None`.
learning_rate (float): Learning rate for Adam Optimization. Default is 2e-5.
num_warmup_steps (int): The number of steps for the warmup phase of linear scheduler. Default is 500.
weight_decay (float): Weight decay (L2 penalty). Default is 0.01.
per_device_train_batch_size (int): The batch size per GPU/CPU for training. Default is 8.
per_device_eval_batch_size (int): The batch size per GPU/CPU for evaluation. Default is 32.
gradient_accumulation_steps (int): Number of updates steps to accumulate the gradients for, before performing a backward/update pass. Default is 1.
random_seed (int): Random seed. Default is 786.
parallel (bool): If `True`, train using multiple GPUs. Default is `False`.
load_best_model_at_end (bool): If `True`, keep track of the best model across training and load it at the end.
eval_adversarial_robustness (bool): If set, evaluate adversarial robustness on evaluation dataset after every epoch.
num_eval_adv_examples (int): The number of samples attack if `eval_adversarial_robustness=True`. Default is 1000.
num_train_adv_examples (int): The number of samples to attack when generating adversarial training set. Default is -1 (which is all possible samples).
query_budget_train (:obj:`int`, optional): The max query budget to use when generating adversarial training set.
query_budget_eval (:obj:`int`, optional): The max query budget to use when evaluating adversarial robustness.
attack_num_workers_per_device (int): Number of worker processes to run per device for attack. Same as `num_workers_per_device` argument for `AttackArgs`.
output_dir (str): Directory to output training logs and checkpoints.
checkpoint_interval_steps (int): Save model checkpoint after every N updates to the model.
checkpoint_interval_epochs (int): Save model checkpoint after every N epochs
save_last (bool): If `True`, save the model at end of training. Can be used with `load_best_model_at_end` to save the best model at the end. Default is `True`.
log_to_tb (bool): If `True`, log to Tensorboard. Default is `False`
tb_log_dir (str): Path of Tensorboard log directory.
log_to_wandb (bool): If `True`, log to Wandb. Default is `False`.
wandb_project (str): Name of Wandb project for logging. Default is `textattack`.
logging_interval_step (int): Log to Tensorboard/Wandb every N steps.
num_epochs (:obj:`int`, 'optional`, defaults to :obj:`3`):
Total number of epochs for training.
num_clean_epochs (:obj:`int`, 'optional`, defaults to :obj:`1`):
Number of epochs to train on just the original training dataset before adversarial training.
attack_epoch_interval (:obj:`int`, 'optional`, defaults to :obj:`1`):
Generate a new adversarial training set every N epochs.
early_stopping_epochs (:obj:`int`, 'optional`, defaults to :obj:`None`):
Number of epochs validation must increase before stopping early (`None` for no early stopping).
learning_rate (:obj:`float`, 'optional`, defaults to :obj:`5e-5`):
Learning rate for optimizer.
num_warmup_steps (:obj:`int` or :obj:`float`, `optional`, defaults to :obj:`500`):
The number of steps for the warmup phase of linear scheduler.
If `num_warmup_steps` is a `float` between 0 and 1, the number of warmup steps will be `math.ceil(num_training_steps * num_warmup_steps)`.
weight_decay (:obj:`float`, `optional`, defaults to :obj:`0.01`):
Weight decay (L2 penalty).
per_device_train_batch_size (:obj:`int`, `optional`, defaults to :obj:`8`):
The batch size per GPU/CPU for training.
per_device_eval_batch_size (:obj:`int`, `optional`, defaults to :obj:`32`):
The batch size per GPU/CPU for evaluation.
gradient_accumulation_steps (:obj:`int`, `optional`, defaults to :obj:`1`):
Number of updates steps to accumulate the gradients before performing a backward/update pass.
random_seed (:obj:`int`, `optional`, defaults to :obj:`786`):
Random seed for reproducibility.
parallel (:obj:`bool`, `optional`, defaults to :obj:`False`):
If `True`, train using multiple GPUs using `torch.DataParallel`.
load_best_model_at_end (:obj:`bool`, `optional`, defaults to :obj:`False`):
If `True`, keep track of the best model across training and load it at the end.
alpha (:obj:`float`, `optional`, defaults to :obj:`1.0`):
The weight for adversarial loss.
num_train_adv_examples (:obj:`int` or :obj:`float`, `optional`, defaults to :obj:`-1`):
The number of samples to successfully attack when generating adversarial training set before start of every epoch.
If `num_train_adv_examples` is a `float` between 0 and 1, the number of adversarial examples generated is
fraction of the original training set.
query_budget_train (:obj:`int`, `optional`, defaults to :obj:`None`):
The max query budget to use when generating adversarial training set. `None` means infinite query budget.
attack_num_workers_per_device (:obj:`int`, defaults to `optional`, :obj:`1`):
Number of worker processes to run per device for attack. Same as `num_workers_per_device` argument for `AttackArgs`.
output_dir (:obj:`str`, `optional`):
Directory to output training logs and checkpoints. Defaults to `./outputs/%Y-%m-%d-%H-%M-%S-%f` format.
checkpoint_interval_steps (:obj:`int`, `optional`, defaults to :obj:`None`):
If set, save model checkpoint after every N updates to the model.
checkpoint_interval_epochs (:obj:`int`, `optional`, defaults to :obj:`None`):
If set, save model checkpoint after every N epochs
save_last (:obj:`bool`, `optional`, defaults to :obj:`True`):
If `True`, save the model at end of training. Can be used with `load_best_model_at_end` to save the best model at the end.
log_to_tb (:obj:`bool`, `optional`, defaults to :obj:`False`):
If `True`, log to Tensorboard.
tb_log_dir (:obj:`str`, `optional`, defaults to :obj:`"./runs"`):
Path of Tensorboard log directory.
log_to_wandb (:obj:`bool`, `optional`, defaults to :obj:`False`):
If `True`, log to Wandb.
wandb_project (:obj:`str`, `optional`, defaults to :obj:`"textattack"`):
Name of Wandb project for logging.
logging_interval_step (:obj: `int`, `optional`, defaults to :obj:`1`):
Log to Tensorboard/Wandb every N training steps.
"""
num_epochs: int = 5
num_clean_epochs: int = 0
num_epochs: int = 3
num_clean_epochs: int = 1
attack_epoch_interval: int = 1
early_stopping_epochs: int = None
learning_rate: float = 5e-5
lr: float = None # alternative keyword arg for learning_rate
num_warmup_steps: int = 500
num_warmup_steps: Union[int, float] = 500
weight_decay: float = 0.01
per_device_train_batch_size: int = 8
per_device_eval_batch_size: int = 32
@@ -74,11 +98,9 @@ class TrainingArgs:
random_seed: int = 786
parallel: bool = False
load_best_model_at_end: bool = False
eval_adversarial_robustness: bool = False
num_eval_adv_examples: int = 1000
num_train_adv_examples: int = -1
alpha: float = 1.0
num_train_adv_examples: Union[int, float] = -1
query_budget_train: int = None
query_budget_eval: int = None
attack_num_workers_per_device: int = 1
output_dir: str = field(default_factory=default_output_dir)
checkpoint_interval_steps: int = None
@@ -91,8 +113,6 @@ class TrainingArgs:
logging_interval_step: int = 1
def __post_init__(self):
if self.lr:
self.learning_rate = self.lr
assert self.num_epochs > 0, "`num_epochs` must be greater than 0."
assert (
self.num_clean_epochs >= 0
@@ -101,16 +121,13 @@ class TrainingArgs:
assert (
self.early_stopping_epochs > 0
), "`early_stopping_epochs` must be greater than 0."
if self.attack_epoch_interval is not None:
assert (
self.attack_epoch_interval > 0
), "`attack_epoch_interval` must be greater than 0."
assert (
self.attack_epoch_interval > 0
), "`attack_epoch_interval` must be greater than 0."
assert self.num_warmup_steps > 0, "`num_warmup_steps` must be greater than 0."
assert (
self.num_train_adv_examples > 0 or self.num_train_adv_examples == -1
), "`num_train_adv_examples` must be greater than 0 or equal to -1."
assert (
self.num_eval_adv_examples > 0 or self.num_eval_adv_examples == -1
), "`num_eval_adv_examples` must be greater than 0 or equal to -1."
self.num_warmup_steps >= 0
), "`num_warmup_steps` must be greater than or equal to 0."
assert (
self.gradient_accumulation_steps > 0
), "`gradient_accumulation_steps` must be greater than 0."
@@ -118,117 +135,130 @@ class TrainingArgs:
self.num_clean_epochs <= self.num_epochs
), f"`num_clean_epochs` cannot be greater than `num_epochs` ({self.num_clean_epochs} > {self.num_epochs})."
if isinstance(self.num_train_adv_examples, float):
assert (
self.num_train_adv_examples >= 0.0
and self.num_train_adv_examples <= 1.0
), "If `num_train_adv_examples` is float, it must be between 0 and 1."
elif isinstance(self.num_train_adv_examples, int):
assert (
self.num_train_adv_examples > 0 or self.num_train_adv_examples == -1
), "If `num_train_adv_examples` is int, it must be greater than 0 or equal to -1."
else:
raise TypeError("`num_train_adv_examples` must be of either type `int` or `float`.")
@classmethod
def add_parser_args(cls, parser):
"""Add listed args to command line parser."""
default_obj = cls()
def int_or_float(v):
try:
return int(v)
except ValueError:
return float(v)
parser.add_argument(
"--num-epochs",
type=int,
default=4,
default=default_obj.num_epochs,
help="Total number of epochs for training.",
)
parser.add_argument(
"--num-clean-epochs",
type=int,
default=0,
default=default_obj.num_clean_epochs,
help="Number of epochs to train on the clean dataset before adversarial training (N/A if --attack unspecified)",
)
parser.add_argument(
"--attack-epoch-interval",
type=int,
default=1,
default=default_obj.attack_epoch_interval,
help="Generate a new adversarial training set every N epochs.",
)
parser.add_argument(
"--early-stopping-epochs",
type=int,
default=None,
default=default_obj.early_stopping_epochs,
help="Number of epochs validation must increase before stopping early (-1 for no early stopping)",
)
parser.add_argument(
"--learning-rate",
"--lr",
type=float,
default=5e-5,
default=default_obj.learning_rate,
help="Learning rate for Adam Optimization.",
)
parser.add_argument(
"--num-warmup-steps",
type=float,
default=500,
type=int_or_float,
default=default_obj.num_warmup_steps,
help="The number of steps for the warmup phase of linear scheduler.",
)
parser.add_argument(
"--weight-decay",
type=float,
default=0.01,
default=default_obj.weight_decay,
help="Weight decay (L2 penalty).",
)
parser.add_argument(
"--per-device-train-batch-size",
type=int,
default=8,
default=default_obj.per_device_train_batch_size,
help="The batch size per GPU/CPU for training.",
)
parser.add_argument(
"--per-device-eval-batch-size",
type=int,
default=32,
default=default_obj.per_device_eval_batch_size,
help="The batch size per GPU/CPU for evaluation.",
)
parser.add_argument(
"--gradient-accumulation-steps",
type=int,
default=1,
default=default_obj.gradient_accumulation_steps,
help="Number of updates steps to accumulate the gradients for, before performing a backward/update pass.",
)
parser.add_argument("--random-seed", type=int, default=786, help="Random seed.")
parser.add_argument(
"--random-seed",
type=int,
default=default_obj.random_seed,
help="Random seed.",
)
parser.add_argument(
"--parallel",
action="store_true",
default=False,
default=default_obj.parallel,
help="If set, run training on multiple GPUs.",
)
parser.add_argument(
"--load-best-model-at-end",
action="store_true",
default=False,
default=default_obj.load_best_model_at_end,
help="If set, keep track of the best model across training and load it at the end.",
)
parser.add_argument(
"--eval-adversarial-robustness",
action="store_true",
default=False,
help="If set, evaluate adversarial robustness on evaluation dataset after every epoch.",
)
parser.add_argument(
"--num-eval-adv-examples",
type=int,
default=1000,
help="The number of samples attack if `eval_adversarial_robustness=True`. Default is 1000.",
"--alpha",
type=float,
default=1.0,
help="The weight of adversarial loss.",
)
parser.add_argument(
"--num-train-adv-examples",
type=int,
default=-1,
type=int_or_float,
default=default_obj.num_train_adv_examples,
help="The number of samples to attack when generating adversarial training set. Default is -1 (which is all possible samples).",
)
parser.add_argument(
"--query-budget-train",
type=int,
default=None,
default=default_obj.query_budget_train,
help="The max query budget to use when generating adversarial training set.",
)
parser.add_argument(
"--query-budget-eval",
type=int,
default=None,
help="The max query budget to use when evaluating adversarial robustness.",
)
parser.add_argument(
"--attack-num-workers-per-device",
type=int,
default=1,
default=default_obj.attack_num_workers_per_device,
help="Number of worker processes to run per device for attack. Same as `num_workers_per_device` argument for `AttackArgs`.",
)
parser.add_argument(
@@ -240,49 +270,49 @@ class TrainingArgs:
parser.add_argument(
"--checkpoint-interval-steps",
type=int,
default=None,
default=default_obj.checkpoint_interval_steps,
help="Save model checkpoint after every N updates to the model.",
)
parser.add_argument(
"--checkpoint-interval-epochs",
type=int,
default=None,
default=default_obj.checkpoint_interval_epochs,
help="Save model checkpoint after every N epochs.",
)
parser.add_argument(
"--save-last",
action="store_true",
default=True,
default=default_obj.save_last,
help="If set, save the model at end of training. Can be used with `--load-best-model-at-end` to save the best model at the end.",
)
parser.add_argument(
"--log-to-tb",
action="store_true",
default=False,
default=default_obj.log_to_tb,
help="If set, log to Tensorboard",
)
parser.add_argument(
"--tb-log-dir",
type=str,
default=None,
default=default_obj.tb_log_dir,
help="Path of Tensorboard log directory.",
)
parser.add_argument(
"--log-to-wandb",
action="store_true",
default=False,
default=default_obj.log_to_wandb,
help="If set, log to Wandb.",
)
parser.add_argument(
"--wandb-project",
type=str,
default="textattack",
default=default_obj.wandb_project,
help="Name of Wandb project for logging.",
)
parser.add_argument(
"--logging-interval-step",
type=int,
default=1,
default=default_obj.logging_interval_step,
help="Log to Tensorboard/Wandb every N steps.",
)