mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
smart cache dir, synchronous post-install hook, better constraint caching
This commit is contained in:
@@ -113,7 +113,7 @@ Follow these steps to start contributing:
|
||||
|
||||
```bash
|
||||
$ cd TextAttack
|
||||
$ pip install -e .
|
||||
$ pip install -e . ".[dev]"
|
||||
$ pip install black isort pytest pytest-xdist
|
||||
```
|
||||
|
||||
|
||||
16
README.md
16
README.md
@@ -1,11 +1,10 @@
|
||||
|
||||
|
||||
<h1 align="center">TextAttack 🐙</h1>
|
||||
|
||||
<p align="center">Generating adversarial examples for NLP models</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://textattack.readthedocs.io/">Docs</a> •
|
||||
<a href="https://textattack.readthedocs.io/">Docs</a>
|
||||
<br>
|
||||
<a href="#about">About</a> •
|
||||
<a href="#setup">Setup</a> •
|
||||
<a href="#usage">Usage</a> •
|
||||
@@ -37,10 +36,9 @@ pip install textattack
|
||||
Once TextAttack is installed, you can run it via command-line (`textattack ...`)
|
||||
or via the python module (`python -m textattack ...`).
|
||||
|
||||
### Configuration
|
||||
TextAttack downloads files to `~/.cache/textattack/` by default. This includes pretrained models,
|
||||
dataset samples, and the configuration file `config.yaml`. To change the cache path, set the
|
||||
environment variable `TA_CACHE_DIR`.
|
||||
> TextAttack downloads files to `~/.cache/textattack/` by default. This includes pretrained models,
|
||||
> dataset samples, and the configuration file `config.yaml`. To change the cache path, set the
|
||||
> environment variable `TA_CACHE_DIR`. (for example: `TA_CACHE_DIR=/tmp/ textattack attack ...`).
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -51,7 +49,9 @@ information about all commands using `textattack --help`, or a specific command
|
||||
|
||||
### Running Attacks
|
||||
|
||||
The [`examples/`](docs/examples/) folder contains notebooks explaining basic usage of TextAttack, including building a custom transformation and a custom constraint. These examples can also be viewed through the [documentation website](https://textattack.readthedocs.io/en/latest).
|
||||
The [`examples/`](docs/examples/) folder contains notebooks explaining basic usage of TextAttack,
|
||||
including building a custom transformation and a custom constraint. These examples can also be viewed
|
||||
through the [documentation website](https://textattack.readthedocs.io/en/latest).
|
||||
|
||||
The easiest way to try out an attack is via the command-line interface, `textattack attack`. Here are some concrete examples:
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
click
|
||||
editdistance
|
||||
filelock
|
||||
language_tool_python
|
||||
|
||||
6
setup.py
6
setup.py
@@ -6,6 +6,11 @@ from docs import conf as docs_conf
|
||||
with open("README.md", "r") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
extras = {}
|
||||
# For developers, install development tools along with all optional dependencies.
|
||||
extras["dev"] = ["black", "isort", "pytest", "pytest-xdist"]
|
||||
|
||||
|
||||
setuptools.setup(
|
||||
name="textattack",
|
||||
version=docs_conf.release,
|
||||
@@ -28,6 +33,7 @@ setuptools.setup(
|
||||
"wandb*",
|
||||
]
|
||||
),
|
||||
extras_require=extras,
|
||||
entry_points={
|
||||
"console_scripts": ["textattack=textattack.commands.textattack_cli:main"],
|
||||
},
|
||||
|
||||
@@ -153,6 +153,8 @@ def parse_goal_function_from_args(args, model):
|
||||
else:
|
||||
raise ValueError(f"Error: unsupported goal_function {goal_function}")
|
||||
goal_function.query_budget = args.query_budget
|
||||
goal_function.model_batch_size = args.model_batch_size
|
||||
goal_function.model_cache_size = args.model_cache_size
|
||||
return goal_function
|
||||
|
||||
|
||||
@@ -191,6 +193,9 @@ def parse_attack_from_args(args):
|
||||
else:
|
||||
raise ValueError(f"Invalid recipe {args.recipe}")
|
||||
recipe.goal_function.query_budget = args.query_budget
|
||||
recipe.goal_function.model_batch_size = args.model_batch_size
|
||||
recipe.goal_function.model_cache_size = args.model_cache_size
|
||||
recipe.constraint_cache_size = args.constraint_cache_size
|
||||
return recipe
|
||||
elif args.attack_from_file:
|
||||
if ":" in args.attack_from_file:
|
||||
@@ -218,7 +223,11 @@ def parse_attack_from_args(args):
|
||||
else:
|
||||
raise ValueError(f"Error: unsupported attack {args.search}")
|
||||
return textattack.shared.Attack(
|
||||
goal_function, constraints, transformation, search_method
|
||||
goal_function,
|
||||
constraints,
|
||||
transformation,
|
||||
search_method,
|
||||
constraint_cache_size=args.constraint_cache_size,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -157,6 +157,24 @@ class AttackCommand(TextAttackCommand):
|
||||
default=float("inf"),
|
||||
help="The maximum number of model queries allowed per example attacked.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="The batch size for making calls to the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-cache-size",
|
||||
type=int,
|
||||
default=2 ** 18,
|
||||
help="The maximum number of items to keep in the model results cache at once.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--constraint-cache-size",
|
||||
type=int,
|
||||
default=2 ** 18,
|
||||
help="The maximum number of items to keep in the constraints cache at once.",
|
||||
)
|
||||
|
||||
attack_group = parser.add_mutually_exclusive_group(required=False)
|
||||
search_choices = ", ".join(SEARCH_METHOD_CLASS_NAMES.keys())
|
||||
|
||||
@@ -19,7 +19,7 @@ class UntargetedClassification(ClassificationGoalFunction):
|
||||
def _is_goal_complete(self, model_output, ground_truth_output):
|
||||
if self.target_max_score:
|
||||
return model_output[ground_truth_output] < self.target_max_score
|
||||
elif (model_output.numel() is 1) and isinstance(ground_truth_output, float):
|
||||
elif (model_output.numel() == 1) and isinstance(ground_truth_output, float):
|
||||
return abs(ground_truth_output - model_output.item()) >= (
|
||||
self.target_max_score or 0.5
|
||||
)
|
||||
@@ -29,7 +29,7 @@ class UntargetedClassification(ClassificationGoalFunction):
|
||||
def _get_score(self, model_output, ground_truth_output):
|
||||
# If the model outputs a single number and the ground truth output is
|
||||
# a float, we assume that this is a regression task.
|
||||
if (model_output.numel() is 1) and isinstance(ground_truth_output, float):
|
||||
if (model_output.numel() == 1) and isinstance(ground_truth_output, float):
|
||||
return abs(model_output.item() - ground_truth_output)
|
||||
else:
|
||||
return 1 - model_output[ground_truth_output]
|
||||
|
||||
@@ -13,12 +13,21 @@ class GoalFunction:
|
||||
Evaluates how well a perturbed attacked_text object is achieving a specified goal.
|
||||
|
||||
Args:
|
||||
model: The PyTorch or TensorFlow model used for evaluation.
|
||||
query_budget: The maximum number of model queries allowed.
|
||||
model: The model used for evaluation.
|
||||
query_budget (float): The maximum number of model queries allowed.
|
||||
model_batch_size (int): The batch size for making calls to the model
|
||||
model_cache_size (int): The maximum number of items to keep in the model
|
||||
results cache at once
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, model, tokenizer=None, use_cache=True, query_budget=float("inf")
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
use_cache=True,
|
||||
query_budget=float("inf"),
|
||||
model_batch_size=32,
|
||||
model_cache_size=2 ** 18,
|
||||
):
|
||||
validators.validate_model_goal_function_compatibility(
|
||||
self.__class__, model.__class__
|
||||
@@ -35,14 +44,15 @@ class GoalFunction:
|
||||
self.use_cache = use_cache
|
||||
self.num_queries = 0
|
||||
self.query_budget = query_budget
|
||||
self.model_batch_size = model_batch_size
|
||||
if self.use_cache:
|
||||
self._call_model_cache = lru.LRU(utils.config("MODEL_CACHE_SIZE"))
|
||||
self._call_model_cache = lru.LRU(model_cache_size)
|
||||
else:
|
||||
self._call_model_cache = None
|
||||
|
||||
def should_skip(self, attacked_text, ground_truth_output):
|
||||
"""
|
||||
Returns whether or not the goal has already been completed for ``attacked_text``\,
|
||||
Returns whether or not the goal has already been completed for ``attacked_text``,
|
||||
due to misprediction by the model.
|
||||
"""
|
||||
model_outputs = self._call_model([attacked_text])
|
||||
@@ -125,7 +135,9 @@ class GoalFunction:
|
||||
ids = utils.batch_tokenize(self.tokenizer, attacked_text_list)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = batch_model_predict(self.model, ids)
|
||||
outputs = batch_model_predict(
|
||||
self.model, ids, batch_size=self.model_batch_size
|
||||
)
|
||||
|
||||
return self._process_model_outputs(attacked_text_list, outputs)
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ class AttackLogManager:
|
||||
# Count things about attacks.
|
||||
all_num_words = np.zeros(len(self.results))
|
||||
perturbed_word_percentages = np.zeros(len(self.results))
|
||||
num_words_changed_until_success = np.zeros(self.max_seq_len)
|
||||
num_words_changed_until_success = np.zeros(2**16) # @ TODO: be smarter about this
|
||||
failed_attacks = 0
|
||||
skipped_attacks = 0
|
||||
successful_attacks = 0
|
||||
|
||||
@@ -27,6 +27,7 @@ class Attack:
|
||||
constraints: A list of constraints to add to the attack, defining which perturbations are valid.
|
||||
transformation: The transformation applied at each step of the attack.
|
||||
search_method: A strategy for exploring the search space of possible perturbations
|
||||
constraint_cache_size (int): the number of items to keep in the constraints cache
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -35,6 +36,7 @@ class Attack:
|
||||
constraints=[],
|
||||
transformation=None,
|
||||
search_method=None,
|
||||
constraint_cache_size=2 ** 18,
|
||||
):
|
||||
""" Initialize an attack object. Attacks can be run multiple times. """
|
||||
self.goal_function = goal_function
|
||||
@@ -68,7 +70,8 @@ class Attack:
|
||||
else:
|
||||
self.constraints.append(constraint)
|
||||
|
||||
self.constraints_cache = lru.LRU(utils.config("CONSTRAINT_CACHE_SIZE"))
|
||||
self.constraint_cache_size = constraint_cache_size
|
||||
self.constraints_cache = lru.LRU(constraint_cache_size)
|
||||
|
||||
# Give search method access to functions for getting transformations and evaluating them
|
||||
self.search_method.get_transformations = self.get_transformations
|
||||
@@ -124,10 +127,10 @@ class Attack:
|
||||
)
|
||||
# Default to false for all original transformations.
|
||||
for original_transformed_text in transformed_texts:
|
||||
self.constraints_cache[original_transformed_text] = False
|
||||
self.constraints_cache[(current_text, original_transformed_text)] = False
|
||||
# Set unfiltered transformations to True in the cache.
|
||||
for filtered_text in filtered_texts:
|
||||
self.constraints_cache[filtered_text] = True
|
||||
self.constraints_cache[(current_text, filtered_text)] = True
|
||||
return filtered_texts
|
||||
|
||||
def _filter_transformations(
|
||||
@@ -145,18 +148,20 @@ class Attack:
|
||||
# Populate cache with transformed_texts
|
||||
uncached_texts = []
|
||||
for transformed_text in transformed_texts:
|
||||
if transformed_text not in self.constraints_cache:
|
||||
if (current_text, transformed_text) not in self.constraints_cache:
|
||||
uncached_texts.append(transformed_text)
|
||||
else:
|
||||
# promote transformed_text to the top of the LRU cache
|
||||
self.constraints_cache[transformed_text] = self.constraints_cache[
|
||||
transformed_text
|
||||
]
|
||||
self.constraints_cache[
|
||||
(current_text, transformed_text)
|
||||
] = self.constraints_cache[(current_text, transformed_text)]
|
||||
self._filter_transformations_uncached(
|
||||
uncached_texts, current_text, original_text=original_text
|
||||
)
|
||||
# Return transformed_texts from cache
|
||||
filtered_texts = [t for t in transformed_texts if self.constraints_cache[t]]
|
||||
filtered_texts = [
|
||||
t for t in transformed_texts if self.constraints_cache[(current_text, t)]
|
||||
]
|
||||
# Sort transformations to ensure order is preserved between runs
|
||||
filtered_texts.sort(key=lambda t: t.text)
|
||||
return filtered_texts
|
||||
|
||||
@@ -14,12 +14,11 @@ import yaml
|
||||
|
||||
|
||||
def path_in_cache(file_path):
|
||||
textattack_cache_dir = config("CACHE_DIR")
|
||||
try:
|
||||
os.makedirs(textattack_cache_dir)
|
||||
os.makedirs(TEXTATTACK_CACHE_DIR)
|
||||
except FileExistsError: # cache path exists
|
||||
pass
|
||||
return os.path.join(textattack_cache_dir, file_path)
|
||||
return os.path.join(TEXTATTACK_CACHE_DIR, file_path)
|
||||
|
||||
|
||||
def s3_url(uri):
|
||||
@@ -48,7 +47,7 @@ def download_if_needed(folder_name):
|
||||
return cache_dest_path
|
||||
# If the file isn't found yet, download the zip file to the cache.
|
||||
downloaded_file = tempfile.NamedTemporaryFile(
|
||||
dir=config("CACHE_DIR"), suffix=".zip", delete=False
|
||||
dir=TEXTATTACK_CACHE_DIR, suffix=".zip", delete=False
|
||||
)
|
||||
http_get(folder_name, downloaded_file)
|
||||
# Move or unzip the file.
|
||||
@@ -107,7 +106,7 @@ logger.propagate = False
|
||||
|
||||
def _post_install():
|
||||
logger.info(
|
||||
"First time importing textattack: downloading remaining required packages."
|
||||
"First time running textattack: downloading remaining required packages."
|
||||
)
|
||||
logger.info("Downloading spaCy required packages.")
|
||||
import spacy
|
||||
@@ -122,28 +121,39 @@ def _post_install():
|
||||
nltk.download("stopwords")
|
||||
|
||||
|
||||
def _post_install_if_needed():
|
||||
def set_cache_dir(cache_dir):
|
||||
""" Sets all relevant cache directories to ``TA_CACHE_DIR``. """
|
||||
# Tensorflow Hub cache directory
|
||||
os.environ["TFHUB_CACHE_DIR"] = cache_dir
|
||||
# HuggingFace `transformers` cache directory
|
||||
os.environ["PYTORCH_TRANSFORMERS_CACHE"] = cache_dir
|
||||
# HuggingFace `nlp` cache directory
|
||||
os.environ["HF_HOME"] = cache_dir
|
||||
# Basic directory for Linux user-specific non-data files
|
||||
os.environ["XDG_CACHE_HOME"] = cache_dir
|
||||
|
||||
|
||||
def _post_install_if_needed(cache_dir):
|
||||
""" Runs _post_install if hasn't been run since install. """
|
||||
# Check for post-install file.
|
||||
post_install_file_path = os.path.join(config("CACHE_DIR"), "post_install_check")
|
||||
post_install_file_path = os.path.join(cache_dir, "post_install_check")
|
||||
post_install_file_lock_path = post_install_file_path + ".lock"
|
||||
post_install_file_lock = filelock.FileLock(post_install_file_lock_path)
|
||||
post_install_file_lock.acquire()
|
||||
if os.path.exists(post_install_file_path):
|
||||
post_install_file_lock.release()
|
||||
return
|
||||
# Run post-install.
|
||||
_post_install()
|
||||
# Create file that indicates post-install completed.
|
||||
open(post_install_file_path, "w").close()
|
||||
post_install_file_lock.release()
|
||||
|
||||
|
||||
def config(key):
|
||||
return config_dict[key]
|
||||
TEXTATTACK_CACHE_DIR = os.environ.get(
|
||||
"TA_CACHE_DIR", os.path.expanduser("~/.cache/textattack")
|
||||
)
|
||||
if "TA_CACHE_DIR" in os.environ:
|
||||
set_cache_dir(os.environ["TA_CACHE_DIR"])
|
||||
|
||||
|
||||
config_dict = {
|
||||
"CACHE_DIR": os.environ.get(
|
||||
"TA_CACHE_DIR", os.path.expanduser("~/.cache/textattack")
|
||||
),
|
||||
}
|
||||
config_path = download_if_needed("config.yaml")
|
||||
with open(config_path, "r") as f:
|
||||
config_dict.update(yaml.load(f, Loader=yaml.FullLoader))
|
||||
_post_install_if_needed()
|
||||
_post_install_if_needed(TEXTATTACK_CACHE_DIR)
|
||||
|
||||
@@ -86,3 +86,4 @@ def set_seed(random_seed):
|
||||
random.seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
torch.manual_seed(random_seed)
|
||||
torch.cuda.manual_seed(random_seed)
|
||||
|
||||
@@ -13,7 +13,7 @@ def batch_tokenize(tokenizer, attacked_text_list):
|
||||
return [tokenizer.encode(x) for x in inputs]
|
||||
|
||||
|
||||
def batch_model_predict(model, inputs, batch_size=utils.config("MODEL_BATCH_SIZE")):
|
||||
def batch_model_predict(model, inputs, batch_size=32):
|
||||
outputs = []
|
||||
i = 0
|
||||
while i < len(inputs):
|
||||
|
||||
Reference in New Issue
Block a user