mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
make trainer extendable
This commit is contained in:
@@ -28,6 +28,7 @@ Complete API Reference
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
:exclude-members: CommandLineAttackArgs
|
||||
|
||||
.. automodule:: textattack.attack
|
||||
:members:
|
||||
@@ -48,18 +49,9 @@ Complete API Reference
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
:exclude-members: CommandLineTrainingArgs
|
||||
|
||||
.. automodule:: textattack.trainer
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. automodule:: textattack.dataset_args
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. automodule:: textattack.model_args
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
bert-score>=0.3.5
|
||||
editdistance
|
||||
flair==0.6.1.post1
|
||||
flair
|
||||
filelock
|
||||
language_tool_python
|
||||
lemminflect
|
||||
|
||||
@@ -30,24 +30,22 @@ goldmember is funny enough to [91mreasoned[0m the embarrassment of bringing a
|
||||
|
||||
|
||||
--------------------------------------------- Result 3 ---------------------------------------------
|
||||
[92mPositive (100%)[0m --> [91mNegative (60%)[0m
|
||||
[92mPositive (100%)[0m --> [91m[FAILED][0m
|
||||
|
||||
it [92mmay[0m not be particularly [92minnovative[0m , but the film's [92mcrisp[0m , [92munaffected[0m [92mstyle[0m and [92mair[0m of [92mgentle[0m [92mlonging[0m [92mmake[0m it [92munexpectedly[0m [92mrewarding[0m .
|
||||
|
||||
it [91mprobable[0m not be particularly [91mcreative[0m , but the film's [91mbrusque[0m , [91mundamaged[0m [91mshape[0m and [91mmidair[0m of [91mmild[0m [91mdesiring[0m [91mdoing[0m it [91msurprisingly[0m [91mbeneficial[0m .
|
||||
it may not be particularly innovative , but the film's crisp , unaffected style and air of gentle longing make it unexpectedly rewarding .
|
||||
|
||||
|
||||
|
||||
+-------------------------------+--------+
|
||||
| Attack Results | |
|
||||
+-------------------------------+--------+
|
||||
| Number of successful attacks: | 3 |
|
||||
| Number of failed attacks: | 0 |
|
||||
| Number of successful attacks: | 2 |
|
||||
| Number of failed attacks: | 1 |
|
||||
| Number of skipped attacks: | 0 |
|
||||
| Original accuracy: | 100.0% |
|
||||
| Accuracy under attack: | 0.0% |
|
||||
| Attack success rate: | 100.0% |
|
||||
| Average perturbed word %: | 23.71% |
|
||||
| Accuracy under attack: | 33.33% |
|
||||
| Attack success rate: | 66.67% |
|
||||
| Average perturbed word %: | 9.38% |
|
||||
| Average num. words per input: | 15.0 |
|
||||
| Avg num queries: | 71.0 |
|
||||
+-------------------------------+--------+
|
||||
|
||||
@@ -9,6 +9,7 @@ Attack: TextAttack builds attacks from four components:
|
||||
|
||||
The ``Attack`` class represents an adversarial attack composed of a goal function, search method, transformation, and constraints.
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
|
||||
import lru
|
||||
import torch
|
||||
@@ -358,8 +359,11 @@ class Attack:
|
||||
AttackResult
|
||||
"""
|
||||
assert isinstance(
|
||||
example, AttackedText
|
||||
), "`example` must be of type `AttackedText`."
|
||||
example, (str, OrderedDict, AttackedText)
|
||||
), "`example` must either be `str`, `collections.OrderedDict`, `textattack.shared.AttackedText`."
|
||||
if isinstance(example, (str, OrderedDict)):
|
||||
example = AttackedText(example)
|
||||
|
||||
assert isinstance(
|
||||
ground_truth_output, (int, str)
|
||||
), "`ground_truth_output` must either be `str` or `int`."
|
||||
|
||||
@@ -111,43 +111,73 @@ GOAL_FUNCTION_CLASS_NAMES = {
|
||||
|
||||
@dataclass
|
||||
class AttackArgs:
|
||||
"""Attack args for running attacks via API. This assumes that ``Attack``
|
||||
has already been created by the user.
|
||||
"""Attack arguments for running attacks via API. This assumes that
|
||||
``Attack`` has already been created by the user.
|
||||
|
||||
Args:
|
||||
num_examples (int): The number of examples to attack. -1 for entire dataset.
|
||||
num_examples_offset (int): The offset to start at in the dataset.
|
||||
query_budget (int): The maximum number of model queries allowed per example attacked.
|
||||
This is optional and setting this overwrites the query budget set in `GoalFunction` object.
|
||||
shuffle (bool): If `True`, shuffle the samples before we attack the dataset. Note this does not involve shuffling the dataset internally. Default is False.
|
||||
attack_n (bool): Whether to run attack until total of `n` examples have been attacked (not skipped).
|
||||
checkpoint_dir (str): The directory to save checkpoint files.
|
||||
checkpoint_interval (int): If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.
|
||||
random_seed (int): Random seed for reproducibility. Default is 765.
|
||||
parallel (bool): Run attack using multiple CPUs/GPUs.
|
||||
num_workers_per_device (int): Number of worker processes to run per device. For example, if you are using GPUs and ``num_workers_per_device=2``,
|
||||
then 2 processes will be running in each GPU. If you are only using CPU, then this is equivalent to running 2 processes concurrently.
|
||||
log_to_txt (str): Path to which to save attack logs as a text file. Set this argument if you want to save text logs.
|
||||
If the last part of the path ends with `.txt` extension, the path is assumed to path for output file.
|
||||
log_to_csv (str): Path to which to save attack logs as a CSV file. Set this argument if you want to save CSV logs.
|
||||
If the last part of the path ends with `.csv` extension, the path is assumed to path for output file.
|
||||
csv_coloring_style (str): Method for choosing how to mark perturbed parts of the text. Options are "file" and "plain".
|
||||
"file" wraps text with double brackets `[[ <text> ]]` while "plain" does not mark any text. Default is "file".
|
||||
log_to_visdom (dict): Set this argument if you want to log attacks to Visdom. The dictionary should have the following
|
||||
three keys and their corresponding values: `"env", "port", "hostname"` (e.g. `{"env": "main", "port": 8097, "hostname": "localhost"}`).
|
||||
log_to_wandb (str): Name of the wandb project. Set this argument if you want to log attacks to Wandb.
|
||||
disable_stdout (bool): Disable logging attack results to stdout.
|
||||
silent (bool): Disable all logging.
|
||||
ignore_exceptions (bool): Skip examples that raise an error instead of exiting.
|
||||
num_examples (:obj:`int`, 'optional`, defaults to :obj:`10`):
|
||||
The number of examples to attack. -1 for entire dataset.
|
||||
num_successful_examples (:obj:`int`, `optional`, defaults to :obj:`None`):
|
||||
The number of successful adversarial examples we want. This is different from `num_examples`
|
||||
as `num_examples` only cares about attacking `N` samples while `num_successful_examples` aims to keep attacking
|
||||
until we have `N` successful cases.
|
||||
|
||||
.. note::
|
||||
If set, this argument overrides `num_examples` argument.
|
||||
num_examples_offset (:obj: `int`, `optional`, defaults to :obj:`0`):
|
||||
The offset index to start at in the dataset.
|
||||
attack_n (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to run attack until total of `N` examples have been attacked (and not skipped).
|
||||
shuffle (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If :obj:`True`, we randomly shuffle the dataset before attacking. However, this avoids actually shuffling
|
||||
the dataset internally and opts for shuffling the list of indices of examples we want to attack. This means
|
||||
`shuffle` can now be used with checkpoint saving.
|
||||
query_budget (:obj:`int`, `optional`, defaults to :obj:`None`):
|
||||
The maximum number of model queries allowed per example attacked.
|
||||
If not set, we use the query budget set in the `GoalFunction` object (which by default is `float("inf")`).
|
||||
|
||||
.. note::
|
||||
Setting this overwrites the query budget set in `GoalFunction` object.
|
||||
checkpoint_interval (:obj:`int`, `optional`, defaults to :obj:`None`):
|
||||
If set, checkpoint will be saved after attacking every `N` examples. If `None` is passed, no checkpoints will be saved.
|
||||
checkpoint_dir (:obj:`str`, `optional`, defaults to :obj:`"checkpoints"`):
|
||||
The directory to save checkpoint files.
|
||||
random_seed (:obj:`int`, `optional`, defaults to :obj:`765`):
|
||||
Random seed for reproducibility.
|
||||
parallel (:obj:`False`, `optional`, defaults to :obj:`False`):
|
||||
If :obj:`True`, run attack using multiple CPUs/GPUs.
|
||||
num_workers_per_device (:obj:`int`, `optional`, defaults to :obj:`1`):
|
||||
Number of worker processes to run per device in parallel mode (i.e. `parallel=True`). For example, if you are using GPUs and `num_workers_per_device=2`,
|
||||
then 2 processes will be running in each GPU.
|
||||
log_to_txt (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
If set, save attack logs as a `.txt` file to the directory specified by this argument.
|
||||
If the last part of the provided path ends with `.txt` extension, it is assumed to the desired path of the log file.
|
||||
log_to_csv (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
If set, save attack logs as a CSV file to the directory specified by this argument.
|
||||
If the last part of the provided path ends with `.csv` extension, it is assumed to the desired path of the log file.
|
||||
csv_coloring_style (:obj:`str`, `optional`, defaults to :obj:`"file"`):
|
||||
Method for choosing how to mark perturbed parts of the text. Options are `"file"` and `"plain"`.
|
||||
`"file"` wraps perturbed parts with double brackets `[[ <text> ]]` while `"plain"` does not mark the text in any way.
|
||||
log_to_visdom (:obj:`dict`, `optional`, defaults to :obj:`None`):
|
||||
If set, Visdom logger is used with the provided dictionary passed as a keyword arguments to `textattack.loggers.VisdomLogger`.
|
||||
Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following
|
||||
three keys and their corresponding values: `"env"`, `"port"`, `"hostname"`.
|
||||
log_to_wandb (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
If set, log the attack results and summary to Wandb project specified by this argument.
|
||||
disable_stdout (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Disable displaying individual attack results to stdout.
|
||||
silent (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Disable all logging (except for errors). This is stronger than `disable_stdout`.
|
||||
"""
|
||||
|
||||
num_examples: int = 5
|
||||
num_examples: int = 10
|
||||
num_successful_examples: int = None
|
||||
num_examples_offset: int = 0
|
||||
query_budget: int = None
|
||||
shuffle: bool = False
|
||||
attack_n: bool = False
|
||||
checkpoint_dir: str = "checkpoints"
|
||||
shuffle: bool = False
|
||||
query_budget: int = None
|
||||
checkpoint_interval: int = None
|
||||
checkpoint_dir: str = "checkpoints"
|
||||
random_seed: int = 765 # equivalent to sum((ord(c) for c in "TEXTATTACK"))
|
||||
parallel: bool = False
|
||||
num_workers_per_device: int = 1
|
||||
@@ -158,106 +188,138 @@ class AttackArgs:
|
||||
log_to_wandb: str = None
|
||||
disable_stdout: bool = False
|
||||
silent: bool = False
|
||||
ignore_exceptions: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_successful_examples:
|
||||
self.num_examples = None
|
||||
if self.num_examples:
|
||||
assert (
|
||||
self.num_examples > 0 or self.num_examples == -1
|
||||
), "`num_examples` must be greater than 0 or equal to -1."
|
||||
if self.num_successful_examples:
|
||||
assert (
|
||||
self.num_successful_examples > 0
|
||||
), "`num_examples` must be greater than 0."
|
||||
|
||||
if self.query_budget:
|
||||
assert self.query_budget > 0, "`query_budget` must be greater than 0"
|
||||
|
||||
if self.checkpoint_interval:
|
||||
assert (
|
||||
self.checkpoint_interval > 0
|
||||
), "`checkpoint_interval` must be greater than 0"
|
||||
|
||||
assert (
|
||||
self.num_workers_per_device > 0
|
||||
), "`num_workers_per_device` must be greater than 0"
|
||||
|
||||
assert self.csv_coloring_style in {
|
||||
"file",
|
||||
"plain",
|
||||
}, '`csv_coloring_style` must either be "file" or "plain".'
|
||||
|
||||
@classmethod
|
||||
def add_parser_args(cls, parser):
|
||||
"""Add listed args to command line parser."""
|
||||
parser.add_argument(
|
||||
default_obj = cls()
|
||||
num_ex_group = parser.add_mutually_exclusive_group(required=False)
|
||||
num_ex_group.add_argument(
|
||||
"--num-examples",
|
||||
"-n",
|
||||
type=int,
|
||||
required=False,
|
||||
default=5,
|
||||
help="The number of examples to process, -1 for entire dataset",
|
||||
default=default_obj.num_examples,
|
||||
help="The number of examples to process, -1 for entire dataset.",
|
||||
)
|
||||
num_ex_group.add_argument(
|
||||
"--num-successful-examples",
|
||||
type=int,
|
||||
default=default_obj.num_successful_examples,
|
||||
help="The number of successful adversarial examples we want.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-examples-offset",
|
||||
"-o",
|
||||
type=int,
|
||||
required=False,
|
||||
default=0,
|
||||
default=default_obj.num_examples_offset,
|
||||
help="The offset to start at in the dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query-budget",
|
||||
"-q",
|
||||
type=int,
|
||||
default=None,
|
||||
default=default_obj.query_budget,
|
||||
help="The maximum number of model queries allowed per example attacked. Setting this overwrites the query budget set in `GoalFunction` object.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shuffle",
|
||||
action="store_true",
|
||||
default=False,
|
||||
default=default_obj.shuffle,
|
||||
help="If `True`, shuffle the samples before we attack the dataset. Default is False.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attack-n",
|
||||
action="store_true",
|
||||
default=False,
|
||||
default=default_obj.attack_n,
|
||||
help="Whether to run attack until `n` examples have been attacked (not skipped).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint-dir",
|
||||
required=False,
|
||||
type=str,
|
||||
default="checkpoints",
|
||||
default=default_obj.checkpoint_dir,
|
||||
help="The directory to save checkpoint files.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint-interval",
|
||||
required=False,
|
||||
type=int,
|
||||
default=default_obj.checkpoint_interval,
|
||||
help="If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.",
|
||||
)
|
||||
|
||||
def str_to_int(s):
|
||||
return sum((ord(c) for c in s))
|
||||
|
||||
parser.add_argument("--random-seed", default=str_to_int("TEXTATTACK"), type=int)
|
||||
|
||||
parser.add_argument(
|
||||
"--random-seed",
|
||||
default=default_obj.random_seed,
|
||||
type=int,
|
||||
help="Random seed for reproducibility.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--parallel",
|
||||
action="store_true",
|
||||
default=False,
|
||||
default=default_obj.parallel,
|
||||
help="Run attack using multiple GPUs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers-per-device",
|
||||
default=1,
|
||||
default=default_obj.num_workers_per_device,
|
||||
type=int,
|
||||
help="Number of worker processes to run per device.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-to-txt",
|
||||
nargs="?",
|
||||
default=None,
|
||||
default=default_obj.log_to_txt,
|
||||
const="",
|
||||
type=str,
|
||||
help="Path to which to save attack logs as a text file. Set this argument if you want to save text logs. "
|
||||
"If the last part of the path ends with `.txt` extension, the path is assumed to path for output file.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-to-csv",
|
||||
nargs="?",
|
||||
default=None,
|
||||
default=default_obj.log_to_csv,
|
||||
const="",
|
||||
type=str,
|
||||
help="Path to which to save attack logs as a CSV file. Set this argument if you want to save CSV logs. "
|
||||
"If the last part of the path ends with `.csv` extension, the path is assumed to path for output file.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--csv-coloring-style",
|
||||
default="file",
|
||||
default=default_obj.csv_coloring_style,
|
||||
type=str,
|
||||
help='Method for choosing how to mark perturbed parts of the text in CSV logs. Options are "file" and "plain". '
|
||||
'"file" wraps text with double brackets `[[ <text> ]]` while "plain" does not mark any text. Default is "file".',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-to-visdom",
|
||||
nargs="?",
|
||||
@@ -268,30 +330,25 @@ class AttackArgs:
|
||||
'three keys and their corresponding values: `"env", "port", "hostname"`. '
|
||||
'Example for command line use: `--log-to-visdom {"env": "main", "port": 8097, "hostname": "localhost"}`.',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-to-wandb",
|
||||
nargs="?",
|
||||
default=None,
|
||||
default=default_obj.log_to_wandb,
|
||||
const="textattack",
|
||||
type=str,
|
||||
help="Name of the wandb project. Set this argument if you want to log attacks to Wandb.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--disable-stdout",
|
||||
action="store_true",
|
||||
default=default_obj.disable_stdout,
|
||||
help="Disable logging attack results to stdout",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--silent", action="store_true", default=False, help="Disable all logging"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ignore-exceptions",
|
||||
"--silent",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Skip examples that raise an error instead of exiting.",
|
||||
default=default_obj.silent,
|
||||
help="Disable all logging",
|
||||
)
|
||||
|
||||
return parser
|
||||
@@ -356,23 +413,36 @@ class AttackArgs:
|
||||
|
||||
@dataclass
|
||||
class _CommandLineAttackArgs:
|
||||
"""Command line interface attack args. This requires more arguments to
|
||||
"""Attack args for command line execution. This requires more arguments to
|
||||
create ``Attack`` object as specified.
|
||||
|
||||
Args:
|
||||
transformation (str): Name of transformation to use.
|
||||
constraints (list[str]): List of names of constraints to use.
|
||||
goal_function (str): Name of goal function to use.
|
||||
search_method (str): Name of search method to use.
|
||||
attack_recipe (str): Name of attack recipe to use.
|
||||
If this is set, it overrides any previous selection of transformation, constraints, goal function, and search method.
|
||||
attack_from_file (str): Path of `.py` file from which to load attack from. Use `<path>^<variable_name>` to specifiy which variable to import from the file.
|
||||
If this is set, it overrides any previous selection of transformation, constraints, goal function, and search method
|
||||
interactive (bool): If `True`, carry attack in interactive mode. Default is `False`.
|
||||
parallel (bool): If `True`, attack in parallel. Default is `False`.
|
||||
model_batch_size (int): The batch size for making calls to the model.
|
||||
model_cache_size (int): The maximum number of items to keep in the model results cache at once.
|
||||
constraint-cache-size (int): The maximum number of items to keep in the constraints cache at once.
|
||||
transformation (:obj:`str`, `optional`, defaults to :obj:`"word-swap-embedding"`):
|
||||
Name of transformation to use.
|
||||
constraints (:obj:`list[str]`, `optional`, defaults to :obj:`["repeat", "stopword"]`):
|
||||
List of names of constraints to use.
|
||||
goal_function (:obj:`str`, `optional`, defaults to :obj:`"untargeted-classification"`):
|
||||
Name of goal function to use.
|
||||
search_method (:obj:`str`, `optional`, defualts to :obj:`"greedy-word-wir"`):
|
||||
Name of search method to use.
|
||||
attack_recipe (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
Name of attack recipe to use.
|
||||
.. note::
|
||||
Setting this overrides any previous selection of transformation, constraints, goal function, and search method.
|
||||
attack_from_file (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
Path of `.py` file from which to load attack from. Use `<path>^<variable_name>` to specifiy which variable to import from the file.
|
||||
.. note::
|
||||
If this is set, it overrides any previous selection of transformation, constraints, goal function, and search method
|
||||
interactive (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If `True`, carry attack in interactive mode.
|
||||
parallel (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If `True`, attack in parallel.
|
||||
model_batch_size (:obj:`int`, `optional`, defaults to :obj:`32`):
|
||||
The batch size for making queries to the victim model.
|
||||
model_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**18`):
|
||||
The maximum number of items to keep in the model results cache at once.
|
||||
constraint-cache-size (:obj:`int`, `optional`, defaults to :obj:`2**18`):
|
||||
The maximum number of items to keep in the constraints cache at once.
|
||||
"""
|
||||
|
||||
transformation: str = "word-swap-embedding"
|
||||
@@ -390,6 +460,7 @@ class _CommandLineAttackArgs:
|
||||
@classmethod
|
||||
def add_parser_args(cls, parser):
|
||||
"""Add listed args to command line parser."""
|
||||
default_obj = cls()
|
||||
transformation_names = set(BLACK_BOX_TRANSFORMATION_CLASS_NAMES.keys()) | set(
|
||||
WHITE_BOX_TRANSFORMATION_CLASS_NAMES.keys()
|
||||
)
|
||||
@@ -397,7 +468,7 @@ class _CommandLineAttackArgs:
|
||||
"--transformation",
|
||||
type=str,
|
||||
required=False,
|
||||
default="word-swap-embedding",
|
||||
default=default_obj.transformation,
|
||||
help='The transformation to apply. Usage: "--transformation {transformation}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
|
||||
+ str(transformation_names),
|
||||
)
|
||||
@@ -406,7 +477,7 @@ class _CommandLineAttackArgs:
|
||||
type=str,
|
||||
required=False,
|
||||
nargs="*",
|
||||
default=["repeat", "stopword"],
|
||||
default=default_obj.constraints,
|
||||
help='Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
|
||||
+ str(CONSTRAINT_CLASS_NAMES.keys()),
|
||||
)
|
||||
@@ -414,7 +485,7 @@ class _CommandLineAttackArgs:
|
||||
parser.add_argument(
|
||||
"--goal-function",
|
||||
"-g",
|
||||
default="untargeted-classification",
|
||||
default=default_obj.goal_function,
|
||||
help=f"The goal function to use. choices: {goal_function_choices}",
|
||||
)
|
||||
attack_group = parser.add_mutually_exclusive_group(required=False)
|
||||
@@ -425,7 +496,7 @@ class _CommandLineAttackArgs:
|
||||
"-s",
|
||||
type=str,
|
||||
required=False,
|
||||
default="greedy-word-wir",
|
||||
default=default_obj.search_method,
|
||||
help=f"The search method to use. choices: {search_choices}",
|
||||
)
|
||||
attack_group.add_argument(
|
||||
@@ -434,7 +505,7 @@ class _CommandLineAttackArgs:
|
||||
"-r",
|
||||
type=str,
|
||||
required=False,
|
||||
default=None,
|
||||
default=default_obj.attack_recipe,
|
||||
help="full attack recipe (overrides provided goal function, transformation & constraints)",
|
||||
choices=ATTACK_RECIPE_NAMES.keys(),
|
||||
)
|
||||
@@ -442,31 +513,31 @@ class _CommandLineAttackArgs:
|
||||
"--attack-from-file",
|
||||
type=str,
|
||||
required=False,
|
||||
default=None,
|
||||
default=default_obj.attack_from_file,
|
||||
help="Path of `.py` file from which to load attack from. Use `<path>^<variable_name>` to specifiy which variable to import from the file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interactive",
|
||||
action="store_true",
|
||||
default=False,
|
||||
default=default_obj.interactive,
|
||||
help="Whether to run attacks interactively.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
default=default_obj.model_batch_size,
|
||||
help="The batch size for making calls to the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-cache-size",
|
||||
type=int,
|
||||
default=2 ** 18,
|
||||
default=default_obj.model_cache_size,
|
||||
help="The maximum number of items to keep in the model results cache at once.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--constraint-cache-size",
|
||||
type=int,
|
||||
default=2 ** 18,
|
||||
default=default_obj.constraint_cache_size,
|
||||
help="The maximum number of items to keep in the constraints cache at once.",
|
||||
)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import multiprocessing as mp
|
||||
import os
|
||||
import queue
|
||||
import random
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
@@ -126,13 +127,13 @@ class Attacker:
|
||||
else:
|
||||
self._attack()
|
||||
|
||||
# Turn logger back on since other modules can be using Attacker (e.g. Trainer)
|
||||
if self.attack_args.silent:
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
return self.attack_log_manager.results
|
||||
|
||||
def _get_worklist(self, start, end, num_examples, shuffle):
|
||||
if end - start > num_examples:
|
||||
if end - start < num_examples:
|
||||
logger.warn(
|
||||
f"Attempting to attack {num_examples} samples when only {end-start} are available."
|
||||
)
|
||||
@@ -141,6 +142,7 @@ class Attacker:
|
||||
random.shuffle(candidates)
|
||||
worklist = collections.deque(candidates[:num_examples])
|
||||
candidates = collections.deque(candidates[num_examples:])
|
||||
assert (len(worklist) + len(candidates)) == (end - start)
|
||||
return worklist, candidates
|
||||
|
||||
def _attack(self):
|
||||
@@ -157,15 +159,26 @@ class Attacker:
|
||||
f"Recovered from checkpoint previously saved at {self._checkpoint.datetime}."
|
||||
)
|
||||
else:
|
||||
num_remaining_attacks = self.attack_args.num_examples
|
||||
# We make `worklist` deque (linked-list) for easy pop and append.
|
||||
# Candidates are other samples we can attack if we need more samples.
|
||||
worklist, worklist_candidates = self._get_worklist(
|
||||
self.attack_args.num_examples_offset,
|
||||
len(self.dataset),
|
||||
self.attack_args.num_examples,
|
||||
self.attack_args.shuffle,
|
||||
)
|
||||
if self.attack_args.num_successful_examples:
|
||||
num_remaining_attacks = self.attack_args.num_successful_examples
|
||||
# We make `worklist` deque (linked-list) for easy pop and append.
|
||||
# Candidates are other samples we can attack if we need more samples.
|
||||
worklist, worklist_candidates = self._get_worklist(
|
||||
self.attack_args.num_examples_offset,
|
||||
len(self.dataset),
|
||||
self.attack_args.num_successful_examples,
|
||||
self.attack_args.shuffle,
|
||||
)
|
||||
else:
|
||||
num_remaining_attacks = self.attack_args.num_examples
|
||||
# We make `worklist` deque (linked-list) for easy pop and append.
|
||||
# Candidates are other samples we can attack if we need more samples.
|
||||
worklist, worklist_candidates = self._get_worklist(
|
||||
self.attack_args.num_examples_offset,
|
||||
len(self.dataset),
|
||||
self.attack_args.num_examples,
|
||||
self.attack_args.shuffle,
|
||||
)
|
||||
|
||||
if not self.attack_args.silent:
|
||||
print(self.attack, "\n")
|
||||
@@ -174,19 +187,20 @@ class Attacker:
|
||||
if self._checkpoint:
|
||||
num_results = self._checkpoint.results_count
|
||||
num_failures = self._checkpoint.num_failed_attacks
|
||||
num_skipped = self._checkpoint.num_skipped_attacks
|
||||
num_successes = self._checkpoint.num_successful_attacks
|
||||
else:
|
||||
num_results = 0
|
||||
num_failures = 0
|
||||
num_skipped = 0
|
||||
num_successes = 0
|
||||
|
||||
if hasattr(self.attack.goal_function.model, "to"):
|
||||
self.attack.goal_function.model.to(textattack.shared.utils.device)
|
||||
|
||||
i = 0
|
||||
sample_exhaustion_warned = False
|
||||
while worklist:
|
||||
idx = worklist.popleft()
|
||||
i += 1
|
||||
try:
|
||||
example, ground_truth_output = self.dataset[idx]
|
||||
except IndexError:
|
||||
@@ -197,32 +211,36 @@ class Attacker:
|
||||
try:
|
||||
result = self.attack.attack(example, ground_truth_output)
|
||||
except Exception as e:
|
||||
if self.attack_args.ignore_exceptions:
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
self.attack_log_manager.log_result(result)
|
||||
|
||||
if not self.attack_args.disable_stdout:
|
||||
print("\n")
|
||||
|
||||
if isinstance(result, SkippedAttackResult) and self.attack_args.attack_n:
|
||||
raise e
|
||||
if (
|
||||
isinstance(result, SkippedAttackResult) and self.attack_args.attack_n
|
||||
) or (
|
||||
not isinstance(result, SuccessfulAttackResult)
|
||||
and self.attack_args.num_successful_examples
|
||||
):
|
||||
if worklist_candidates:
|
||||
next_sample = worklist_candidates.popleft()
|
||||
worklist.append(next_sample)
|
||||
else:
|
||||
logger.warn("`attack_n=True` but no more samples to attack.")
|
||||
if not sample_exhaustion_warned:
|
||||
logger.warn("Ran out of samples to attack!")
|
||||
sample_exhaustion_warned = True
|
||||
else:
|
||||
pbar.update(1)
|
||||
|
||||
self.attack_log_manager.log_result(result)
|
||||
if not self.attack_args.disable_stdout and not self.attack_args.silent:
|
||||
print("\n")
|
||||
num_results += 1
|
||||
|
||||
if isinstance(result, SkippedAttackResult):
|
||||
num_skipped += 1
|
||||
if isinstance(result, (SuccessfulAttackResult, MaximizedAttackResult)):
|
||||
num_successes += 1
|
||||
if isinstance(result, FailedAttackResult):
|
||||
num_failures += 1
|
||||
pbar.set_description(
|
||||
f"[Succeeded / Failed / Total] {num_successes} / {num_failures} / {num_results}"
|
||||
f"[Succeeded / Failed / Skipped / Total] {num_successes} / {num_failures} / {num_skipped} / {num_results}"
|
||||
)
|
||||
|
||||
if (
|
||||
@@ -260,13 +278,26 @@ class Attacker:
|
||||
f"Recovered from checkpoint previously saved at {self._checkpoint.datetime}."
|
||||
)
|
||||
else:
|
||||
num_remaining_attacks = self.attack_args.num_examples
|
||||
worklist, worklist_candidates = self._get_worklist(
|
||||
self.attack_args.num_examples_offset,
|
||||
len(self.dataset),
|
||||
self.attack_args.num_examples,
|
||||
self.attack_args.shuffle,
|
||||
)
|
||||
if self.attack_args.num_successful_examples:
|
||||
num_remaining_attacks = self.attack_args.num_successful_examples
|
||||
# We make `worklist` deque (linked-list) for easy pop and append.
|
||||
# Candidates are other samples we can attack if we need more samples.
|
||||
worklist, worklist_candidates = self._get_worklist(
|
||||
self.attack_args.num_examples_offset,
|
||||
len(self.dataset),
|
||||
self.attack_args.num_successful_examples,
|
||||
self.attack_args.shuffle,
|
||||
)
|
||||
else:
|
||||
num_remaining_attacks = self.attack_args.num_examples
|
||||
# We make `worklist` deque (linked-list) for easy pop and append.
|
||||
# Candidates are other samples we can attack if we need more samples.
|
||||
worklist, worklist_candidates = self._get_worklist(
|
||||
self.attack_args.num_examples_offset,
|
||||
len(self.dataset),
|
||||
self.attack_args.num_examples,
|
||||
self.attack_args.shuffle,
|
||||
)
|
||||
|
||||
in_queue = torch.multiprocessing.Queue()
|
||||
out_queue = torch.multiprocessing.Queue()
|
||||
@@ -313,25 +344,38 @@ class Attacker:
|
||||
if self._checkpoint:
|
||||
num_results = self._checkpoint.results_count
|
||||
num_failures = self._checkpoint.num_failed_attacks
|
||||
num_skipped = self._checkpoint.num_skipped_attacks
|
||||
num_successes = self._checkpoint.num_successful_attacks
|
||||
else:
|
||||
num_results = 0
|
||||
num_failures = 0
|
||||
num_skipped = 0
|
||||
num_successes = 0
|
||||
|
||||
logger.info(f"Worklist size: {len(worklist)}")
|
||||
logger.info(f"Worklist candidate size: {len(worklist_candidates)}")
|
||||
|
||||
sample_exhaustion_warned = False
|
||||
pbar = tqdm.tqdm(total=num_remaining_attacks, smoothing=0)
|
||||
while worklist:
|
||||
idx, result = out_queue.get(block=True)
|
||||
worklist.remove(idx)
|
||||
if isinstance(result, Exception):
|
||||
if self.attack_args.ignore_exceptions:
|
||||
continue
|
||||
else:
|
||||
worker_pool.terminate()
|
||||
worker_pool.join()
|
||||
raise result
|
||||
|
||||
self.attack_log_manager.log_result(result)
|
||||
if self.attack_args.attack_n and isinstance(result, SkippedAttackResult):
|
||||
if isinstance(result, tuple) and isinstance(result[0], Exception):
|
||||
logger.error(
|
||||
f'Exception encountered for input "{self.dataset[idx][0]}".'
|
||||
)
|
||||
error_trace = result[1]
|
||||
logger.error(error_trace)
|
||||
worker_pool.terminate()
|
||||
worker_pool.join()
|
||||
exit(1)
|
||||
elif (
|
||||
isinstance(result, SkippedAttackResult) and self.attack_args.attack_n
|
||||
) or (
|
||||
not isinstance(result, SuccessfulAttackResult)
|
||||
and self.attack_args.num_successful_examples
|
||||
):
|
||||
if worklist_candidates:
|
||||
next_sample = worklist_candidates.popleft()
|
||||
example, ground_truth_output = self.dataset[next_sample]
|
||||
@@ -341,21 +385,24 @@ class Attacker:
|
||||
worklist.append(next_sample)
|
||||
in_queue.put((next_sample, example, ground_truth_output))
|
||||
else:
|
||||
logger.warn(
|
||||
f"Attempted to attack {self.attack_args.num_examples} examples with but ran out of examples. "
|
||||
f"You might see fewer number of results than {self.attack_args.num_examples}."
|
||||
)
|
||||
if not sample_exhaustion_warned:
|
||||
logger.warn("Ran out of samples to attack!")
|
||||
sample_exhaustion_warned = True
|
||||
else:
|
||||
pbar.update()
|
||||
num_results += 1
|
||||
|
||||
if isinstance(result, (SuccessfulAttackResult, MaximizedAttackResult)):
|
||||
num_successes += 1
|
||||
if isinstance(result, FailedAttackResult):
|
||||
num_failures += 1
|
||||
pbar.set_description(
|
||||
f"[Succeeded / Failed / Total] {num_successes} / {num_failures} / {num_results}"
|
||||
)
|
||||
self.attack_log_manager.log_result(result)
|
||||
num_results += 1
|
||||
|
||||
if isinstance(result, SkippedAttackResult):
|
||||
num_skipped += 1
|
||||
if isinstance(result, (SuccessfulAttackResult, MaximizedAttackResult)):
|
||||
num_successes += 1
|
||||
if isinstance(result, FailedAttackResult):
|
||||
num_failures += 1
|
||||
pbar.set_description(
|
||||
f"[Succeeded / Failed / Skipped / Total] {num_successes} / {num_failures} / {num_skipped} / {num_results}"
|
||||
)
|
||||
|
||||
if (
|
||||
self.attack_args.checkpoint_interval
|
||||
@@ -372,7 +419,10 @@ class Attacker:
|
||||
new_checkpoint.save()
|
||||
self.attack_log_manager.flush()
|
||||
|
||||
worker_pool.terminate()
|
||||
# Send sentinel values to worker processes
|
||||
for _ in range(num_workers):
|
||||
in_queue.put(("END", "END", "END"))
|
||||
worker_pool.close()
|
||||
worker_pool.join()
|
||||
|
||||
pbar.close()
|
||||
@@ -504,11 +554,14 @@ def attack_from_queue(
|
||||
while True:
|
||||
try:
|
||||
i, example, ground_truth_output = in_queue.get(timeout=5)
|
||||
result = attack.attack(example, ground_truth_output)
|
||||
out_queue.put((i, result))
|
||||
if i == "END" and example == "END" and ground_truth_output == "END":
|
||||
# End process when sentinel value is received
|
||||
break
|
||||
else:
|
||||
result = attack.attack(example, ground_truth_output)
|
||||
out_queue.put((i, result))
|
||||
except Exception as e:
|
||||
if isinstance(e, queue.Empty):
|
||||
continue
|
||||
out_queue.put((i, e))
|
||||
if not attack_args.ignore_exceptions:
|
||||
exit(1)
|
||||
else:
|
||||
out_queue.put((i, (e, traceback.format_exc(e))))
|
||||
|
||||
@@ -10,3 +10,4 @@ from .repeat_modification import RepeatModification
|
||||
from .input_column_modification import InputColumnModification
|
||||
from .max_word_index_modification import MaxWordIndexModification
|
||||
from .min_word_length import MinWordLength
|
||||
from .max_modification_rate import MaxModificationRate
|
||||
|
||||
@@ -92,7 +92,7 @@ class HuggingFaceModelWrapper(PyTorchModelWrapper):
|
||||
self.model.zero_grad()
|
||||
model_device = next(self.model.parameters()).device
|
||||
input_dict = self.tokenizer(
|
||||
text_input,
|
||||
[text_input],
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
@@ -135,7 +135,7 @@ class HuggingFaceModelWrapper(PyTorchModelWrapper):
|
||||
"""
|
||||
return [
|
||||
self.tokenizer.convert_ids_to_tokens(
|
||||
self.tokenizer(x, truncation=True)["input_ids"]
|
||||
self.tokenizer([x], truncation=True)["input_ids"][0]
|
||||
)
|
||||
for x in inputs
|
||||
]
|
||||
|
||||
@@ -72,6 +72,9 @@ class GreedyWordSwapWIR(SearchMethod):
|
||||
continue
|
||||
swap_results, _ = self.get_goal_results(transformed_text_candidates)
|
||||
score_change = [result.score for result in swap_results]
|
||||
if not score_change:
|
||||
delta_ps.append(0.0)
|
||||
continue
|
||||
max_score_change = np.max(score_change)
|
||||
delta_ps.append(max_score_change)
|
||||
|
||||
|
||||
@@ -103,11 +103,13 @@ class AttackedText:
|
||||
"""
|
||||
if "previous_attacked_text" in self.attack_attrs:
|
||||
self.attack_attrs["previous_attacked_text"].free_memory()
|
||||
if "last_transformation" in self.attack_attrs:
|
||||
del self.attack_attrs["last_transformation"]
|
||||
self.attack_attrs.pop("previous_attacked_text", None)
|
||||
|
||||
self.attack_attrs.pop("last_transformation", None)
|
||||
|
||||
for key in self.attack_attrs:
|
||||
if isinstance(self.attack_attrs[key], torch.Tensor):
|
||||
del self.attack_attrs[key]
|
||||
self.attack_attrs.pop(key, None)
|
||||
|
||||
def text_window_around_index(self, index, window_size):
|
||||
"""The text window of ``window_size`` words centered around
|
||||
|
||||
@@ -1,42 +1,28 @@
|
||||
import collections
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
import torch
|
||||
import tqdm
|
||||
import transformers
|
||||
|
||||
import textattack
|
||||
from textattack.models.helpers import LSTMForClassification, WordCNNForClassification
|
||||
from textattack.models.wrappers import ModelWrapper
|
||||
|
||||
from .attack import Attack
|
||||
from .attack_args import AttackArgs
|
||||
from .attack_results import MaximizedAttackResult, SuccessfulAttackResult
|
||||
from .attacker import Attacker
|
||||
from .model_args import HUGGINGFACE_MODELS
|
||||
from .models.helpers import LSTMForClassification, WordCNNForClassification
|
||||
from .models.wrappers import ModelWrapper
|
||||
from .training_args import CommandLineTrainingArgs, TrainingArgs
|
||||
|
||||
logger = textattack.shared.logger
|
||||
|
||||
|
||||
# Helper functions for collating data
|
||||
def collate_fn(input_columns, data):
|
||||
input_texts = []
|
||||
labels = []
|
||||
for _input, label in data:
|
||||
_input = tuple(_input[c] for c in input_columns)
|
||||
if len(_input) == 1:
|
||||
_input = _input[0]
|
||||
input_texts.append(_input)
|
||||
labels.append(label)
|
||||
return input_texts, torch.tensor(labels)
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""Trainer is training and eval loop for adversarial training.
|
||||
|
||||
@@ -47,10 +33,10 @@ class Trainer:
|
||||
self,
|
||||
model_wrapper,
|
||||
task_type,
|
||||
attack,
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
training_args,
|
||||
attack=None,
|
||||
train_dataset=None,
|
||||
eval_dataset=None,
|
||||
training_args=TrainingArgs(),
|
||||
):
|
||||
assert isinstance(
|
||||
model_wrapper, ModelWrapper
|
||||
@@ -60,15 +46,18 @@ class Trainer:
|
||||
"classification",
|
||||
"regression",
|
||||
}, '`task_type` must either be "classification" or "regression"'
|
||||
assert isinstance(
|
||||
attack, Attack
|
||||
), f"`attack` argument must be of type `textattack.Attack`, but got type of `{type(attack)}`."
|
||||
assert isinstance(
|
||||
train_dataset, textattack.datasets.Dataset
|
||||
), f"`train_dataset` must be of type `textattack.datasets.Dataset`, but got type `{type(train_dataset)}`."
|
||||
assert isinstance(
|
||||
eval_dataset, textattack.datasets.Dataset
|
||||
), f"`eval_dataset` must be of type `textattack.datasets.Dataset`, but got type `{type(eval_dataset)}`."
|
||||
if attack:
|
||||
assert isinstance(
|
||||
attack, Attack
|
||||
), f"`attack` argument must be of type `textattack.Attack`, but got type of `{type(attack)}`."
|
||||
if train_dataset:
|
||||
assert isinstance(
|
||||
train_dataset, textattack.datasets.Dataset
|
||||
), f"`train_dataset` must be of type `textattack.datasets.Dataset`, but got type `{type(train_dataset)}`."
|
||||
if eval_dataset:
|
||||
assert isinstance(
|
||||
eval_dataset, textattack.datasets.Dataset
|
||||
), f"`eval_dataset` must be of type `textattack.datasets.Dataset`, but got type `{type(eval_dataset)}`."
|
||||
assert isinstance(
|
||||
training_args, TrainingArgs
|
||||
), f"`training_args` must be of type `textattack.TrainingArgs`, but got type `{type(training_args)}`."
|
||||
@@ -94,121 +83,88 @@ class Trainer:
|
||||
self.eval_dataset = eval_dataset
|
||||
self.training_args = training_args
|
||||
|
||||
def _generate_adversarial_examples(self, dataset, epoch, eval_mode=False):
|
||||
"""Generate adversarial examples using attacker."""
|
||||
if eval_mode:
|
||||
logger.info("Attacking model to evaluate adversarial robustness...")
|
||||
self._metric_name = (
|
||||
"pearson_correlation" if self.task_type == "regression" else "accuracy"
|
||||
)
|
||||
if self.task_type == "regression":
|
||||
self.loss_fct = torch.nn.MSELoss(reduction="none")
|
||||
else:
|
||||
logger.info("Attacking model to generate new adversarial training set...")
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
num_examples = (
|
||||
self.training_args.num_eval_adv_examples
|
||||
if eval_mode
|
||||
else self.training_args.num_train_adv_examples
|
||||
)
|
||||
query_budget = (
|
||||
self.training_args.query_budget_eval
|
||||
if eval_mode
|
||||
else self.training_args.query_budget_train
|
||||
)
|
||||
shuffle = False if eval_mode else True
|
||||
base_file_name = (
|
||||
f"attack-eval-{epoch}" if eval_mode else f"attack-train-{epoch}"
|
||||
)
|
||||
self._global_step = 0
|
||||
|
||||
def _generate_adversarial_examples(self, epoch):
|
||||
"""Generate adversarial examples using attacker."""
|
||||
base_file_name = f"attack-train-{epoch}"
|
||||
log_file_name = os.path.join(self.training_args.output_dir, base_file_name)
|
||||
logger.info("Attacking model to generate new adversarial training set...")
|
||||
|
||||
if isinstance(self.training_args.num_train_adv_examples, float):
|
||||
num_train_adv_examples = math.ceil(
|
||||
len(self.train_dataset) * self.training_args.num_train_adv_examples
|
||||
)
|
||||
else:
|
||||
num_train_adv_examples = self.training_args.num_train_adv_examples
|
||||
|
||||
attack_args = AttackArgs(
|
||||
num_examples=num_examples,
|
||||
num_successful_examples=num_train_adv_examples,
|
||||
num_examples_offset=0,
|
||||
query_budget=query_budget,
|
||||
shuffle=shuffle,
|
||||
attack_n=True,
|
||||
query_budget=self.training_args.query_budget_train,
|
||||
shuffle=True,
|
||||
parallel=self.training_args.parallel,
|
||||
num_workers_per_device=self.training_args.attack_num_workers_per_device,
|
||||
disable_stdout=True,
|
||||
silent=True,
|
||||
ignore_exceptions=True,
|
||||
log_to_txt=log_file_name + ".txt",
|
||||
log_to_csv=log_file_name + ".csv",
|
||||
)
|
||||
attacker = Attacker(self.attack, dataset, attack_args)
|
||||
|
||||
attacker = Attacker(self.attack, self.train_dataset, attack_args=attack_args)
|
||||
results = attacker.attack_dataset()
|
||||
|
||||
if eval_mode:
|
||||
return results
|
||||
else:
|
||||
attack_types = collections.Counter(r.__class__.__name__ for r in results)
|
||||
total_attacks = (
|
||||
attack_types["SuccessfulAttackResult"]
|
||||
+ attack_types["FailedAttackResult"]
|
||||
)
|
||||
success_rate = attack_types["SuccessfulAttackResult"] / total_attacks * 100
|
||||
logger.info(
|
||||
f"Attack Success Rate: {success_rate:.2f}% [{attack_types['SuccessfulAttackResult']} / {total_attacks}]"
|
||||
)
|
||||
adversarial_examples = [
|
||||
(
|
||||
tuple(r.perturbed_result.attacked_text._text_input.values()),
|
||||
r.perturbed_result.ground_truth_output,
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
adversarial_dataset = textattack.datasets.Dataset(
|
||||
adversarial_examples,
|
||||
input_columns=dataset.input_columns,
|
||||
label_map=dataset.label_map,
|
||||
label_names=dataset.label_names,
|
||||
output_scale_factor=dataset.output_scale_factor,
|
||||
shuffle=False,
|
||||
)
|
||||
return adversarial_dataset
|
||||
|
||||
def _training_setup(self):
|
||||
"""Handle all the training set ups including logging."""
|
||||
textattack.shared.utils.set_seed(self.training_args.random_seed)
|
||||
if not os.path.exists(self.training_args.output_dir):
|
||||
os.makedirs(self.training_args.output_dir)
|
||||
|
||||
# Save logger writes to file
|
||||
log_txt_path = os.path.join(self.training_args.output_dir, "train_log.txt")
|
||||
fh = logging.FileHandler(log_txt_path)
|
||||
fh.setLevel(logging.DEBUG)
|
||||
logger.addHandler(fh)
|
||||
logger.info(f"Writing logs to {log_txt_path}.")
|
||||
|
||||
# Save original self.training_args to file
|
||||
args_save_path = os.path.join(
|
||||
self.training_args.output_dir, "training_args.json"
|
||||
attack_types = collections.Counter(r.__class__.__name__ for r in results)
|
||||
total_attacks = (
|
||||
attack_types["SuccessfulAttackResult"] + attack_types["FailedAttackResult"]
|
||||
)
|
||||
with open(args_save_path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.training_args.__dict__, f)
|
||||
logger.info(f"Wrote original training args to {args_save_path}.")
|
||||
|
||||
def _print_training_args(self, total_training_steps):
|
||||
logger.info("==================== Running Training ====================")
|
||||
logger.info(f"Num epochs = {self.training_args.num_epochs}")
|
||||
logger.info(f"Num clean epochs = {self.training_args.num_clean_epochs}")
|
||||
logger.info(f"Num total steps = {total_training_steps}")
|
||||
logger.info(f"Num training examples = {len(self.train_dataset)}")
|
||||
logger.info(f"Num evaluation examples = {len(self.eval_dataset)}")
|
||||
logger.info(f"Starting learning rate = {self.training_args.learning_rate}")
|
||||
logger.info(f"Num warmup steps = {self.training_args.num_warmup_steps}")
|
||||
logger.info(f"Weight decay = {self.training_args.weight_decay}")
|
||||
|
||||
def _get_tensorboard_writer(self):
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
tb_writer = SummaryWriter(self.training_args.tb_log_dir)
|
||||
tb_writer.add_hparams(self.training_args.__dict__, {})
|
||||
tb_writer.flush()
|
||||
return tb_writer
|
||||
|
||||
def _init_wandb(self):
|
||||
global wandb
|
||||
import wandb
|
||||
|
||||
wandb.init(
|
||||
project=self.training_args.wand_project, config=self.training_args.__dict__
|
||||
success_rate = attack_types["SuccessfulAttackResult"] / total_attacks * 100
|
||||
logger.info(f"Total number of attack results: {len(results)}")
|
||||
logger.info(
|
||||
f"Attack success rate: {success_rate:.2f}% [{attack_types['SuccessfulAttackResult']} / {total_attacks}]"
|
||||
)
|
||||
adversarial_examples = [
|
||||
(
|
||||
tuple(r.perturbed_result.attacked_text._text_input.values()),
|
||||
r.perturbed_result.ground_truth_output,
|
||||
"adversarial_example",
|
||||
)
|
||||
for r in results
|
||||
if isinstance(r, (SuccessfulAttackResult, MaximizedAttackResult))
|
||||
]
|
||||
adversarial_dataset = textattack.datasets.Dataset(
|
||||
adversarial_examples,
|
||||
input_columns=self.train_dataset.input_columns,
|
||||
label_map=self.train_dataset.label_map,
|
||||
label_names=self.train_dataset.label_names,
|
||||
output_scale_factor=self.train_dataset.output_scale_factor,
|
||||
shuffle=False,
|
||||
)
|
||||
return adversarial_dataset
|
||||
|
||||
def _print_training_args(self, total_training_steps, train_batch_size):
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(self.train_dataset)}")
|
||||
logger.info(f" Num epochs = {self.training_args.num_epochs}")
|
||||
logger.info(f" Num clean epochs = {self.training_args.num_clean_epochs}")
|
||||
logger.info(
|
||||
f" Instantaneous batch size per device = {self.training_args.per_device_train_batch_size}"
|
||||
)
|
||||
logger.info(
|
||||
f" Total train batch size (w. parallel, distributed & accumulation) = {train_batch_size * self.training_args.gradient_accumulation_steps}"
|
||||
)
|
||||
logger.info(
|
||||
f" Gradient accumulation steps = {self.training_args.gradient_accumulation_steps}"
|
||||
)
|
||||
logger.info(f" Total optimization steps = {total_training_steps}")
|
||||
|
||||
def _save_model_checkpoint(
|
||||
self, model, tokenizer, step=None, epoch=None, best=False, last=False
|
||||
@@ -242,7 +198,36 @@ class Trainer:
|
||||
os.path.join(output_dir, "pytorch_model.bin"),
|
||||
)
|
||||
|
||||
def _get_optimizer_and_scheduler(self, model, total_training_steps):
|
||||
def _tb_log(self, log, step):
|
||||
if not hasattr(self, "_tb_writer"):
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
self._tb_writer = SummaryWriter(self.training_args.tb_log_dir)
|
||||
self._tb_writer.add_hparams(self.training_args.__dict__, {})
|
||||
self._tb_writer.flush()
|
||||
|
||||
for key in log:
|
||||
self._tb_writer.add_scalar(key, log[key], step)
|
||||
|
||||
self.tb_writer.flush()
|
||||
|
||||
def _wandb_log(self, log, step):
|
||||
if not hasattr(self, "_wandb_init"):
|
||||
global wandb
|
||||
import wandb
|
||||
|
||||
self._wandb_init = True
|
||||
wandb.init(
|
||||
project=self.training_args.wandb_project,
|
||||
config=self.training_args.__dict__,
|
||||
)
|
||||
|
||||
wandb.log(log, step=step)
|
||||
|
||||
def get_optimizer_and_scheduler(self, model, total_training_steps):
|
||||
if isinstance(model, torch.nn.DataParallel):
|
||||
model = model.module
|
||||
|
||||
if isinstance(model, transformers.PreTrainedModel):
|
||||
# Reference https://huggingface.co/transformers/training.html
|
||||
param_optimizer = list(model.named_parameters())
|
||||
@@ -267,10 +252,16 @@ class Trainer:
|
||||
optimizer = transformers.optimization.AdamW(
|
||||
optimizer_grouped_parameters, lr=self.training_args.learning_rate
|
||||
)
|
||||
if isinstance(self.training_args.num_warmup_steps, float):
|
||||
num_warmup_steps = math.ceil(
|
||||
self.training_args.num_warmup_steps * total_training_steps
|
||||
)
|
||||
else:
|
||||
num_warmup_steps = self.training_args.num_warmup_steps
|
||||
|
||||
scheduler = transformers.optimization.get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=self.training_args.num_warmup_steps,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=total_training_steps,
|
||||
)
|
||||
else:
|
||||
@@ -282,8 +273,162 @@ class Trainer:
|
||||
|
||||
return optimizer, scheduler
|
||||
|
||||
def get_train_dataloader(self, dataset, train_batch_size):
|
||||
# Helper functions for collating data
|
||||
def collate_fn(data):
|
||||
input_texts = []
|
||||
targets = []
|
||||
is_adv_sample = []
|
||||
for item in data:
|
||||
if len(item) == 3:
|
||||
# `len(item)` is 3 for adversarial training dataset
|
||||
_input, label, adv = item
|
||||
if adv != "adversarial_example":
|
||||
raise ValueError(
|
||||
"`item` has length of 3 but last element is not for marking if the item is an `adversarial example`."
|
||||
)
|
||||
else:
|
||||
is_adv_sample.append(True)
|
||||
else:
|
||||
# else `len(item)` is 2.
|
||||
_input, label = item
|
||||
is_adv_sample.append(False)
|
||||
|
||||
_input = tuple(_input[c] for c in dataset.input_columns)
|
||||
if len(_input) == 1:
|
||||
_input = _input[0]
|
||||
input_texts.append(_input)
|
||||
targets.append(label)
|
||||
|
||||
return input_texts, torch.tensor(targets), torch.tensor(is_adv_sample)
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=True,
|
||||
)
|
||||
return train_dataloader
|
||||
|
||||
def get_eval_dataloader(self, dataset, eval_batch_size):
|
||||
# Helper functions for collating data
|
||||
def collate_fn(data):
|
||||
input_texts = []
|
||||
targets = []
|
||||
for _input, label in data:
|
||||
_input = tuple(_input[c] for c in dataset.input_columns)
|
||||
if len(_input) == 1:
|
||||
_input = _input[0]
|
||||
input_texts.append(_input)
|
||||
targets.append(label)
|
||||
return input_texts, torch.tensor(targets)
|
||||
|
||||
eval_dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=eval_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=True,
|
||||
)
|
||||
return eval_dataloader
|
||||
|
||||
def training_step(self, model, tokenizer, batch):
|
||||
"""
|
||||
Args:
|
||||
model (:obj:`torch.nn.Module`):
|
||||
Model to train.
|
||||
tokenizer:
|
||||
Tokenizer used to tokenize input text.
|
||||
batch (:obj:`tuple[list[str], torch.Tensor, torch.Tensor]`):
|
||||
Tuple of input texts, targets, and boolean tensor indicating if the sample is an adversarial example.
|
||||
"""
|
||||
|
||||
input_texts, targets, is_adv_sample = batch
|
||||
_targets = targets
|
||||
targets = targets.to(textattack.shared.utils.device)
|
||||
|
||||
if isinstance(model, transformers.PreTrainedModel) or (isinstance(model, torch.nn.DataParallel) and isinstance(model.module, transformers.PreTrainedModel)):
|
||||
input_ids = tokenizer(
|
||||
input_texts,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
)
|
||||
input_ids.to(textattack.shared.utils.device)
|
||||
logits = model(**input_ids)[0]
|
||||
else:
|
||||
input_ids = tokenizer(input_texts)
|
||||
if not isinstance(input_ids, torch.Tensor):
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.to(textattack.shared.utils.device)
|
||||
logits = model(input_ids)
|
||||
|
||||
if self.task_type == "regression":
|
||||
loss = self.loss_fct(logits.squeeze(), targets.squeeze())
|
||||
preds = logits
|
||||
else:
|
||||
loss = self.loss_fct(logits, targets)
|
||||
preds = logits.argmax(dim=-1)
|
||||
|
||||
sample_weights = torch.ones(
|
||||
is_adv_sample.size(), device=textattack.shared.utils.device
|
||||
)
|
||||
sample_weights[is_adv_sample] *= self.training_args.alpha
|
||||
loss = loss * sample_weights
|
||||
loss = torch.mean(loss)
|
||||
preds = preds.cpu()
|
||||
|
||||
return loss, preds, _targets
|
||||
|
||||
def evaluate_step(self, model, tokenizer, batch):
|
||||
input_texts, targets = batch
|
||||
_targets = targets
|
||||
targets = targets.to(textattack.shared.utils.device)
|
||||
|
||||
if isinstance(model, transformers.PreTrainedModel):
|
||||
input_ids = tokenizer(
|
||||
input_texts,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
)
|
||||
input_ids.to(textattack.shared.utils.device)
|
||||
logits = model(**input_ids)[0]
|
||||
else:
|
||||
input_ids = tokenizer(input_texts)
|
||||
if not isinstance(input_ids, torch.Tensor):
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.to(textattack.shared.utils.device)
|
||||
logits = model(input_ids)
|
||||
|
||||
if self.task_type == "regression":
|
||||
preds = logits
|
||||
else:
|
||||
preds = logits.argmax(dim=-1)
|
||||
|
||||
return preds.cpu(), _targets
|
||||
|
||||
def train(self):
|
||||
self._training_setup()
|
||||
textattack.shared.utils.set_seed(self.training_args.random_seed)
|
||||
if not os.path.exists(self.training_args.output_dir):
|
||||
os.makedirs(self.training_args.output_dir)
|
||||
|
||||
# Save logger writes to file
|
||||
log_txt_path = os.path.join(self.training_args.output_dir, "train_log.txt")
|
||||
fh = logging.FileHandler(log_txt_path)
|
||||
fh.setLevel(logging.DEBUG)
|
||||
logger.addHandler(fh)
|
||||
logger.info(f"Writing logs to {log_txt_path}.")
|
||||
|
||||
# Save original self.training_args to file
|
||||
args_save_path = os.path.join(
|
||||
self.training_args.output_dir, "training_args.json"
|
||||
)
|
||||
with open(args_save_path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.training_args.__dict__, f)
|
||||
logger.info(f"Wrote original training args to {args_save_path}.")
|
||||
|
||||
num_gpus = torch.cuda.device_count()
|
||||
tokenizer = self.model_wrapper.tokenizer
|
||||
model = self.model_wrapper.model
|
||||
@@ -297,41 +442,39 @@ class Trainer:
|
||||
else:
|
||||
train_batch_size = self.training_args.per_device_train_batch_size
|
||||
|
||||
model.to(textattack.shared.utils.device)
|
||||
total_training_steps = (
|
||||
total_clean_training_steps = (
|
||||
math.ceil(
|
||||
len(self.train_dataset)
|
||||
/ (train_batch_size * self.training_args.gradient_accumulation_steps)
|
||||
)
|
||||
* self.training_args.num_epochs
|
||||
* self.training_args.num_clean_epochs
|
||||
)
|
||||
total_adv_training_steps = math.ceil(
|
||||
(len(self.train_dataset) + self.training_args.num_train_adv_examples)
|
||||
/ (train_batch_size * self.training_args.gradient_accumulation_steps)
|
||||
) * (self.training_args.num_epochs - self.training_args.num_clean_epochs)
|
||||
|
||||
if self.training_args.log_to_tb:
|
||||
tb_writer = self._get_tensorboard_writer()
|
||||
if self.training_args.log_to_wandb:
|
||||
self._init_wandb()
|
||||
total_training_steps = total_clean_training_steps + total_adv_training_steps
|
||||
self._print_training_args(total_training_steps, train_batch_size)
|
||||
|
||||
optimizer, scheduler = self._get_optimizer_and_scheduler(
|
||||
optimizer, scheduler = self.get_optimizer_and_scheduler(
|
||||
model, total_training_steps
|
||||
)
|
||||
|
||||
collate_func = functools.partial(collate_fn, self.train_dataset.input_columns)
|
||||
|
||||
if self.task_type == "regression":
|
||||
loss_fct = torch.nn.MSELoss()
|
||||
else:
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
model.to(textattack.shared.utils.device)
|
||||
|
||||
# Variables across epochs
|
||||
global_step = 0
|
||||
total_loss = 0.0
|
||||
self._total_loss = 0.0
|
||||
self._current_loss = 0.0
|
||||
self._last_log_step = 0
|
||||
|
||||
# `best_score` is used to keep track of the best model across training.
|
||||
# Could be loss, accuracy, or other metrics.
|
||||
best_eval_score = 0.0
|
||||
best_eval_score_epoch = 0
|
||||
best_model_path = None
|
||||
epochs_since_best_eval_score = 0
|
||||
self._print_training_args(total_training_steps)
|
||||
|
||||
for epoch in range(1, self.training_args.num_epochs + 1):
|
||||
logger.info("==========================================================")
|
||||
logger.info(f"Epoch {epoch}")
|
||||
@@ -344,14 +487,10 @@ class Trainer:
|
||||
# after the clean epochs
|
||||
# adv_example_dataset is instance of `textattack.datasets.Dataset
|
||||
model.eval()
|
||||
model.cpu()
|
||||
adv_example_dataset = self._generate_adversarial_examples(
|
||||
self.train_dataset, epoch
|
||||
)
|
||||
adv_example_dataset = self._generate_adversarial_examples(epoch)
|
||||
train_dataset = torch.utils.data.ConcatDataset(
|
||||
[self.train_dataset, adv_example_dataset]
|
||||
)
|
||||
model.to(textattack.shared.utils.device)
|
||||
model.train()
|
||||
else:
|
||||
train_dataset = self.train_dataset
|
||||
@@ -361,116 +500,123 @@ class Trainer:
|
||||
)
|
||||
train_dataset = self.train_dataset
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collate_func,
|
||||
train_dataloader = self.get_train_dataloader(
|
||||
train_dataset, train_batch_size
|
||||
)
|
||||
model.train()
|
||||
# Epoch-specific variables
|
||||
correct_predictions = 0
|
||||
total_predictions = 0
|
||||
# Epoch variables
|
||||
all_preds = []
|
||||
all_targets = []
|
||||
prog_bar = tqdm.tqdm(
|
||||
train_dataloader, desc="Iteration", position=0, leave=True
|
||||
)
|
||||
for step, batch in enumerate(prog_bar):
|
||||
input_texts, labels = batch
|
||||
labels = labels.to(textattack.shared.utils.device)
|
||||
if isinstance(model, transformers.PreTrainedModel):
|
||||
input_ids = tokenizer(
|
||||
input_texts,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
)
|
||||
for key in input_ids:
|
||||
if isinstance(input_ids[key], torch.Tensor):
|
||||
input_ids[key] = input_ids[key].to(
|
||||
textattack.shared.utils.device
|
||||
)
|
||||
logits = model(**input_ids)[0]
|
||||
else:
|
||||
input_ids = tokenizer(input_texts)
|
||||
if not isinstance(input_ids, torch.Tensor):
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.to(textattack.shared.utils.device)
|
||||
logits = model(input_ids)
|
||||
|
||||
if self.task_type == "regression":
|
||||
# TODO integrate with textattack `metrics` package
|
||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(logits, labels)
|
||||
pred_labels = logits.argmax(dim=-1)
|
||||
correct_predictions += (pred_labels == labels).sum().item()
|
||||
total_predictions += len(pred_labels)
|
||||
loss, preds, targets = self.training_step(model, tokenizer, batch)
|
||||
|
||||
if isinstance(model, torch.nn.DataParallel):
|
||||
loss = loss.mean()
|
||||
if self.training_args.gradient_accumulation_steps > 1:
|
||||
loss = loss / self.training_args.gradient_accumulation_steps
|
||||
loss.backward()
|
||||
|
||||
total_loss += loss.item()
|
||||
loss = loss / self.training_args.gradient_accumulation_steps
|
||||
loss.backward()
|
||||
loss = loss.item()
|
||||
self._total_loss += loss
|
||||
self._current_loss += loss
|
||||
|
||||
all_preds.append(preds)
|
||||
all_targets.append(targets)
|
||||
|
||||
if (step + 1) % self.training_args.gradient_accumulation_steps == 0:
|
||||
optimizer.step()
|
||||
if scheduler:
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
self._global_step += 1
|
||||
|
||||
if self._global_step > 0:
|
||||
prog_bar.set_description(
|
||||
f"Loss {self._total_loss/self._global_step:.5f}"
|
||||
)
|
||||
|
||||
# TODO: Better way to handle TB and Wandb logging
|
||||
if global_step % self.training_args.logging_interval_step == 0:
|
||||
if (self._global_step > 0) and (
|
||||
self._global_step % self.training_args.logging_interval_step == 0
|
||||
):
|
||||
lr_to_log = (
|
||||
scheduler.get_last_lr()[0]
|
||||
if scheduler
|
||||
else self.training_args.learning_rate
|
||||
)
|
||||
if self._global_step - self._last_log_step >= 1:
|
||||
loss_to_log = round(
|
||||
self._current_loss
|
||||
/ (self._global_step - self._last_log_step),
|
||||
4,
|
||||
)
|
||||
else:
|
||||
loss_to_log = round(self._current_loss, 4)
|
||||
|
||||
log = {"train/loss": loss_to_log, "train/learning_rate": lr_to_log}
|
||||
if self.training_args.log_to_tb:
|
||||
tb_writer.add_scalar("loss", loss.item(), global_step)
|
||||
tb_writer.add_scalar("lr", lr_to_log, global_step)
|
||||
self._tb_log(log, self._global_step)
|
||||
|
||||
if self.training_args.log_to_wandb:
|
||||
wandb.log({"loss": loss.item()}, step=global_step)
|
||||
wandb.log({"lr": lr_to_log}, step=global_step)
|
||||
self._wandb_log(log, self._global_step)
|
||||
|
||||
if global_step > 0:
|
||||
prog_bar.set_description(f"Loss {total_loss/global_step:.5f}")
|
||||
self._current_loss = 0.0
|
||||
self._last_log_step = self._global_step
|
||||
|
||||
# Save model checkpoint to file.
|
||||
if self.training_args.checkpoint_interval_steps:
|
||||
if (
|
||||
global_step > 0
|
||||
and self.training_args.checkpoint_interval_steps > 0
|
||||
and (global_step % self.training_args.checkpoint_interval_steps)
|
||||
self._global_step > 0
|
||||
and (
|
||||
self._global_step
|
||||
% self.training_args.checkpoint_interval_steps
|
||||
)
|
||||
== 0
|
||||
):
|
||||
self._save_model_checkpoint(model, tokenizer, step=global_step)
|
||||
self._save_model_checkpoint(
|
||||
model, tokenizer, step=self._global_step
|
||||
)
|
||||
|
||||
# Print training accuracy, if we're tracking it.
|
||||
if total_predictions > 0:
|
||||
train_acc = correct_predictions / total_predictions
|
||||
logger.info(f"Train accuracy: {train_acc*100:.2f}%")
|
||||
preds = torch.cat(all_preds)
|
||||
targets = torch.cat(all_targets)
|
||||
if self._metric_name == "accuracy":
|
||||
correct_predictions = (preds == targets).sum().item()
|
||||
accuracy = correct_predictions / len(targets)
|
||||
metric_log = {"train/train_accuracy": accuracy}
|
||||
logger.info(f"Train accuracy: {accuracy*100:.2f}%")
|
||||
else:
|
||||
pearson_correlation, pearson_pvalue = scipy.stats.pearsonr(
|
||||
preds, targets
|
||||
)
|
||||
metric_log = {
|
||||
"train/pearson_correlation": pearson_correlation,
|
||||
"train/pearson_pvalue": pearson_pvalue,
|
||||
}
|
||||
logger.info(f"Train Pearson correlation: {pearson_correlation:.4f}%")
|
||||
|
||||
if len(targets) > 0:
|
||||
if self.training_args.log_to_tb:
|
||||
tb_writer.add_scalar("epoch_train_acc", train_acc, global_step)
|
||||
self._tb_log(metric_log, epoch)
|
||||
if self.training_args.log_to_wandb:
|
||||
wandb.log({"epoch_train_acc": train_acc}, step=global_step)
|
||||
metric_log["epoch"] = epoch
|
||||
self._wandb_log(metric_log, self._global_step)
|
||||
|
||||
# Evaluate after each epoch.
|
||||
eval_score = self.evaluate()
|
||||
|
||||
# Check eval accuracy after each epoch.
|
||||
eval_score = self._evaluate(model, tokenizer)
|
||||
logger.info(
|
||||
f"Eval {'pearson correlation' if self.task_type == 'regression' else 'accuracy'}: {eval_score*100:.2f}%"
|
||||
)
|
||||
if self.training_args.log_to_tb:
|
||||
tb_writer.add_scalar("epoch_eval_score", eval_score, global_step)
|
||||
self._tb_log({f"eval/{self._metric_name}": eval_score}, epoch)
|
||||
if self.training_args.log_to_wandb:
|
||||
wandb.log({"epoch_eval_score": eval_score}, step=global_step)
|
||||
self._wandb_log(
|
||||
{f"eval/{self._metric_name}": eval_score, "epoch": epoch},
|
||||
self._global_step,
|
||||
)
|
||||
|
||||
if (
|
||||
self.training_args.checkpoint_interval_epochs
|
||||
and ((epoch - 1) % self.training_args.checkpoint_interval_epochs) == 0
|
||||
and (epoch % self.training_args.checkpoint_interval_epochs) == 0
|
||||
):
|
||||
self._save_model_checkpoint(model, tokenizer, epoch=epoch)
|
||||
|
||||
@@ -480,7 +626,7 @@ class Trainer:
|
||||
epochs_since_best_eval_score = 0
|
||||
self._save_model_checkpoint(model, tokenizer, best=True)
|
||||
logger.info(
|
||||
f"Best acc found. Saved model to {self.training_args.output_dir}/best_model/"
|
||||
f"Best score found. Saved model to {self.training_args.output_dir}/best_model/"
|
||||
)
|
||||
else:
|
||||
epochs_since_best_eval_score += 1
|
||||
@@ -493,64 +639,7 @@ class Trainer:
|
||||
)
|
||||
break
|
||||
|
||||
if self.training_args.eval_adversarial_robustness and (
|
||||
epoch >= self.training_args.num_clean_epochs
|
||||
):
|
||||
# Evaluate adversarial robustness
|
||||
model.eval()
|
||||
model.cpu()
|
||||
adv_attack_results = self._generate_adversarial_examples(
|
||||
self.eval_dataset, epoch, eval_mode=True
|
||||
)
|
||||
model.to(textattack.shared.utils.device)
|
||||
model.train()
|
||||
attack_types = [r.__class__.__name__ for r in adv_attack_results]
|
||||
attack_types = collections.Counter(attack_types)
|
||||
total_attacks = (
|
||||
attack_types["SuccessfulAttackResult"]
|
||||
+ attack_types["FailedAttackResult"]
|
||||
)
|
||||
adv_succ_rate = attack_types["SuccessfulAttackResult"] / total_attacks
|
||||
num_queries = np.array(
|
||||
[
|
||||
r.num_queries
|
||||
for r in adv_attack_results
|
||||
if not isinstance(
|
||||
r, textattack.attack_results.SkippedAttackResult
|
||||
)
|
||||
]
|
||||
)
|
||||
avg_num_queries = round(num_queries.mean(), 2)
|
||||
|
||||
if self.training_args.log_to_tb:
|
||||
tb_writer.add_scalar(
|
||||
"robustness_total_attacks", total_attacks, global_step
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"robustness_attack_succ_rate", adv_succ_rate, global_step
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"robustness_avg_num_queries", avg_num_queries, global_step
|
||||
)
|
||||
if self.training_args.log_to_wandb:
|
||||
wandb.log(
|
||||
{"robustness_total_attacks": total_attacks}, step=global_step
|
||||
)
|
||||
wandb.log(
|
||||
{"robustness_attack_succ_rate": adv_succ_rate}, step=global_step
|
||||
)
|
||||
wandb.log(
|
||||
{"robustness_avg_num_queries": avg_num_queries},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
logger.info(f"Eval total attack: {total_attacks}")
|
||||
logger.info(f"Eval attack success rate: {100*adv_succ_rate:.2f}%")
|
||||
logger.info(f"Eval avg num queries: {avg_num_queries}")
|
||||
|
||||
if self.training_args.log_to_tb:
|
||||
tb_writer.flush()
|
||||
|
||||
# Finish training
|
||||
if isinstance(model, torch.nn.DataParallel):
|
||||
model = model.module
|
||||
|
||||
@@ -560,20 +649,23 @@ class Trainer:
|
||||
model = model.__class__.from_pretrained(best_model_path)
|
||||
else:
|
||||
model = model.load_state_dict(
|
||||
torch.load(os.path.join(best_model_path, "model.pt"))
|
||||
torch.load(os.path.join(best_model_path, "pytorch_model.bin"))
|
||||
)
|
||||
|
||||
if self.training_args.save_last:
|
||||
self._save_model_checkpoint(model, tokenizer, last=True)
|
||||
|
||||
self.model_wrapper.model = model
|
||||
self.write_readme(best_eval_score, best_eval_score_epoch, train_batch_size)
|
||||
self._write_readme(best_eval_score, best_eval_score_epoch, train_batch_size)
|
||||
|
||||
def evaluate(self):
|
||||
logging.info("Evaluating model on evaluation dataset.")
|
||||
model = self.model_wrapper.model
|
||||
tokenizer = self.model_wrapper.tokenizer
|
||||
|
||||
def _evaluate(self, model, tokenizer):
|
||||
model.eval()
|
||||
correct = 0
|
||||
logits = []
|
||||
labels = []
|
||||
all_preds = []
|
||||
all_targets = []
|
||||
|
||||
if isinstance(model, torch.nn.DataParallel):
|
||||
num_gpus = torch.cuda.device_count()
|
||||
@@ -581,58 +673,33 @@ class Trainer:
|
||||
else:
|
||||
eval_batch_size = self.training_args.per_device_eval_batch_size
|
||||
|
||||
collate_func = functools.partial(collate_fn, self.eval_dataset.input_columns)
|
||||
eval_dataloader = torch.utils.data.DataLoader(
|
||||
self.eval_dataset,
|
||||
batch_size=eval_batch_size,
|
||||
collate_fn=collate_func,
|
||||
)
|
||||
eval_dataloader = self.get_eval_dataloader(self.eval_dataset, eval_batch_size)
|
||||
|
||||
with torch.no_grad():
|
||||
for input_texts, batch_labels in eval_dataloader:
|
||||
batch_labels = batch_labels.to(textattack.shared.utils.device)
|
||||
if isinstance(model, transformers.PreTrainedModel):
|
||||
input_ids = tokenizer(
|
||||
input_texts,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
)
|
||||
for key in input_ids:
|
||||
if isinstance(input_ids[key], torch.Tensor):
|
||||
input_ids[key] = input_ids[key].to(
|
||||
textattack.shared.utils.device
|
||||
)
|
||||
batch_logits = model(**input_ids)[0]
|
||||
else:
|
||||
input_ids = tokenizer(input_texts)
|
||||
if not isinstance(input_ids, torch.Tensor):
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.to(textattack.shared.utils.device)
|
||||
batch_logits = model(input_ids)
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
preds, targets = self.evaluate_step(model, tokenizer, batch)
|
||||
all_preds.append(preds)
|
||||
all_targets.append(targets)
|
||||
|
||||
logits.extend(batch_logits.cpu().squeeze().tolist())
|
||||
labels.extend(batch_labels)
|
||||
|
||||
model.train()
|
||||
logits = torch.tensor(logits)
|
||||
labels = torch.tensor(labels)
|
||||
preds = torch.cat(all_preds)
|
||||
targets = torch.cat(all_targets)
|
||||
|
||||
if self.task_type == "regression":
|
||||
pearson_correlation, pearson_p_value = scipy.stats.pearsonr(logits, labels)
|
||||
return pearson_correlation
|
||||
pearson_correlation, pearson_p_value = scipy.stats.pearsonr(preds, targets)
|
||||
eval_score = pearson_correlation
|
||||
else:
|
||||
preds = logits.argmax(dim=1)
|
||||
correct = (preds == labels).sum()
|
||||
return float(correct) / len(labels)
|
||||
correct_predictions = (preds == targets).sum().item()
|
||||
accuracy = correct_predictions / len(targets)
|
||||
eval_score = accuracy
|
||||
|
||||
def evaluate(self):
|
||||
logging.info("Evaluating model on evaluation dataset.")
|
||||
model = self.model_wrapper.model
|
||||
tokenizer = self.model_wrapper.tokenizer
|
||||
return self._evaluate(model, tokenizer)
|
||||
if self._metric_name == "accuracy":
|
||||
logger.info(f"Eval {self._metric_name}: {eval_score*100:.2f}%")
|
||||
else:
|
||||
logger.info(f"Eval {self._metric_name}: {eval_score:.4f}%")
|
||||
|
||||
def write_readme(self, best_eval_score, best_eval_score_epoch, train_batch_size):
|
||||
return eval_score
|
||||
|
||||
def _write_readme(self, best_eval_score, best_eval_score_epoch, train_batch_size):
|
||||
if isinstance(self.training_args, CommandLineTrainingArgs):
|
||||
model_name = self.training_args.model_name_or_path
|
||||
elif isinstance(self.model_wrapper.model, transformers.PreTrainedModel):
|
||||
@@ -658,7 +725,7 @@ class Trainer:
|
||||
):
|
||||
model_max_length = self.training_args.model_max_length
|
||||
elif isinstance(self.model_wrapper.model, transformers.PreTrainedModel):
|
||||
model_max_length = self.model_wrapper.config.max_position_embeddings
|
||||
model_max_length = self.model_wrapper.model.config.max_position_embeddings
|
||||
elif isinstance(
|
||||
self.model_wrapper.model, (LSTMForClassification, WordCNNForClassification)
|
||||
):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
import datetime
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
from textattack.datasets import HuggingFaceDataset
|
||||
from textattack.models.helpers import LSTMForClassification, WordCNNForClassification
|
||||
@@ -15,8 +16,6 @@ from textattack.shared.utils import ARGS_SPLIT_TOKEN
|
||||
from .attack import Attack
|
||||
from .attack_args import ATTACK_RECIPE_NAMES
|
||||
|
||||
# TODO Add `metric_for_best_model` argument. Currently we just use accuracy for classification and MSE for regression by default.
|
||||
|
||||
|
||||
def default_output_dir():
|
||||
return os.path.join(
|
||||
@@ -26,47 +25,72 @@ def default_output_dir():
|
||||
|
||||
@dataclass
|
||||
class TrainingArgs:
|
||||
"""Args for TextAttack ``Trainer`` class that is used for running
|
||||
adversarial training.
|
||||
"""Arguments for ``Trainer`` class that is used for adversarial training.
|
||||
|
||||
Args:
|
||||
num_epochs (int): Total number of epochs for training. Default is 5.
|
||||
num_clean_epochs (int): Number of epochs to train on just the original training dataset before adversarial training. Default is 0.
|
||||
attack_epoch_interval (int): Generate a new adversarial training set every N epochs. Default is 1.
|
||||
early_stopping_epochs (int): Number of epochs validation must increase before stopping early (-1 for no early stopping). Default is `None`.
|
||||
learning_rate (float): Learning rate for Adam Optimization. Default is 2e-5.
|
||||
num_warmup_steps (int): The number of steps for the warmup phase of linear scheduler. Default is 500.
|
||||
weight_decay (float): Weight decay (L2 penalty). Default is 0.01.
|
||||
per_device_train_batch_size (int): The batch size per GPU/CPU for training. Default is 8.
|
||||
per_device_eval_batch_size (int): The batch size per GPU/CPU for evaluation. Default is 32.
|
||||
gradient_accumulation_steps (int): Number of updates steps to accumulate the gradients for, before performing a backward/update pass. Default is 1.
|
||||
random_seed (int): Random seed. Default is 786.
|
||||
parallel (bool): If `True`, train using multiple GPUs. Default is `False`.
|
||||
load_best_model_at_end (bool): If `True`, keep track of the best model across training and load it at the end.
|
||||
eval_adversarial_robustness (bool): If set, evaluate adversarial robustness on evaluation dataset after every epoch.
|
||||
num_eval_adv_examples (int): The number of samples attack if `eval_adversarial_robustness=True`. Default is 1000.
|
||||
num_train_adv_examples (int): The number of samples to attack when generating adversarial training set. Default is -1 (which is all possible samples).
|
||||
query_budget_train (:obj:`int`, optional): The max query budget to use when generating adversarial training set.
|
||||
query_budget_eval (:obj:`int`, optional): The max query budget to use when evaluating adversarial robustness.
|
||||
attack_num_workers_per_device (int): Number of worker processes to run per device for attack. Same as `num_workers_per_device` argument for `AttackArgs`.
|
||||
output_dir (str): Directory to output training logs and checkpoints.
|
||||
checkpoint_interval_steps (int): Save model checkpoint after every N updates to the model.
|
||||
checkpoint_interval_epochs (int): Save model checkpoint after every N epochs
|
||||
save_last (bool): If `True`, save the model at end of training. Can be used with `load_best_model_at_end` to save the best model at the end. Default is `True`.
|
||||
log_to_tb (bool): If `True`, log to Tensorboard. Default is `False`
|
||||
tb_log_dir (str): Path of Tensorboard log directory.
|
||||
log_to_wandb (bool): If `True`, log to Wandb. Default is `False`.
|
||||
wandb_project (str): Name of Wandb project for logging. Default is `textattack`.
|
||||
logging_interval_step (int): Log to Tensorboard/Wandb every N steps.
|
||||
num_epochs (:obj:`int`, 'optional`, defaults to :obj:`3`):
|
||||
Total number of epochs for training.
|
||||
num_clean_epochs (:obj:`int`, 'optional`, defaults to :obj:`1`):
|
||||
Number of epochs to train on just the original training dataset before adversarial training.
|
||||
attack_epoch_interval (:obj:`int`, 'optional`, defaults to :obj:`1`):
|
||||
Generate a new adversarial training set every N epochs.
|
||||
early_stopping_epochs (:obj:`int`, 'optional`, defaults to :obj:`None`):
|
||||
Number of epochs validation must increase before stopping early (`None` for no early stopping).
|
||||
learning_rate (:obj:`float`, 'optional`, defaults to :obj:`5e-5`):
|
||||
Learning rate for optimizer.
|
||||
num_warmup_steps (:obj:`int` or :obj:`float`, `optional`, defaults to :obj:`500`):
|
||||
The number of steps for the warmup phase of linear scheduler.
|
||||
If `num_warmup_steps` is a `float` between 0 and 1, the number of warmup steps will be `math.ceil(num_training_steps * num_warmup_steps)`.
|
||||
weight_decay (:obj:`float`, `optional`, defaults to :obj:`0.01`):
|
||||
Weight decay (L2 penalty).
|
||||
per_device_train_batch_size (:obj:`int`, `optional`, defaults to :obj:`8`):
|
||||
The batch size per GPU/CPU for training.
|
||||
per_device_eval_batch_size (:obj:`int`, `optional`, defaults to :obj:`32`):
|
||||
The batch size per GPU/CPU for evaluation.
|
||||
gradient_accumulation_steps (:obj:`int`, `optional`, defaults to :obj:`1`):
|
||||
Number of updates steps to accumulate the gradients before performing a backward/update pass.
|
||||
random_seed (:obj:`int`, `optional`, defaults to :obj:`786`):
|
||||
Random seed for reproducibility.
|
||||
parallel (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If `True`, train using multiple GPUs using `torch.DataParallel`.
|
||||
load_best_model_at_end (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If `True`, keep track of the best model across training and load it at the end.
|
||||
alpha (:obj:`float`, `optional`, defaults to :obj:`1.0`):
|
||||
The weight for adversarial loss.
|
||||
num_train_adv_examples (:obj:`int` or :obj:`float`, `optional`, defaults to :obj:`-1`):
|
||||
The number of samples to successfully attack when generating adversarial training set before start of every epoch.
|
||||
If `num_train_adv_examples` is a `float` between 0 and 1, the number of adversarial examples generated is
|
||||
fraction of the original training set.
|
||||
query_budget_train (:obj:`int`, `optional`, defaults to :obj:`None`):
|
||||
The max query budget to use when generating adversarial training set. `None` means infinite query budget.
|
||||
attack_num_workers_per_device (:obj:`int`, defaults to `optional`, :obj:`1`):
|
||||
Number of worker processes to run per device for attack. Same as `num_workers_per_device` argument for `AttackArgs`.
|
||||
output_dir (:obj:`str`, `optional`):
|
||||
Directory to output training logs and checkpoints. Defaults to `./outputs/%Y-%m-%d-%H-%M-%S-%f` format.
|
||||
checkpoint_interval_steps (:obj:`int`, `optional`, defaults to :obj:`None`):
|
||||
If set, save model checkpoint after every N updates to the model.
|
||||
checkpoint_interval_epochs (:obj:`int`, `optional`, defaults to :obj:`None`):
|
||||
If set, save model checkpoint after every N epochs
|
||||
save_last (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
If `True`, save the model at end of training. Can be used with `load_best_model_at_end` to save the best model at the end.
|
||||
log_to_tb (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If `True`, log to Tensorboard.
|
||||
tb_log_dir (:obj:`str`, `optional`, defaults to :obj:`"./runs"`):
|
||||
Path of Tensorboard log directory.
|
||||
log_to_wandb (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If `True`, log to Wandb.
|
||||
wandb_project (:obj:`str`, `optional`, defaults to :obj:`"textattack"`):
|
||||
Name of Wandb project for logging.
|
||||
logging_interval_step (:obj: `int`, `optional`, defaults to :obj:`1`):
|
||||
Log to Tensorboard/Wandb every N training steps.
|
||||
"""
|
||||
|
||||
num_epochs: int = 5
|
||||
num_clean_epochs: int = 0
|
||||
num_epochs: int = 3
|
||||
num_clean_epochs: int = 1
|
||||
attack_epoch_interval: int = 1
|
||||
early_stopping_epochs: int = None
|
||||
learning_rate: float = 5e-5
|
||||
lr: float = None # alternative keyword arg for learning_rate
|
||||
num_warmup_steps: int = 500
|
||||
num_warmup_steps: Union[int, float] = 500
|
||||
weight_decay: float = 0.01
|
||||
per_device_train_batch_size: int = 8
|
||||
per_device_eval_batch_size: int = 32
|
||||
@@ -74,11 +98,9 @@ class TrainingArgs:
|
||||
random_seed: int = 786
|
||||
parallel: bool = False
|
||||
load_best_model_at_end: bool = False
|
||||
eval_adversarial_robustness: bool = False
|
||||
num_eval_adv_examples: int = 1000
|
||||
num_train_adv_examples: int = -1
|
||||
alpha: float = 1.0
|
||||
num_train_adv_examples: Union[int, float] = -1
|
||||
query_budget_train: int = None
|
||||
query_budget_eval: int = None
|
||||
attack_num_workers_per_device: int = 1
|
||||
output_dir: str = field(default_factory=default_output_dir)
|
||||
checkpoint_interval_steps: int = None
|
||||
@@ -91,8 +113,6 @@ class TrainingArgs:
|
||||
logging_interval_step: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
if self.lr:
|
||||
self.learning_rate = self.lr
|
||||
assert self.num_epochs > 0, "`num_epochs` must be greater than 0."
|
||||
assert (
|
||||
self.num_clean_epochs >= 0
|
||||
@@ -101,16 +121,13 @@ class TrainingArgs:
|
||||
assert (
|
||||
self.early_stopping_epochs > 0
|
||||
), "`early_stopping_epochs` must be greater than 0."
|
||||
if self.attack_epoch_interval is not None:
|
||||
assert (
|
||||
self.attack_epoch_interval > 0
|
||||
), "`attack_epoch_interval` must be greater than 0."
|
||||
assert (
|
||||
self.attack_epoch_interval > 0
|
||||
), "`attack_epoch_interval` must be greater than 0."
|
||||
assert self.num_warmup_steps > 0, "`num_warmup_steps` must be greater than 0."
|
||||
assert (
|
||||
self.num_train_adv_examples > 0 or self.num_train_adv_examples == -1
|
||||
), "`num_train_adv_examples` must be greater than 0 or equal to -1."
|
||||
assert (
|
||||
self.num_eval_adv_examples > 0 or self.num_eval_adv_examples == -1
|
||||
), "`num_eval_adv_examples` must be greater than 0 or equal to -1."
|
||||
self.num_warmup_steps >= 0
|
||||
), "`num_warmup_steps` must be greater than or equal to 0."
|
||||
assert (
|
||||
self.gradient_accumulation_steps > 0
|
||||
), "`gradient_accumulation_steps` must be greater than 0."
|
||||
@@ -118,117 +135,130 @@ class TrainingArgs:
|
||||
self.num_clean_epochs <= self.num_epochs
|
||||
), f"`num_clean_epochs` cannot be greater than `num_epochs` ({self.num_clean_epochs} > {self.num_epochs})."
|
||||
|
||||
if isinstance(self.num_train_adv_examples, float):
|
||||
assert (
|
||||
self.num_train_adv_examples >= 0.0
|
||||
and self.num_train_adv_examples <= 1.0
|
||||
), "If `num_train_adv_examples` is float, it must be between 0 and 1."
|
||||
elif isinstance(self.num_train_adv_examples, int):
|
||||
assert (
|
||||
self.num_train_adv_examples > 0 or self.num_train_adv_examples == -1
|
||||
), "If `num_train_adv_examples` is int, it must be greater than 0 or equal to -1."
|
||||
else:
|
||||
raise TypeError("`num_train_adv_examples` must be of either type `int` or `float`.")
|
||||
|
||||
@classmethod
|
||||
def add_parser_args(cls, parser):
|
||||
"""Add listed args to command line parser."""
|
||||
default_obj = cls()
|
||||
|
||||
def int_or_float(v):
|
||||
try:
|
||||
return int(v)
|
||||
except ValueError:
|
||||
return float(v)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=4,
|
||||
default=default_obj.num_epochs,
|
||||
help="Total number of epochs for training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-clean-epochs",
|
||||
type=int,
|
||||
default=0,
|
||||
default=default_obj.num_clean_epochs,
|
||||
help="Number of epochs to train on the clean dataset before adversarial training (N/A if --attack unspecified)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attack-epoch-interval",
|
||||
type=int,
|
||||
default=1,
|
||||
default=default_obj.attack_epoch_interval,
|
||||
help="Generate a new adversarial training set every N epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--early-stopping-epochs",
|
||||
type=int,
|
||||
default=None,
|
||||
default=default_obj.early_stopping_epochs,
|
||||
help="Number of epochs validation must increase before stopping early (-1 for no early stopping)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning-rate",
|
||||
"--lr",
|
||||
type=float,
|
||||
default=5e-5,
|
||||
default=default_obj.learning_rate,
|
||||
help="Learning rate for Adam Optimization.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-warmup-steps",
|
||||
type=float,
|
||||
default=500,
|
||||
type=int_or_float,
|
||||
default=default_obj.num_warmup_steps,
|
||||
help="The number of steps for the warmup phase of linear scheduler.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weight-decay",
|
||||
type=float,
|
||||
default=0.01,
|
||||
default=default_obj.weight_decay,
|
||||
help="Weight decay (L2 penalty).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per-device-train-batch-size",
|
||||
type=int,
|
||||
default=8,
|
||||
default=default_obj.per_device_train_batch_size,
|
||||
help="The batch size per GPU/CPU for training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per-device-eval-batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
default=default_obj.per_device_eval_batch_size,
|
||||
help="The batch size per GPU/CPU for evaluation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient-accumulation-steps",
|
||||
type=int,
|
||||
default=1,
|
||||
default=default_obj.gradient_accumulation_steps,
|
||||
help="Number of updates steps to accumulate the gradients for, before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--random-seed", type=int, default=786, help="Random seed.")
|
||||
parser.add_argument(
|
||||
"--random-seed",
|
||||
type=int,
|
||||
default=default_obj.random_seed,
|
||||
help="Random seed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--parallel",
|
||||
action="store_true",
|
||||
default=False,
|
||||
default=default_obj.parallel,
|
||||
help="If set, run training on multiple GPUs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-best-model-at-end",
|
||||
action="store_true",
|
||||
default=False,
|
||||
default=default_obj.load_best_model_at_end,
|
||||
help="If set, keep track of the best model across training and load it at the end.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval-adversarial-robustness",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="If set, evaluate adversarial robustness on evaluation dataset after every epoch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-eval-adv-examples",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="The number of samples attack if `eval_adversarial_robustness=True`. Default is 1000.",
|
||||
"--alpha",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The weight of adversarial loss.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-train-adv-examples",
|
||||
type=int,
|
||||
default=-1,
|
||||
type=int_or_float,
|
||||
default=default_obj.num_train_adv_examples,
|
||||
help="The number of samples to attack when generating adversarial training set. Default is -1 (which is all possible samples).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query-budget-train",
|
||||
type=int,
|
||||
default=None,
|
||||
default=default_obj.query_budget_train,
|
||||
help="The max query budget to use when generating adversarial training set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query-budget-eval",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The max query budget to use when evaluating adversarial robustness.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attack-num-workers-per-device",
|
||||
type=int,
|
||||
default=1,
|
||||
default=default_obj.attack_num_workers_per_device,
|
||||
help="Number of worker processes to run per device for attack. Same as `num_workers_per_device` argument for `AttackArgs`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -240,49 +270,49 @@ class TrainingArgs:
|
||||
parser.add_argument(
|
||||
"--checkpoint-interval-steps",
|
||||
type=int,
|
||||
default=None,
|
||||
default=default_obj.checkpoint_interval_steps,
|
||||
help="Save model checkpoint after every N updates to the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint-interval-epochs",
|
||||
type=int,
|
||||
default=None,
|
||||
default=default_obj.checkpoint_interval_epochs,
|
||||
help="Save model checkpoint after every N epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-last",
|
||||
action="store_true",
|
||||
default=True,
|
||||
default=default_obj.save_last,
|
||||
help="If set, save the model at end of training. Can be used with `--load-best-model-at-end` to save the best model at the end.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-to-tb",
|
||||
action="store_true",
|
||||
default=False,
|
||||
default=default_obj.log_to_tb,
|
||||
help="If set, log to Tensorboard",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tb-log-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
default=default_obj.tb_log_dir,
|
||||
help="Path of Tensorboard log directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-to-wandb",
|
||||
action="store_true",
|
||||
default=False,
|
||||
default=default_obj.log_to_wandb,
|
||||
help="If set, log to Wandb.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wandb-project",
|
||||
type=str,
|
||||
default="textattack",
|
||||
default=default_obj.wandb_project,
|
||||
help="Name of Wandb project for logging.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging-interval-step",
|
||||
type=int,
|
||||
default=1,
|
||||
default=default_obj.logging_interval_step,
|
||||
help="Log to Tensorboard/Wandb every N steps.",
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user