mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
Merge branch 'master' into train
This commit is contained in:
@@ -113,7 +113,7 @@ Follow these steps to start contributing:
|
||||
|
||||
```bash
|
||||
$ cd TextAttack
|
||||
$ pip install -e .
|
||||
$ pip install -e . ".[dev]"
|
||||
$ pip install black isort pytest pytest-xdist
|
||||
```
|
||||
|
||||
|
||||
30
README.md
30
README.md
@@ -1,11 +1,10 @@
|
||||
|
||||
|
||||
<h1 align="center">TextAttack 🐙</h1>
|
||||
|
||||
<p align="center">Generating adversarial examples for NLP models</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://textattack.readthedocs.io/">Docs</a> •
|
||||
<a href="https://textattack.readthedocs.io/">Docs</a>
|
||||
<br>
|
||||
<a href="#about">About</a> •
|
||||
<a href="#setup">Setup</a> •
|
||||
<a href="#usage">Usage</a> •
|
||||
@@ -37,10 +36,9 @@ pip install textattack
|
||||
Once TextAttack is installed, you can run it via command-line (`textattack ...`)
|
||||
or via the python module (`python -m textattack ...`).
|
||||
|
||||
### Configuration
|
||||
TextAttack downloads files to `~/.cache/textattack/` by default. This includes pretrained models,
|
||||
dataset samples, and the configuration file `config.yaml`. To change the cache path, set the
|
||||
environment variable `TA_CACHE_DIR`.
|
||||
> TextAttack downloads files to `~/.cache/textattack/` by default. This includes pretrained models,
|
||||
> dataset samples, and the configuration file `config.yaml`. To change the cache path, set the
|
||||
> environment variable `TA_CACHE_DIR`. (for example: `TA_CACHE_DIR=/tmp/ textattack attack ...`).
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -51,7 +49,9 @@ information about all commands using `textattack --help`, or a specific command
|
||||
|
||||
### Running Attacks
|
||||
|
||||
The [`examples/`](docs/examples/) folder contains notebooks explaining basic usage of TextAttack, including building a custom transformation and a custom constraint. These examples can also be viewed through the [documentation website](https://textattack.readthedocs.io/en/latest).
|
||||
The [`examples/`](docs/examples/) folder contains notebooks explaining basic usage of TextAttack,
|
||||
including building a custom transformation and a custom constraint. These examples can also be viewed
|
||||
through the [documentation website](https://textattack.readthedocs.io/en/latest).
|
||||
|
||||
The easiest way to try out an attack is via the command-line interface, `textattack attack`. Here are some concrete examples:
|
||||
|
||||
@@ -185,9 +185,9 @@ textattack train --model bert-base-uncased --dataset glue:cola --batch-size 32 -
|
||||
|
||||
## Design
|
||||
|
||||
### TokenizedText
|
||||
### AttackedText
|
||||
|
||||
To allow for word replacement after a sequence has been tokenized, we include a `TokenizedText` object
|
||||
To allow for word replacement after a sequence has been tokenized, we include an `AttackedText` object
|
||||
which maintains both a list of tokens and the original text, with punctuation. We use this object in favor of a list of words or just raw text.
|
||||
|
||||
### Models and Datasets
|
||||
@@ -252,23 +252,23 @@ You can then run attacks on samples from this dataset by adding the argument `--
|
||||
|
||||
### Attacks
|
||||
|
||||
The `attack_one` method in an `Attack` takes as input a `TokenizedText`, and outputs either a `SuccessfulAttackResult` if it succeeds or a `FailedAttackResult` if it fails. We formulate an attack as consisting of four components: a **goal function** which determines if the attack has succeeded, **constraints** defining which perturbations are valid, a **transformation** that generates potential modifications given an input, and a **search method** which traverses through the search space of possible perturbations.
|
||||
The `attack_one` method in an `Attack` takes as input an `AttackedText`, and outputs either a `SuccessfulAttackResult` if it succeeds or a `FailedAttackResult` if it fails. We formulate an attack as consisting of four components: a **goal function** which determines if the attack has succeeded, **constraints** defining which perturbations are valid, a **transformation** that generates potential modifications given an input, and a **search method** which traverses through the search space of possible perturbations.
|
||||
|
||||
### Goal Functions
|
||||
|
||||
A `GoalFunction` takes as input a `TokenizedText` object and the ground truth output, and determines whether the attack has succeeded, returning a `GoalFunctionResult`.
|
||||
A `GoalFunction` takes as input an `AttackedText` object and the ground truth output, and determines whether the attack has succeeded, returning a `GoalFunctionResult`.
|
||||
|
||||
### Constraints
|
||||
|
||||
A `Constraint` takes as input a current `TokenizedText`, and a list of transformed `TokenizedText`s. For each transformed option, it returns a boolean representing whether the constraint is met.
|
||||
A `Constraint` takes as input a current `AttackedText`, and a list of transformed `AttackedText`s. For each transformed option, it returns a boolean representing whether the constraint is met.
|
||||
|
||||
### Transformations
|
||||
|
||||
A `Transformation` takes as input a `TokenizedText` and returns a list of possible transformed `TokenizedText`s. For example, a transformation might return all possible synonym replacements.
|
||||
A `Transformation` takes as input an `AttackedText` and returns a list of possible transformed `AttackedText`s. For example, a transformation might return all possible synonym replacements.
|
||||
|
||||
### Search Methods
|
||||
|
||||
A `SearchMethod` takes as input an initial `GoalFunctionResult` and returns a final `GoalFunctionResult` The search is given access to the `get_transformations` function, which takes as input a `TokenizedText` object and outputs a list of possible transformations filtered by meeting all of the attack’s constraints. A search consists of successive calls to `get_transformations` until the search succeeds (determined using `get_goal_results`) or is exhausted.
|
||||
A `SearchMethod` takes as input an initial `GoalFunctionResult` and returns a final `GoalFunctionResult` The search is given access to the `get_transformations` function, which takes as input an `AttackedText` object and outputs a list of possible transformations filtered by meeting all of the attack’s constraints. A search consists of successive calls to `get_transformations` until the search succeeds (determined using `get_goal_results`) or is exhausted.
|
||||
|
||||
## Contributing to TextAttack
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ nlp
|
||||
nltk
|
||||
numpy
|
||||
pandas
|
||||
pyyaml>=5.1
|
||||
scikit-learn
|
||||
scipy==1.4.1
|
||||
sentence_transformers
|
||||
|
||||
6
setup.py
6
setup.py
@@ -6,6 +6,11 @@ from docs import conf as docs_conf
|
||||
with open("README.md", "r") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
extras = {}
|
||||
# For developers, install development tools along with all optional dependencies.
|
||||
extras["dev"] = ["black", "isort", "pytest", "pytest-xdist"]
|
||||
|
||||
|
||||
setuptools.setup(
|
||||
name="textattack",
|
||||
version=docs_conf.release,
|
||||
@@ -28,6 +33,7 @@ setuptools.setup(
|
||||
"wandb*",
|
||||
]
|
||||
),
|
||||
extras_require=extras,
|
||||
entry_points={
|
||||
"console_scripts": ["textattack=textattack.commands.textattack_cli:main"],
|
||||
},
|
||||
|
||||
@@ -154,6 +154,8 @@ def parse_goal_function_from_args(args, model):
|
||||
else:
|
||||
raise ValueError(f"Error: unsupported goal_function {goal_function}")
|
||||
goal_function.query_budget = args.query_budget
|
||||
goal_function.model_batch_size = args.model_batch_size
|
||||
goal_function.model_cache_size = args.model_cache_size
|
||||
return goal_function
|
||||
|
||||
|
||||
@@ -192,6 +194,9 @@ def parse_attack_from_args(args):
|
||||
else:
|
||||
raise ValueError(f"Invalid recipe {args.recipe}")
|
||||
recipe.goal_function.query_budget = args.query_budget
|
||||
recipe.goal_function.model_batch_size = args.model_batch_size
|
||||
recipe.goal_function.model_cache_size = args.model_cache_size
|
||||
recipe.constraint_cache_size = args.constraint_cache_size
|
||||
return recipe
|
||||
elif args.attack_from_file:
|
||||
if ":" in args.attack_from_file:
|
||||
@@ -219,7 +224,11 @@ def parse_attack_from_args(args):
|
||||
else:
|
||||
raise ValueError(f"Error: unsupported attack {args.search}")
|
||||
return textattack.shared.Attack(
|
||||
goal_function, constraints, transformation, search_method
|
||||
goal_function,
|
||||
constraints,
|
||||
transformation,
|
||||
search_method,
|
||||
constraint_cache_size=args.constraint_cache_size,
|
||||
)
|
||||
|
||||
|
||||
@@ -434,7 +443,6 @@ def parse_checkpoint_from_args(args):
|
||||
checkpoint_path = args.checkpoint_file
|
||||
|
||||
checkpoint = textattack.shared.Checkpoint.load(checkpoint_path)
|
||||
set_seed(checkpoint.args.random_seed)
|
||||
|
||||
return checkpoint
|
||||
|
||||
@@ -451,7 +459,6 @@ def merge_checkpoint_args(saved_args, cmdline_args):
|
||||
""" Merge previously saved arguments for checkpoint and newly entered arguments """
|
||||
args = copy.deepcopy(saved_args)
|
||||
# Newly entered arguments take precedence
|
||||
args.checkpoint_resume = cmdline_args.checkpoint_resume
|
||||
args.parallel = cmdline_args.parallel
|
||||
# If set, replace
|
||||
if cmdline_args.checkpoint_dir:
|
||||
|
||||
@@ -157,6 +157,24 @@ class AttackCommand(TextAttackCommand):
|
||||
default=float("inf"),
|
||||
help="The maximum number of model queries allowed per example attacked.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="The batch size for making calls to the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-cache-size",
|
||||
type=int,
|
||||
default=2 ** 18,
|
||||
help="The maximum number of items to keep in the model results cache at once.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--constraint-cache-size",
|
||||
type=int,
|
||||
default=2 ** 18,
|
||||
help="The maximum number of items to keep in the constraints cache at once.",
|
||||
)
|
||||
|
||||
attack_group = parser.add_mutually_exclusive_group(required=False)
|
||||
search_choices = ", ".join(SEARCH_METHOD_CLASS_NAMES.keys())
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
||||
|
||||
import textattack
|
||||
from textattack.commands import TextAttackCommand
|
||||
from textattack.commands.attack.attack_args_helpers import *
|
||||
|
||||
|
||||
class AttackResumeCommand(TextAttackCommand):
|
||||
@@ -10,9 +12,11 @@ class AttackResumeCommand(TextAttackCommand):
|
||||
A command line parser to resume a checkpointed attack from user specifications.
|
||||
"""
|
||||
|
||||
def run(self):
|
||||
textattack.shared.utils.set_seed(self.random_seed)
|
||||
self.checkpoint_resume = True
|
||||
def run(self, args):
|
||||
checkpoint = parse_checkpoint_from_args(args)
|
||||
args = merge_checkpoint_args(checkpoint.args, args)
|
||||
textattack.shared.utils.set_seed(args.random_seed)
|
||||
args.checkpoint_resume = True
|
||||
|
||||
# Run attack from checkpoint.
|
||||
from textattack.commands.attack.run_attack_parallel import run as run_parallel
|
||||
@@ -20,10 +24,10 @@ class AttackResumeCommand(TextAttackCommand):
|
||||
run as run_single_threaded,
|
||||
)
|
||||
|
||||
if self.parallel:
|
||||
run_parallel(self)
|
||||
if args.parallel:
|
||||
run_parallel(args, checkpoint=checkpoint)
|
||||
else:
|
||||
run_single_threaded(self)
|
||||
run_single_threaded(args, checkpoint=checkpoint)
|
||||
|
||||
@staticmethod
|
||||
def register_subcommand(main_parser: ArgumentParser):
|
||||
|
||||
@@ -48,26 +48,22 @@ def attack_from_queue(args, in_queue, out_queue):
|
||||
exit()
|
||||
|
||||
|
||||
def run(args):
|
||||
def run(args, checkpoint=None):
|
||||
pytorch_multiprocessing_workaround()
|
||||
|
||||
if args.checkpoint_resume:
|
||||
# Override current args with checkpoint args
|
||||
resume_checkpoint = parse_checkpoint_from_args(args)
|
||||
args = merge_checkpoint_args(resume_checkpoint.args, args)
|
||||
num_total_examples = args.num_examples
|
||||
|
||||
num_remaining_attacks = resume_checkpoint.num_remaining_attacks
|
||||
num_total_examples = args.num_examples
|
||||
worklist = resume_checkpoint.worklist
|
||||
worklist_tail = resume_checkpoint.worklist_tail
|
||||
if args.checkpoint_resume:
|
||||
num_remaining_attacks = checkpoint.num_remaining_attacks
|
||||
worklist = checkpoint.worklist
|
||||
worklist_tail = checkpoint.worklist_tail
|
||||
logger.info(
|
||||
"Recovered from checkpoint previously saved at {}".format(
|
||||
resume_checkpoint.datetime
|
||||
checkpoint.datetime
|
||||
)
|
||||
)
|
||||
print(resume_checkpoint, "\n")
|
||||
print(checkpoint, "\n")
|
||||
else:
|
||||
num_total_examples = args.num_examples
|
||||
num_remaining_attacks = num_total_examples
|
||||
worklist = deque(range(0, num_total_examples))
|
||||
worklist_tail = worklist[-1]
|
||||
@@ -79,7 +75,7 @@ def run(args):
|
||||
start_time = time.time()
|
||||
|
||||
if args.checkpoint_resume:
|
||||
attack_log_manager = resume_checkpoint.log_manager
|
||||
attack_log_manager = checkpoint.log_manager
|
||||
else:
|
||||
attack_log_manager = parse_logger_from_args(args)
|
||||
|
||||
@@ -107,9 +103,9 @@ def run(args):
|
||||
)
|
||||
# Log results asynchronously and update progress bar.
|
||||
if args.checkpoint_resume:
|
||||
num_results = resume_checkpoint.results_count
|
||||
num_failures = resume_checkpoint.num_failed_attacks
|
||||
num_successes = resume_checkpoint.num_successful_attacks
|
||||
num_results = checkpoint.results_count
|
||||
num_failures = checkpoint.num_failed_attacks
|
||||
num_successes = checkpoint.num_successful_attacks
|
||||
else:
|
||||
num_results = 0
|
||||
num_failures = 0
|
||||
@@ -157,10 +153,10 @@ def run(args):
|
||||
args.checkpoint_interval
|
||||
and len(attack_log_manager.results) % args.checkpoint_interval == 0
|
||||
):
|
||||
checkpoint = textattack.shared.Checkpoint(
|
||||
new_checkpoint = textattack.shared.Checkpoint(
|
||||
args, attack_log_manager, worklist, worklist_tail
|
||||
)
|
||||
checkpoint.save()
|
||||
new_checkpoint.save()
|
||||
attack_log_manager.flush()
|
||||
|
||||
pbar.close()
|
||||
|
||||
@@ -16,7 +16,7 @@ from .attack_args_helpers import *
|
||||
logger = textattack.shared.logger
|
||||
|
||||
|
||||
def run(args):
|
||||
def run(args, checkpoint=None):
|
||||
# Only use one GPU, if we have one.
|
||||
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
@@ -28,20 +28,16 @@ def run(args):
|
||||
os.environ["TFHUB_CACHE_DIR"] = os.path.expanduser("~/.cache/tensorflow-hub")
|
||||
|
||||
if args.checkpoint_resume:
|
||||
# Override current args with checkpoint args
|
||||
resume_checkpoint = parse_checkpoint_from_args(args)
|
||||
args = merge_checkpoint_args(resume_checkpoint.args, args)
|
||||
|
||||
num_remaining_attacks = resume_checkpoint.num_remaining_attacks
|
||||
worklist = resume_checkpoint.worklist
|
||||
worklist_tail = resume_checkpoint.worklist_tail
|
||||
num_remaining_attacks = checkpoint.num_remaining_attacks
|
||||
worklist = checkpoint.worklist
|
||||
worklist_tail = checkpoint.worklist_tail
|
||||
|
||||
logger.info(
|
||||
"Recovered from checkpoint previously saved at {}".format(
|
||||
resume_checkpoint.datetime
|
||||
checkpoint.datetime
|
||||
)
|
||||
)
|
||||
print(resume_checkpoint, "\n")
|
||||
print(checkpoint, "\n")
|
||||
else:
|
||||
num_remaining_attacks = args.num_examples
|
||||
worklist = deque(range(0, args.num_examples))
|
||||
@@ -55,7 +51,7 @@ def run(args):
|
||||
|
||||
# Logger
|
||||
if args.checkpoint_resume:
|
||||
attack_log_manager = resume_checkpoint.log_manager
|
||||
attack_log_manager = checkpoint.log_manager
|
||||
else:
|
||||
attack_log_manager = parse_logger_from_args(args)
|
||||
|
||||
@@ -89,9 +85,9 @@ def run(args):
|
||||
|
||||
pbar = tqdm.tqdm(total=num_remaining_attacks, smoothing=0)
|
||||
if args.checkpoint_resume:
|
||||
num_results = resume_checkpoint.results_count
|
||||
num_failures = resume_checkpoint.num_failed_attacks
|
||||
num_successes = resume_checkpoint.num_successful_attacks
|
||||
num_results = checkpoint.results_count
|
||||
num_failures = checkpoint.num_failed_attacks
|
||||
num_successes = checkpoint.num_successful_attacks
|
||||
else:
|
||||
num_results = 0
|
||||
num_failures = 0
|
||||
@@ -128,10 +124,10 @@ def run(args):
|
||||
args.checkpoint_interval
|
||||
and len(attack_log_manager.results) % args.checkpoint_interval == 0
|
||||
):
|
||||
checkpoint = textattack.shared.Checkpoint(
|
||||
new_checkpoint = textattack.shared.Checkpoint(
|
||||
args, attack_log_manager, worklist, worklist_tail
|
||||
)
|
||||
checkpoint.save()
|
||||
new_checkpoint.save()
|
||||
attack_log_manager.flush()
|
||||
|
||||
pbar.close()
|
||||
|
||||
@@ -19,7 +19,7 @@ class UntargetedClassification(ClassificationGoalFunction):
|
||||
def _is_goal_complete(self, model_output, ground_truth_output):
|
||||
if self.target_max_score:
|
||||
return model_output[ground_truth_output] < self.target_max_score
|
||||
elif (model_output.numel() is 1) and isinstance(ground_truth_output, float):
|
||||
elif (model_output.numel() == 1) and isinstance(ground_truth_output, float):
|
||||
return abs(ground_truth_output - model_output.item()) >= (
|
||||
self.target_max_score or 0.5
|
||||
)
|
||||
@@ -29,7 +29,7 @@ class UntargetedClassification(ClassificationGoalFunction):
|
||||
def _get_score(self, model_output, ground_truth_output):
|
||||
# If the model outputs a single number and the ground truth output is
|
||||
# a float, we assume that this is a regression task.
|
||||
if (model_output.numel() is 1) and isinstance(ground_truth_output, float):
|
||||
if (model_output.numel() == 1) and isinstance(ground_truth_output, float):
|
||||
return abs(model_output.item() - ground_truth_output)
|
||||
else:
|
||||
return 1 - model_output[ground_truth_output]
|
||||
|
||||
@@ -13,12 +13,21 @@ class GoalFunction:
|
||||
Evaluates how well a perturbed attacked_text object is achieving a specified goal.
|
||||
|
||||
Args:
|
||||
model: The PyTorch or TensorFlow model used for evaluation.
|
||||
query_budget: The maximum number of model queries allowed.
|
||||
model: The model used for evaluation.
|
||||
query_budget (float): The maximum number of model queries allowed.
|
||||
model_batch_size (int): The batch size for making calls to the model
|
||||
model_cache_size (int): The maximum number of items to keep in the model
|
||||
results cache at once
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, model, tokenizer=None, use_cache=True, query_budget=float("inf")
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
use_cache=True,
|
||||
query_budget=float("inf"),
|
||||
model_batch_size=32,
|
||||
model_cache_size=2 ** 18,
|
||||
):
|
||||
validators.validate_model_goal_function_compatibility(
|
||||
self.__class__, model.__class__
|
||||
@@ -35,14 +44,15 @@ class GoalFunction:
|
||||
self.use_cache = use_cache
|
||||
self.num_queries = 0
|
||||
self.query_budget = query_budget
|
||||
self.model_batch_size = model_batch_size
|
||||
if self.use_cache:
|
||||
self._call_model_cache = lru.LRU(utils.config("MODEL_CACHE_SIZE"))
|
||||
self._call_model_cache = lru.LRU(model_cache_size)
|
||||
else:
|
||||
self._call_model_cache = None
|
||||
|
||||
def should_skip(self, attacked_text, ground_truth_output):
|
||||
"""
|
||||
Returns whether or not the goal has already been completed for ``attacked_text``\,
|
||||
Returns whether or not the goal has already been completed for ``attacked_text``,
|
||||
due to misprediction by the model.
|
||||
"""
|
||||
model_outputs = self._call_model([attacked_text])
|
||||
@@ -125,7 +135,9 @@ class GoalFunction:
|
||||
ids = utils.batch_tokenize(self.tokenizer, attacked_text_list)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = batch_model_predict(self.model, ids)
|
||||
outputs = batch_model_predict(
|
||||
self.model, ids, batch_size=self.model_batch_size
|
||||
)
|
||||
|
||||
return self._process_model_outputs(attacked_text_list, outputs)
|
||||
|
||||
|
||||
@@ -12,8 +12,6 @@ class AttackLogManager:
|
||||
def __init__(self):
|
||||
self.loggers = []
|
||||
self.results = []
|
||||
self.max_words_changed = 0
|
||||
self.max_seq_len = 2 ** 16
|
||||
|
||||
def enable_stdout(self):
|
||||
self.loggers.append(FileLogger(stdout=True))
|
||||
@@ -71,7 +69,9 @@ class AttackLogManager:
|
||||
# Count things about attacks.
|
||||
all_num_words = np.zeros(len(self.results))
|
||||
perturbed_word_percentages = np.zeros(len(self.results))
|
||||
num_words_changed_until_success = np.zeros(self.max_seq_len)
|
||||
num_words_changed_until_success = np.zeros(
|
||||
2 ** 16
|
||||
) # @ TODO: be smarter about this
|
||||
failed_attacks = 0
|
||||
skipped_attacks = 0
|
||||
successful_attacks = 0
|
||||
@@ -156,7 +156,7 @@ class AttackLogManager:
|
||||
summary_table_rows, "Attack Results", "attack_results_summary"
|
||||
)
|
||||
# Show histogram of words changed.
|
||||
numbins = max(self.max_words_changed, 10)
|
||||
numbins = max(max_words_changed, 10)
|
||||
for logger in self.loggers:
|
||||
logger.log_hist(
|
||||
num_words_changed_until_success[:numbins],
|
||||
|
||||
@@ -6,7 +6,6 @@ Entailment by Jin et. al, 2019.
|
||||
See https://arxiv.org/abs/1907.11932 and https://github.com/jind11/TextFooler.
|
||||
"""
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
from textattack.search_methods import SearchMethod
|
||||
@@ -56,15 +55,19 @@ class GreedyWordSwapWIR(SearchMethod):
|
||||
leave_one_texts = [
|
||||
attacked_text.delete_word_at_index(i) for i in range(len_text)
|
||||
]
|
||||
leave_one_scores = self._get_index_order(initial_result, leave_one_texts)
|
||||
leave_one_scores, search_over = self._get_index_order(
|
||||
initial_result, leave_one_texts
|
||||
)
|
||||
elif self.wir_method == "random":
|
||||
leave_one_scores = torch.random(len_text)
|
||||
index_order = np.arange(len_text)
|
||||
np.random.shuffle(index_order)
|
||||
search_over = False
|
||||
|
||||
if self.ascending:
|
||||
index_order = (leave_one_scores).argsort()
|
||||
else:
|
||||
index_order = (-leave_one_scores).argsort()
|
||||
if self.wir_method != "random":
|
||||
if self.ascending:
|
||||
index_order = (leave_one_scores).argsort()
|
||||
else:
|
||||
index_order = (-leave_one_scores).argsort()
|
||||
|
||||
i = 0
|
||||
results = None
|
||||
|
||||
@@ -27,6 +27,7 @@ class Attack:
|
||||
constraints: A list of constraints to add to the attack, defining which perturbations are valid.
|
||||
transformation: The transformation applied at each step of the attack.
|
||||
search_method: A strategy for exploring the search space of possible perturbations
|
||||
constraint_cache_size (int): the number of items to keep in the constraints cache
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -35,6 +36,7 @@ class Attack:
|
||||
constraints=[],
|
||||
transformation=None,
|
||||
search_method=None,
|
||||
constraint_cache_size=2 ** 18,
|
||||
):
|
||||
""" Initialize an attack object. Attacks can be run multiple times. """
|
||||
self.goal_function = goal_function
|
||||
@@ -68,7 +70,8 @@ class Attack:
|
||||
else:
|
||||
self.constraints.append(constraint)
|
||||
|
||||
self.constraints_cache = lru.LRU(utils.config("CONSTRAINT_CACHE_SIZE"))
|
||||
self.constraint_cache_size = constraint_cache_size
|
||||
self.constraints_cache = lru.LRU(constraint_cache_size)
|
||||
|
||||
# Give search method access to functions for getting transformations and evaluating them
|
||||
self.search_method.get_transformations = self.get_transformations
|
||||
@@ -124,10 +127,10 @@ class Attack:
|
||||
)
|
||||
# Default to false for all original transformations.
|
||||
for original_transformed_text in transformed_texts:
|
||||
self.constraints_cache[original_transformed_text] = False
|
||||
self.constraints_cache[(current_text, original_transformed_text)] = False
|
||||
# Set unfiltered transformations to True in the cache.
|
||||
for filtered_text in filtered_texts:
|
||||
self.constraints_cache[filtered_text] = True
|
||||
self.constraints_cache[(current_text, filtered_text)] = True
|
||||
return filtered_texts
|
||||
|
||||
def _filter_transformations(
|
||||
@@ -145,18 +148,20 @@ class Attack:
|
||||
# Populate cache with transformed_texts
|
||||
uncached_texts = []
|
||||
for transformed_text in transformed_texts:
|
||||
if transformed_text not in self.constraints_cache:
|
||||
if (current_text, transformed_text) not in self.constraints_cache:
|
||||
uncached_texts.append(transformed_text)
|
||||
else:
|
||||
# promote transformed_text to the top of the LRU cache
|
||||
self.constraints_cache[transformed_text] = self.constraints_cache[
|
||||
transformed_text
|
||||
]
|
||||
self.constraints_cache[
|
||||
(current_text, transformed_text)
|
||||
] = self.constraints_cache[(current_text, transformed_text)]
|
||||
self._filter_transformations_uncached(
|
||||
uncached_texts, current_text, original_text=original_text
|
||||
)
|
||||
# Return transformed_texts from cache
|
||||
filtered_texts = [t for t in transformed_texts if self.constraints_cache[t]]
|
||||
filtered_texts = [
|
||||
t for t in transformed_texts if self.constraints_cache[(current_text, t)]
|
||||
]
|
||||
# Sort transformations to ensure order is preserved between runs
|
||||
filtered_texts.sort(key=lambda t: t.text)
|
||||
return filtered_texts
|
||||
|
||||
@@ -10,16 +10,14 @@ import filelock
|
||||
import requests
|
||||
import torch
|
||||
import tqdm
|
||||
import yaml
|
||||
|
||||
|
||||
def path_in_cache(file_path):
|
||||
textattack_cache_dir = config("CACHE_DIR")
|
||||
try:
|
||||
os.makedirs(textattack_cache_dir)
|
||||
os.makedirs(TEXTATTACK_CACHE_DIR)
|
||||
except FileExistsError: # cache path exists
|
||||
pass
|
||||
return os.path.join(textattack_cache_dir, file_path)
|
||||
return os.path.join(TEXTATTACK_CACHE_DIR, file_path)
|
||||
|
||||
|
||||
def s3_url(uri):
|
||||
@@ -48,7 +46,7 @@ def download_if_needed(folder_name):
|
||||
return cache_dest_path
|
||||
# If the file isn't found yet, download the zip file to the cache.
|
||||
downloaded_file = tempfile.NamedTemporaryFile(
|
||||
dir=config("CACHE_DIR"), suffix=".zip", delete=False
|
||||
dir=TEXTATTACK_CACHE_DIR, suffix=".zip", delete=False
|
||||
)
|
||||
http_get(folder_name, downloaded_file)
|
||||
# Move or unzip the file.
|
||||
@@ -107,7 +105,7 @@ logger.propagate = False
|
||||
|
||||
def _post_install():
|
||||
logger.info(
|
||||
"First time importing textattack: downloading remaining required packages."
|
||||
"First time running textattack: downloading remaining required packages."
|
||||
)
|
||||
logger.info("Downloading spaCy required packages.")
|
||||
import spacy
|
||||
@@ -122,28 +120,39 @@ def _post_install():
|
||||
nltk.download("stopwords")
|
||||
|
||||
|
||||
def set_cache_dir(cache_dir):
|
||||
""" Sets all relevant cache directories to ``TA_CACHE_DIR``. """
|
||||
# Tensorflow Hub cache directory
|
||||
os.environ["TFHUB_CACHE_DIR"] = cache_dir
|
||||
# HuggingFace `transformers` cache directory
|
||||
os.environ["PYTORCH_TRANSFORMERS_CACHE"] = cache_dir
|
||||
# HuggingFace `nlp` cache directory
|
||||
os.environ["HF_HOME"] = cache_dir
|
||||
# Basic directory for Linux user-specific non-data files
|
||||
os.environ["XDG_CACHE_HOME"] = cache_dir
|
||||
|
||||
|
||||
def _post_install_if_needed():
|
||||
""" Runs _post_install if hasn't been run since install. """
|
||||
# Check for post-install file.
|
||||
post_install_file_path = os.path.join(config("CACHE_DIR"), "post_install_check")
|
||||
post_install_file_path = path_in_cache("post_install_check")
|
||||
post_install_file_lock_path = post_install_file_path + ".lock"
|
||||
post_install_file_lock = filelock.FileLock(post_install_file_lock_path)
|
||||
post_install_file_lock.acquire()
|
||||
if os.path.exists(post_install_file_path):
|
||||
post_install_file_lock.release()
|
||||
return
|
||||
# Run post-install.
|
||||
_post_install()
|
||||
# Create file that indicates post-install completed.
|
||||
open(post_install_file_path, "w").close()
|
||||
post_install_file_lock.release()
|
||||
|
||||
|
||||
def config(key):
|
||||
return config_dict[key]
|
||||
TEXTATTACK_CACHE_DIR = os.environ.get(
|
||||
"TA_CACHE_DIR", os.path.expanduser("~/.cache/textattack")
|
||||
)
|
||||
if "TA_CACHE_DIR" in os.environ:
|
||||
set_cache_dir(os.environ["TA_CACHE_DIR"])
|
||||
|
||||
|
||||
config_dict = {
|
||||
"CACHE_DIR": os.environ.get(
|
||||
"TA_CACHE_DIR", os.path.expanduser("~/.cache/textattack")
|
||||
),
|
||||
}
|
||||
config_path = download_if_needed("config.yaml")
|
||||
with open(config_path, "r") as f:
|
||||
config_dict.update(yaml.load(f, Loader=yaml.FullLoader))
|
||||
_post_install_if_needed()
|
||||
|
||||
@@ -86,3 +86,4 @@ def set_seed(random_seed):
|
||||
random.seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
torch.manual_seed(random_seed)
|
||||
torch.cuda.manual_seed(random_seed)
|
||||
|
||||
@@ -13,7 +13,7 @@ def batch_tokenize(tokenizer, attacked_text_list):
|
||||
return [tokenizer.encode(x) for x in inputs]
|
||||
|
||||
|
||||
def batch_model_predict(model, inputs, batch_size=utils.config("MODEL_BATCH_SIZE")):
|
||||
def batch_model_predict(model, inputs, batch_size=32):
|
||||
outputs = []
|
||||
i = 0
|
||||
while i < len(inputs):
|
||||
|
||||
Reference in New Issue
Block a user