1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

WIP: add datasets and attacker

This commit is contained in:
Jin Yong Yoo
2020-12-27 08:02:38 -05:00
parent 10f54a16f4
commit 2744fd3ca8
10 changed files with 363 additions and 137 deletions

View File

@@ -319,15 +319,15 @@ TEXTATTACK_DATASET_BY_MODEL = {
#
"t5-en-de": (
"english_to_german",
("textattack.datasets.translation.TedMultiTranslationDataset", "en", "de"),
("textattack.datasets.helpers.TedMultiTranslationDataset", "en", "de"),
),
"t5-en-fr": (
"english_to_french",
("textattack.datasets.translation.TedMultiTranslationDataset", "en", "fr"),
("textattack.datasets.helpers.TedMultiTranslationDataset", "en", "fr"),
),
"t5-en-ro": (
"english_to_romanian",
("textattack.datasets.translation.TedMultiTranslationDataset", "en", "de"),
("textattack.datasets.helpers.TedMultiTranslationDataset", "en", "de"),
),
#
# T5 for summarization

View File

@@ -357,7 +357,7 @@ def parse_dataset_from_args(args):
_, dataset = TEXTATTACK_DATASET_BY_MODEL[args.model]
if dataset[0].startswith("textattack"):
# unsavory way to pass custom dataset classes
# ex: dataset = ('textattack.datasets.translation.TedMultiTranslationDataset', 'en', 'de')
# ex: dataset = ('textattack.datasets.helpers.TedMultiTranslationDataset', 'en', 'de')
dataset = eval(f"{dataset[0]}")(*dataset[1:])
return dataset
else:

View File

@@ -21,8 +21,11 @@ class Dataset:
data (list_like): A list-like iterable of ``(input, output)`` pairs. Here, `output` can either be an integer representing labels for classification
or a string for seq2seq tasks. If input consists of multiple sequences (e.g. SNLI), iterable
should be of the form ``([input_1, input_2, ...], output)`` and ``input_columns`` parameter must be set.
lang (str): Two letter ISO 639-1 code representing the language of the input data (e.g. "en", "fr", "ko", "zh"). Default is "en".
input_columns (list[str]): List of column names of inputs in order. Default is ``["text"]`` for single text input.
lang (str, optional): Two letter ISO 639-1 code representing the language of the input data (e.g. "en", "fr", "ko", "zh"). Default is "en".
input_columns (list[str], optional): List of column names of inputs in order. Default is ``["text"]`` for single text input.
label_names (list[str], optional): List of label names in corresponding order (e.g. ``["World", "Sports", "Business", "Sci/Tech"] for AG-News dataset).
If not set, labels will printed as is (e.g. "0", "1", ...). This should be set to ``None`` for non-classification datasets.
shuffle (bool): Whether to shuffle the dataset on load.
Examples::
@@ -30,23 +33,24 @@ class Dataset:
>>> # Example of sentiment-classification dataset
>>> data = [("I enjoyed the movie a lot!", 1), ("Absolutely horrible film.", 0), ("Our family had a fun time!", 1)]
>>> dataset = textattack.dataset.Dataset(data, lang="en")
>>> dataset = textattack.datasets.Dataset(data, lang="en")
>>> dataset[1:2]
>>> # Example for pair of sequence inputs (e.g. SNLI)
>>> data = [("A man inspects the uniform of a figure in some East Asian country.", "The man is sleeping"), 1)]
>>> dataset = textattack.dataset.Dataset(data, lang="en", input_columns=("premise", "hypothesis"))
>>> dataset = textattack.datasets.Dataset(data, lang="en", input_columns=("premise", "hypothesis"))
>>> # Example for seq2seq
>>> data = [("J'aime le film.", "I love the movie.")]
>>> dataset = textattack.dataset.Dataset(data, lang="fr")
>>> dataset = textattack.datasets.Dataset(data, lang="fr")
"""
def __init__(self, data, lang="en", input_columns=["text"], shuffle=False):
self.data = data
def __init__(self, data, lang="en", input_columns=["text"], label_names=None, shuffle=False):
self._data = data
self.lang = lang
self.input_columns = input_columns
self.label_names = label_names
self.shuffled = shuffle
if shuffle:
@@ -72,27 +76,27 @@ class Dataset:
def __getitem__(self, i):
if isinstance(i, int):
return self._format_example(self.data[i])
return self._format_as_dict(self.data[i])
else:
# `i` could be a slice or an integer. if it's a slice,
# return the formatted version of the proper slice of the list
return [self._format_example(ex) for ex in self.data[i]]
return [self._format_as_dict(ex) for ex in self.data[i]]
def __len__(self):
return len(self.data)
class IterableDataset(ABC):
class IterableDataset(Dataset, ABC):
"""Basic class for datasets that fetch data via ``__iter__`` protocol. Idea
is similar to PyTorch's ``IterableDataset``. This is useful if you cannot
load the entire dataset to memory, such as reading from a large file.
Unlike ``Dataset``, ``IterableDataset` is an abstract base class, meaning
that you need to extend it with your own custom child class and define the
``format_example`` method. This is to suppport flexible preprocessing of
each example returned by the underlying iterator.
``format_example`` method, which is responsible for formatting each example into a tuple of ``(input, output)`` pair.
However, for most cases that involve loading from a txt, CSV, or JSON files, we recommend loading it as Huggingface's ``datasets.Dataset`` object and
pass it to TextAttack's ``HuggingFaceDataset`` class. This class is designed for cases where you really need to load data via ``__iter__`` protocol.
pass it to TextAttack's ``HuggingFaceDataset`` class. It uses Apache Arrow as backend and will be sufficient for loading large datasets.
This class is designed for cases where you really need to load data via ``__iter__`` protocol.
For more information about loading files as ``datasets.Dataset`` object, visit https://huggingface.co/docs/datasets/loading_datasets.html.
@@ -100,6 +104,9 @@ class IterableDataset(ABC):
data_iterator: Iterator that returns next element when ``next(data_iterator)`` is called.
lang (str): Two letter ISO 639-1 code representing the language of the data (e.g. "en", "fr", "ko", "zh"). Default is "en".
input_columns (list[str]): List of column names of inputs in order. Default is ``["text"]`` for single text input.
label_names (list[str], optional): List of label names in corresponding order (e.g. ``["World", "Sports", "Business", "Sci/Tech"] for AG-News dataset).
If not set, labels will printed as is (e.g. "0", "1", ...). This should be set to ``None`` for non-classification datasets.
Examples::
Suppose `data.csv` looks like the following:
@@ -127,12 +134,11 @@ class IterableDataset(ABC):
(OrderedDict([('text', 'Our family had a fun time!')]), 1)
"""
def __init__(self, data_iterator, lang="en", input_columns=["text"]):
def __init__(self, data_iterator, lang="en", input_columns=["text"], label_names=None):
self.data_iterator = data_iterator
self.lang = lang
self.input_columns = input_columns
self._i = 0
self.label_names = label_names
def _format_as_dict(self, example):
output = example[1]

View File

@@ -1,6 +1,6 @@
"""
Multi TranslationDataset
Dataset Helpers
=============================
"""

View File

@@ -34,8 +34,6 @@ class TedMultiTranslationDataset(HuggingFaceDataset):
)
self.source_lang = source_lang
self.target_lang = target_lang
self.label_names = ("Translation",)
self._i = 0
def _format_raw_example(self, raw_example):
translations = np.array(raw_example["translation"])

View File

@@ -9,6 +9,7 @@ import random
import datasets
import textattack
from .dataset import Dataset
# from textattack.shared import AttackedText
@@ -19,6 +20,7 @@ def _cb(s):
def get_datasets_dataset_columns(dataset):
"""Common schemas for datasets found in dataset hub"""
schema = set(dataset.column_names)
if {"premise", "hypothesis", "label"} <= schema:
input_columns = ("premise", "hypothesis")
@@ -52,29 +54,27 @@ def get_datasets_dataset_columns(dataset):
output_column = "label"
else:
raise ValueError(
f"Unsupported dataset schema {schema}. Try loading dataset manually (from a file) instead."
f"Unsupported dataset schema {schema}. Try passing your own `dataset_columns` argument."
)
return input_columns, output_column
class HuggingFaceDataset:
class HuggingFaceDataset(Dataset):
"""Loads a dataset from HuggingFace ``datasets`` and prepares it as a
TextAttack dataset.
- name_or_dataset: the dataset name or actual ``datasets.Dataset`` object. If it's your custom ``datasets.Dataset`` object,
- name_or_dataset (Union[datasets.Dataset, str]): the dataset name or actual ``datasets.Dataset`` object. If it's your custom ``datasets.Dataset`` object,
please pass the input and output columns via ``dataset_columns`` argument.
- subset: the subset of the main dataset. Dataset will be loaded as ``datasets.load_dataset(name, subset)``.
- label_map: Mapping if output labels should be re-mapped. Useful
if model was trained with a different label arrangement than
provided in the ``datasets`` version of the dataset.
- output_scale_factor (float): Factor to divide ground-truth outputs by.
Generally, TextAttack goal functions require model outputs
between 0 and 1. Some datasets test the model's correlation
with ground-truth output, instead of its accuracy, so these
outputs may be scaled arbitrarily.
- dataset_columns (tuple(list[str], str))): Pair of ``list[str]`` representing list of input column names (e.g. ["premise", "hypothesis"]) and
``str`` representing the output column name (e.g. ``label``).
- subset (str, optional): the subset of the main dataset. Dataset will be loaded as ``datasets.load_dataset(name, subset)``. Default is ``None``.
- split (str, optioanl): the split of the dataset. Default is "train".
- lang (str, optional): Two letter ISO 639-1 code representing the language of the input data (e.g. "en", "fr", "ko", "zh"). Default is "en".
- dataset_columns (tuple(list[str], str)), optional): Pair of ``list[str]`` representing list of input column names (e.g. ["premise", "hypothesis"]) and
``str`` representing the output column name (e.g. ``label``). If not set, we will try to automatically determine column names from known designs.
- label_map (dict, optional): Mapping if output labels should be re-mapped. Useful if model was trained with a different label arrangement than
provided in the ``datasets`` version of the dataset. For example, to remap "Positive" label to 1 and "Negative" label to 0, pass `{"Positive": 1, "Negative": 0}`.
- label_names (list[str], optional): List of label names in corresponding order (e.g. ``["World", "Sports", "Business", "Sci/Tech"] for AG-News dataset).
If ``datasets.Dataset`` object already has label names, then this is not required. Also, this should be set to ``None`` for non-classification datasets.
- shuffle (bool): Whether to shuffle the dataset on load.
"""
@@ -83,9 +83,10 @@ class HuggingFaceDataset:
name_or_dataset,
subset=None,
split="train",
label_map=None,
output_scale_factor=None,
lang="en",
dataset_columns=None,
label_map=None,
label_names=None,
shuffle=False,
):
if isinstance(name_or_dataset, datasets.Dataset):
@@ -102,28 +103,14 @@ class HuggingFaceDataset:
self.input_columns,
self.output_column,
) = dataset_columns or get_datasets_dataset_columns(self._dataset)
self._i = 0
self.examples = list(self._dataset)
self.label_map = label_map
self.output_scale_factor = output_scale_factor
try:
self.label_names = self._dataset.features["label"].names
# If labels are remapped, the label names have to be remapped as
# well.
if label_map:
self.label_names = [
self.label_names[self.label_map[i]]
for i in range(len(self.label_map))
]
except KeyError:
# This happens when the dataset doesn't have 'features' or a 'label' column.
self.label_names = None
except AttributeError:
# This happens when self._dataset.features["label"] exists
# but is a single value.
self.label_names = ("label",)
if shuffle:
random.shuffle(self.examples)
self._dataset.shuffle()
def _format_as_dict(self, example):
input_dict = collections.OrderedDict(
@@ -133,25 +120,16 @@ class HuggingFaceDataset:
output = example[self.output_column]
if self.label_map:
output = self.label_map[output]
if self.output_scale_factor:
output = output / self.output_scale_factor
return (input_dict, output)
def __next__(self):
if self._i >= len(self.examples):
raise StopIteration
example = self.examples[self._i]
self._i += 1
return self._format_as_dict(example)
def __getitem__(self, i):
if isinstance(i, int):
return self._format_as_dict(self.examples[i])
return self._format_as_dict(self._dataset[i])
else:
# `i` could be a slice or an integer. if it's a slice,
# return the formatted version of the proper slice of the list
return [self._format_as_dict(ex) for ex in self.examples[i]]
return [self._format_as_dict(ex) for ex in self._dataset[i]]
def __len__(self):
return len(self.examples)
return len(self._dataset)

View File

@@ -12,11 +12,11 @@ from .logger import Logger
class WeightsAndBiasesLogger(Logger):
"""Logs attack results to Weights & Biases."""
def __init__(self, filename="", stdout=False):
def __init__(self, project_name):
global wandb
wandb = LazyLoader("wandb", globals(), "wandb")
wandb.init(project="textattack", resume=True)
wandb.init(project=project_name)
self._result_table_rows = []
def __setstate__(self, state):

View File

@@ -29,8 +29,7 @@ from textattack.transformations import CompositeTransformation
class Attack:
"""An attack generates adversarial examples on text.
This is an abstract class that contains main helper functionality for
attacks. An attack is comprised of a search method, goal function,
An attack is comprised of a search method, goal function,
a transformation, and a set of one or more linguistic constraints that
successful examples must meet.
@@ -255,7 +254,7 @@ class Attack:
filtered_texts.sort(key=lambda t: t.text)
return filtered_texts
def attack_one(self, initial_result):
def _attack(self, initial_result):
"""Calls the ``SearchMethod`` to perturb the ``AttackedText`` stored in
``initial_result``.
@@ -286,66 +285,24 @@ class Attack:
else:
raise ValueError(f"Unrecognized goal status {final_result.goal_status}")
def _get_examples_from_dataset(self, dataset, indices=None):
"""Gets examples from a dataset and tokenizes them.
def attack(self, example, ground_truth_output):
"""Attack a single example represented as ``AttackedText``
Args:
dataset: An iterable of (text_input, ground_truth_output) pairs
indices: An iterable of indices of the dataset that we want to attack. If None, attack all samples in dataset.
example (Union[AttackedText]): example to attack.
ground_truth_output: ground truth output of ``example``.
Returns:
results (Iterable[GoalFunctionResult]): an iterable of GoalFunctionResults of the original examples
AttackResult
"""
if indices is None:
indices = range(len(dataset))
assert isinstance(example, AttackedText), "`example` must be of type `AttackedText`."
assert isinstance(ground_truth_output, [int, str]), "`ground_truth_output` must either be `str` or `int`."
goal_function_result, _ = self.goal_function.init_attack_example(example, ground_truth_output)
if goal_function_result.goal_status == GoalFunctionResultStatus.SKIPPED:
return SkippedAttackResult(goal_function_result)
else:
result = self._attack(goal_function_result)
return result
if not isinstance(indices, deque):
indices = deque(sorted(indices))
if not indices:
return
yield
while indices:
i = indices.popleft()
try:
text_input, ground_truth_output = dataset[i]
except IndexError:
utils.logger.warn(
f"Dataset has {len(dataset)} samples but tried to access index {i}. Ending attack early."
)
break
try:
# get label names from dataset, if possible
label_names = dataset.label_names
except AttributeError:
label_names = None
attacked_text = AttackedText(
text_input, attack_attrs={"label_names": label_names}
)
goal_function_result, _ = self.goal_function.init_attack_example(
attacked_text, ground_truth_output
)
yield goal_function_result
def attack_dataset(self, dataset, indices=None):
"""Runs an attack on the given dataset and outputs the results to the
console and the output file.
Args:
dataset: An iterable of (text, ground_truth_output) pairs.
indices: An iterable of indices of the dataset that we want to attack. If None, attack all samples in dataset.
"""
examples = self._get_examples_from_dataset(dataset, indices=indices)
for goal_function_result in examples:
if goal_function_result.goal_status == GoalFunctionResultStatus.SKIPPED:
yield SkippedAttackResult(goal_function_result)
else:
result = self.attack_one(goal_function_result)
yield result
def __repr__(self):
"""Prints attack parameters in a human-readable string.

View File

@@ -0,0 +1,287 @@
import time
import tqdm
import torch
import collections
from .attack import Attack
import textattack
from textattack.shared.utils import logger
class AttackArgs:
def __init__(self):
pass
class CliAttackArgs(AttackArgs):
def __init__(self):
pass
class Attacker:
def __init__(self, attack, attack_log_manager=None):
assert isinstance(attack, Attack), f"`attack` argument must be of type `textattack.shared.Attack`, but received argument of type `{type(attack)}`."
self.attack = attack
if attack_log_manager is None:
self.attack_log_manager = textattack.logger.AttackLogManager()
else:
self.attack_log_manager = attack_log_manager
def _attack(self, dataset, attack_args=None, checkpoint=None):
"""Internal method that carries out attack. No parallel processing is involved.
Args:
dataset (textattack.datasets.Dataset): Dataset to attack
attack_args (textattack.shared.AttackArgs, optional): Arguments for attack. This will be overrided by checkpoint's argument if `checkpoint` is not `None`.
checkpoint (textattack.shared.Checkpoint, optional): Checkpoint from which to resume the attack.
"""
if checkpoint is not None and attack_args is not None:
raise ValueError("`attack_args` and `checkpoint` cannot be both set.")
if checkpoint:
num_remaining_attacks = checkpoint.num_remaining_attacks
worklist = checkpoint.worklist
worklist_tail = checkpoint.worklist_tail
logger.info(
"Recovered from checkpoint previously saved at {}.".format(
checkpoint.datetime
)
)
else:
num_remaining_attacks = attack_args.num_examples
worklist = collections.deque(range(0, args.num_examples))
worklist_tail = worklist[-1]
pbar = tqdm.tqdm(total=num_remaining_attacks, smoothing=0)
if checkpoint:
num_results = checkpoint.results_count
num_failures = checkpoint.num_failed_attacks
num_successes = checkpoint.num_successful_attacks
else:
num_results = 0
num_failures = 0
num_successes = 0
i = 0
while i < len(worklist):
idx = worklist[i]
i += 1
example, ground_truth_output = self.dataset[idx]
example = textattack.shared.AttackedText(example)
if self.dataset.label_names is not None:
example.attack_attrs["label_names"] = self.dataset.label_names
result = self.attack.attack(example, ground_truth_output)
self.attack_log_manager.log_result(result)
if not attack_args.disable_stdout:
print("\n")
if isinstance(result, textattack.attack_results.SkippedAttackResult) and attack_args.attack_n:
# `worklist_tail` keeps track of highest idx that has been part of worklist.
# This is useful for async-logging that can happen when using parallel processing.
# Used to get the next dataset element when attacking with `attack_n` = True.
worklist_tail += 1
worklist.append(worklist_tail)
else:
pbar.update(1)
num_results += 1
if (
type(result) == textattack.attack_results.SuccessfulAttackResult
or type(result) == textattack.attack_results.MaximizedAttackResult
):
num_successes += 1
if type(result) == textattack.attack_results.FailedAttackResult:
num_failures += 1
pbar.set_description(
"[Succeeded / Failed / Total] {} / {} / {}".format(
num_successes, num_failures, num_results
)
)
if (
attack_args.checkpoint_interval
and len(self.attack_log_manager.results) % attack_args.checkpoint_interval == 0
):
new_checkpoint = textattack.shared.Checkpoint(
attack_args, self.attack_log_manager, worklist, worklist_tail
)
new_checkpoint.save()
self.attack_log_manager.flush()
pbar.close()
# Enable summary stdout
if attack_args.disable_stdout:
self.attack_log_manager.enable_stdout()
self.attack_log_manager.log_summary()
self.attack_log_manager.flush()
def _attack_parallel(self, dataset, num_workers_per_device, attack_args=None, checkpoint=None):
if checkpoint is not None and attack_args is not None:
raise ValueError("`attack_args` and `checkpoint` cannot be both set.")
if checkpoint:
num_remaining_attacks = checkpoint.num_remaining_attacks
worklist = checkpoint.worklist
worklist_tail = checkpoint.worklist_tail
logger.info(
"Recovered from checkpoint previously saved at {}.".format(
checkpoint.datetime
)
)
else:
num_remaining_attacks = attack_args.num_examples
worklist = collections.deque(range(0, args.num_examples))
worklist_tail = worklist[-1]
# We reserve the first GPU for coordinating workers.
num_gpus = torch.cuda.device_count()
num_workers = num_workers_per_device * num_gpus
textattack.shared.logger.info(f"Running {num_workers} workers on {num_gpus} GPUs")
in_queue = torch.multiprocessing.Queue()
out_queue = torch.multiprocessing.Queue()
# Add stuff to queue.
missing_datapoints = set()
for i in worklist:
try:
example, ground_truth_output = self.dataset[i]
example = textattack.shared.AttackedText(example)
if self.dataset.label_names is not None:
example.attack_attrs["label_names"] = self.dataset.label_names
in_queue.put((i, example, ground_truth_output))
except IndexError:
missing_datapoints.add(i)
# if our dataset is shorter than the number of samples chosen, remove the
# out-of-bounds indices from the dataset
for i in missing_datapoints:
worklist.remove(i)
# Start workers.
torch.multiprocessing.Pool(num_gpus, attack_from_queue, (args, in_queue, out_queue))
# Log results asynchronously and update progress bar.
if args.checkpoint_resume:
num_results = checkpoint.results_count
num_failures = checkpoint.num_failed_attacks
num_successes = checkpoint.num_successful_attacks
else:
num_results = 0
num_failures = 0
num_successes = 0
pbar = tqdm.tqdm(total=num_remaining_attacks, smoothing=0)
while worklist:
result = out_queue.get(block=True)
if isinstance(result, Exception):
raise result
idx, result = result
attack_log_manager.log_result(result)
worklist.remove(idx)
if (not args.attack_n) or (
not isinstance(result, textattack.attack_results.SkippedAttackResult)
):
pbar.update()
num_results += 1
if (
type(result) == textattack.attack_results.SuccessfulAttackResult
or type(result) == textattack.attack_results.MaximizedAttackResult
):
num_successes += 1
if type(result) == textattack.attack_results.FailedAttackResult:
num_failures += 1
pbar.set_description(
"[Succeeded / Failed / Total] {} / {} / {}".format(
num_successes, num_failures, num_results
)
)
else:
# worklist_tail keeps track of highest idx that has been part of worklist
# Used to get the next dataset element when attacking with `attack_n` = True.
worklist_tail += 1
try:
text, output = dataset[worklist_tail]
worklist.append(worklist_tail)
in_queue.put((worklist_tail, text, output))
except IndexError:
raise IndexError(
"Tried adding to worklist, but ran out of datapoints. Size of data is {} but tried to access index {}".format(
len(dataset), worklist_tail
)
)
if (
args.checkpoint_interval
and len(attack_log_manager.results) % args.checkpoint_interval == 0
):
new_checkpoint = textattack.shared.Checkpoint(
args, attack_log_manager, worklist, worklist_tail
)
new_checkpoint.save()
attack_log_manager.flush()
pbar.close()
print()
# Enable summary stdout.
if args.disable_stdout:
attack_log_manager.enable_stdout()
attack_log_manager.log_summary()
attack_log_manager.flush()
print()
def attack(self, dataset, attack_args):
"""Attack `dataset` and record results to specified loggers.
Args:
dataset (textattack.datasets.Dataset): dataset to attack.
attack_args (textattack.shared.AttackArgs): arguments for attack.
"""
assert isinstance(dataset, textattack.datasets.Dataset), f"`dataset` argument must be of type `textattack.datasets.Dataset`, but received argument of type `{type(dataset)}`."
assert isinstance(dataset, textattack.shared.AttackArgs), f"`attack_args` argument must be of type `textattack.shared.AttackArgs`, but received argument of type `{type(attack_args)}`."
self._attack(dataset, attack_args=attack_args)
def attack_parallel(self, dataset, attack_args, num_workers_per_device):
"""Attack `dataset` with single worker and record results to specified loggers.
Args:
dataset (textattack.datasets.Dataset): dataset to attack.
num_workers_per_device (int): Number of worker threads to run per device. For example, if you are using GPUs and ``num_workers_per_device=2``,
then 2 processes will be running in each GPU. If you are only using CPU, then this is equivalent to running 2 processes concurrently.
"""
assert isinstance(dataset, textattack.datasets.Dataset), f"`dataset` argument must be of type `textattack.datasets.Dataset`, but received argument of type `{type(dataset)}`."
assert isinstance(dataset, textattack.shared.AttackArgs), f"`attack_args` argument must be of type `textattack.shared.AttackArgs`, but received argument of type `{type(attack_args)}`."
self._attack_parallel(dataset, num_workers_per_device, attack_args=attack_args)
def resume_attack(self, dataset, checkpoint):
"""Resume attacking `dataset` from saved `checkpoint`.
Args:
dataset (textattack.datasets.Dataset): dataset to attack.
checkpoint (textattack.shared.Checkpoint): checkpoint object that has previously been saved.
"""
assert isinstance(dataset, textattack.datasets.Dataset), f"`dataset` argument must be of type `textattack.datasets.Dataset`, but received argument of type `{type(dataset)}`."
assert isinstance(dataset, textattack.shared.Checkpoint), f"`checkpoint` argument must be of type `textattack.shared.Checkpoint`, but received argument of type `{type(checkpoint)}`."
self.attack_log_manager = checkpoint.attack_log_manager
self._attack(dataset, checkpoint=checkpoint)
def resume_attack_parallel(self, dataset, checkpoint, num_workers_per_device):
"""Resume attacking `dataset` from saved `checkpoint`.
Args:
dataset (textattack.datasets.Dataset): dataset to attack.
checkpoint (textattack.shared.Checkpoint): checkpoint object that has previously been saved.
"""
assert isinstance(dataset, textattack.datasets.Dataset), f"`dataset` argument must be of type `textattack.datasets.Dataset`, but received argument of type `{type(dataset)}`."
assert isinstance(dataset, textattack.shared.Checkpoint), f"`checkpoint` argument must be of type `textattack.shared.Checkpoint`, but received argument of type `{type(checkpoint)}`."
self.attack_log_manager = checkpoint.attack_log_manager
self._attack(dataset, checkpoint=checkpoint)

View File

@@ -24,17 +24,17 @@ class Checkpoint:
checkpoints.
Args:
args: Command line arguments of the original attack
log_manager (AttackLogManager): Object for storing attack results
args: Arguments of the original attack
attack_log_manager (AttackLogManager): Object for storing attack results
worklist (deque[int]): List of examples that will be attacked. Examples are represented by their indicies within the dataset.
worklist_tail (int): Highest index that had been in the worklist at any given time. Used to get the next dataset element
when attacking with `attack_n` = True.
chkpt_time (float): epoch time representing when checkpoint was made
"""
def __init__(self, args, log_manager, worklist, worklist_tail, chkpt_time=None):
def __init__(self, args, attack_log_manager, worklist, worklist_tail, chkpt_time=None):
self.args = copy.deepcopy(args)
self.log_manager = log_manager
self.attack_log_manager = attack_log_manager
self.worklist = worklist
self.worklist_tail = worklist_tail
if chkpt_time:
@@ -137,26 +137,26 @@ class Checkpoint:
@property
def results_count(self):
"""Return number of attacks made so far."""
return len(self.log_manager.results)
return len(self.attack_log_manager.results)
@property
def num_skipped_attacks(self):
return sum(isinstance(r, SkippedAttackResult) for r in self.log_manager.results)
return sum(isinstance(r, SkippedAttackResult) for r in self.attack_log_manager.results)
@property
def num_failed_attacks(self):
return sum(isinstance(r, FailedAttackResult) for r in self.log_manager.results)
return sum(isinstance(r, FailedAttackResult) for r in self.attack_log_manager.results)
@property
def num_successful_attacks(self):
return sum(
isinstance(r, SuccessfulAttackResult) for r in self.log_manager.results
isinstance(r, SuccessfulAttackResult) for r in self.attack_log_manager.results
)
@property
def num_maximized_attacks(self):
return sum(
isinstance(r, MaximizedAttackResult) for r in self.log_manager.results
isinstance(r, MaximizedAttackResult) for r in self.attack_log_manager.results
)
@property
@@ -209,7 +209,7 @@ class Checkpoint:
), "Recorded number of remaining attacks and size of worklist are different."
results_set = set()
for result in self.log_manager.results:
for result in self.attack_log_manager.results:
results_set.add(result.original_text)
assert len(results_set) == self.results_count, "Duplicate AttackResults found."
assert len(results_set) == self.results_count, "Duplicate `AttackResults` found."