1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

smart cache dir, synchronous post-install hook, better constraint caching

This commit is contained in:
Jack Morris
2020-06-23 22:35:00 -04:00
parent 3b40a3afb4
commit 2d98ce0114
13 changed files with 108 additions and 48 deletions

View File

@@ -113,7 +113,7 @@ Follow these steps to start contributing:
```bash
$ cd TextAttack
$ pip install -e .
$ pip install -e . ".[dev]"
$ pip install black isort pytest pytest-xdist
```

View File

@@ -1,11 +1,10 @@
<h1 align="center">TextAttack 🐙</h1>
<p align="center">Generating adversarial examples for NLP models</p>
<p align="center">
<a href="https://textattack.readthedocs.io/">Docs</a>
<a href="https://textattack.readthedocs.io/">Docs</a>
<br>
<a href="#about">About</a>
<a href="#setup">Setup</a>
<a href="#usage">Usage</a>
@@ -37,10 +36,9 @@ pip install textattack
Once TextAttack is installed, you can run it via command-line (`textattack ...`)
or via the python module (`python -m textattack ...`).
### Configuration
TextAttack downloads files to `~/.cache/textattack/` by default. This includes pretrained models,
dataset samples, and the configuration file `config.yaml`. To change the cache path, set the
environment variable `TA_CACHE_DIR`.
> TextAttack downloads files to `~/.cache/textattack/` by default. This includes pretrained models,
> dataset samples, and the configuration file `config.yaml`. To change the cache path, set the
> environment variable `TA_CACHE_DIR`. (for example: `TA_CACHE_DIR=/tmp/ textattack attack ...`).
## Usage
@@ -51,7 +49,9 @@ information about all commands using `textattack --help`, or a specific command
### Running Attacks
The [`examples/`](docs/examples/) folder contains notebooks explaining basic usage of TextAttack, including building a custom transformation and a custom constraint. These examples can also be viewed through the [documentation website](https://textattack.readthedocs.io/en/latest).
The [`examples/`](docs/examples/) folder contains notebooks explaining basic usage of TextAttack,
including building a custom transformation and a custom constraint. These examples can also be viewed
through the [documentation website](https://textattack.readthedocs.io/en/latest).
The easiest way to try out an attack is via the command-line interface, `textattack attack`. Here are some concrete examples:

View File

@@ -1,4 +1,3 @@
click
editdistance
filelock
language_tool_python

View File

@@ -6,6 +6,11 @@ from docs import conf as docs_conf
with open("README.md", "r") as fh:
long_description = fh.read()
extras = {}
# For developers, install development tools along with all optional dependencies.
extras["dev"] = ["black", "isort", "pytest", "pytest-xdist"]
setuptools.setup(
name="textattack",
version=docs_conf.release,
@@ -28,6 +33,7 @@ setuptools.setup(
"wandb*",
]
),
extras_require=extras,
entry_points={
"console_scripts": ["textattack=textattack.commands.textattack_cli:main"],
},

View File

@@ -153,6 +153,8 @@ def parse_goal_function_from_args(args, model):
else:
raise ValueError(f"Error: unsupported goal_function {goal_function}")
goal_function.query_budget = args.query_budget
goal_function.model_batch_size = args.model_batch_size
goal_function.model_cache_size = args.model_cache_size
return goal_function
@@ -191,6 +193,9 @@ def parse_attack_from_args(args):
else:
raise ValueError(f"Invalid recipe {args.recipe}")
recipe.goal_function.query_budget = args.query_budget
recipe.goal_function.model_batch_size = args.model_batch_size
recipe.goal_function.model_cache_size = args.model_cache_size
recipe.constraint_cache_size = args.constraint_cache_size
return recipe
elif args.attack_from_file:
if ":" in args.attack_from_file:
@@ -218,7 +223,11 @@ def parse_attack_from_args(args):
else:
raise ValueError(f"Error: unsupported attack {args.search}")
return textattack.shared.Attack(
goal_function, constraints, transformation, search_method
goal_function,
constraints,
transformation,
search_method,
constraint_cache_size=args.constraint_cache_size,
)

View File

@@ -157,6 +157,24 @@ class AttackCommand(TextAttackCommand):
default=float("inf"),
help="The maximum number of model queries allowed per example attacked.",
)
parser.add_argument(
"--model-batch-size",
type=int,
default=32,
help="The batch size for making calls to the model.",
)
parser.add_argument(
"--model-cache-size",
type=int,
default=2 ** 18,
help="The maximum number of items to keep in the model results cache at once.",
)
parser.add_argument(
"--constraint-cache-size",
type=int,
default=2 ** 18,
help="The maximum number of items to keep in the constraints cache at once.",
)
attack_group = parser.add_mutually_exclusive_group(required=False)
search_choices = ", ".join(SEARCH_METHOD_CLASS_NAMES.keys())

View File

@@ -19,7 +19,7 @@ class UntargetedClassification(ClassificationGoalFunction):
def _is_goal_complete(self, model_output, ground_truth_output):
if self.target_max_score:
return model_output[ground_truth_output] < self.target_max_score
elif (model_output.numel() is 1) and isinstance(ground_truth_output, float):
elif (model_output.numel() == 1) and isinstance(ground_truth_output, float):
return abs(ground_truth_output - model_output.item()) >= (
self.target_max_score or 0.5
)
@@ -29,7 +29,7 @@ class UntargetedClassification(ClassificationGoalFunction):
def _get_score(self, model_output, ground_truth_output):
# If the model outputs a single number and the ground truth output is
# a float, we assume that this is a regression task.
if (model_output.numel() is 1) and isinstance(ground_truth_output, float):
if (model_output.numel() == 1) and isinstance(ground_truth_output, float):
return abs(model_output.item() - ground_truth_output)
else:
return 1 - model_output[ground_truth_output]

View File

@@ -13,12 +13,21 @@ class GoalFunction:
Evaluates how well a perturbed attacked_text object is achieving a specified goal.
Args:
model: The PyTorch or TensorFlow model used for evaluation.
query_budget: The maximum number of model queries allowed.
model: The model used for evaluation.
query_budget (float): The maximum number of model queries allowed.
model_batch_size (int): The batch size for making calls to the model
model_cache_size (int): The maximum number of items to keep in the model
results cache at once
"""
def __init__(
self, model, tokenizer=None, use_cache=True, query_budget=float("inf")
self,
model,
tokenizer=None,
use_cache=True,
query_budget=float("inf"),
model_batch_size=32,
model_cache_size=2 ** 18,
):
validators.validate_model_goal_function_compatibility(
self.__class__, model.__class__
@@ -35,14 +44,15 @@ class GoalFunction:
self.use_cache = use_cache
self.num_queries = 0
self.query_budget = query_budget
self.model_batch_size = model_batch_size
if self.use_cache:
self._call_model_cache = lru.LRU(utils.config("MODEL_CACHE_SIZE"))
self._call_model_cache = lru.LRU(model_cache_size)
else:
self._call_model_cache = None
def should_skip(self, attacked_text, ground_truth_output):
"""
Returns whether or not the goal has already been completed for ``attacked_text``\,
Returns whether or not the goal has already been completed for ``attacked_text``,
due to misprediction by the model.
"""
model_outputs = self._call_model([attacked_text])
@@ -125,7 +135,9 @@ class GoalFunction:
ids = utils.batch_tokenize(self.tokenizer, attacked_text_list)
with torch.no_grad():
outputs = batch_model_predict(self.model, ids)
outputs = batch_model_predict(
self.model, ids, batch_size=self.model_batch_size
)
return self._process_model_outputs(attacked_text_list, outputs)

View File

@@ -71,7 +71,7 @@ class AttackLogManager:
# Count things about attacks.
all_num_words = np.zeros(len(self.results))
perturbed_word_percentages = np.zeros(len(self.results))
num_words_changed_until_success = np.zeros(self.max_seq_len)
num_words_changed_until_success = np.zeros(2**16) # @ TODO: be smarter about this
failed_attacks = 0
skipped_attacks = 0
successful_attacks = 0

View File

@@ -27,6 +27,7 @@ class Attack:
constraints: A list of constraints to add to the attack, defining which perturbations are valid.
transformation: The transformation applied at each step of the attack.
search_method: A strategy for exploring the search space of possible perturbations
constraint_cache_size (int): the number of items to keep in the constraints cache
"""
def __init__(
@@ -35,6 +36,7 @@ class Attack:
constraints=[],
transformation=None,
search_method=None,
constraint_cache_size=2 ** 18,
):
""" Initialize an attack object. Attacks can be run multiple times. """
self.goal_function = goal_function
@@ -68,7 +70,8 @@ class Attack:
else:
self.constraints.append(constraint)
self.constraints_cache = lru.LRU(utils.config("CONSTRAINT_CACHE_SIZE"))
self.constraint_cache_size = constraint_cache_size
self.constraints_cache = lru.LRU(constraint_cache_size)
# Give search method access to functions for getting transformations and evaluating them
self.search_method.get_transformations = self.get_transformations
@@ -124,10 +127,10 @@ class Attack:
)
# Default to false for all original transformations.
for original_transformed_text in transformed_texts:
self.constraints_cache[original_transformed_text] = False
self.constraints_cache[(current_text, original_transformed_text)] = False
# Set unfiltered transformations to True in the cache.
for filtered_text in filtered_texts:
self.constraints_cache[filtered_text] = True
self.constraints_cache[(current_text, filtered_text)] = True
return filtered_texts
def _filter_transformations(
@@ -145,18 +148,20 @@ class Attack:
# Populate cache with transformed_texts
uncached_texts = []
for transformed_text in transformed_texts:
if transformed_text not in self.constraints_cache:
if (current_text, transformed_text) not in self.constraints_cache:
uncached_texts.append(transformed_text)
else:
# promote transformed_text to the top of the LRU cache
self.constraints_cache[transformed_text] = self.constraints_cache[
transformed_text
]
self.constraints_cache[
(current_text, transformed_text)
] = self.constraints_cache[(current_text, transformed_text)]
self._filter_transformations_uncached(
uncached_texts, current_text, original_text=original_text
)
# Return transformed_texts from cache
filtered_texts = [t for t in transformed_texts if self.constraints_cache[t]]
filtered_texts = [
t for t in transformed_texts if self.constraints_cache[(current_text, t)]
]
# Sort transformations to ensure order is preserved between runs
filtered_texts.sort(key=lambda t: t.text)
return filtered_texts

View File

@@ -14,12 +14,11 @@ import yaml
def path_in_cache(file_path):
textattack_cache_dir = config("CACHE_DIR")
try:
os.makedirs(textattack_cache_dir)
os.makedirs(TEXTATTACK_CACHE_DIR)
except FileExistsError: # cache path exists
pass
return os.path.join(textattack_cache_dir, file_path)
return os.path.join(TEXTATTACK_CACHE_DIR, file_path)
def s3_url(uri):
@@ -48,7 +47,7 @@ def download_if_needed(folder_name):
return cache_dest_path
# If the file isn't found yet, download the zip file to the cache.
downloaded_file = tempfile.NamedTemporaryFile(
dir=config("CACHE_DIR"), suffix=".zip", delete=False
dir=TEXTATTACK_CACHE_DIR, suffix=".zip", delete=False
)
http_get(folder_name, downloaded_file)
# Move or unzip the file.
@@ -107,7 +106,7 @@ logger.propagate = False
def _post_install():
logger.info(
"First time importing textattack: downloading remaining required packages."
"First time running textattack: downloading remaining required packages."
)
logger.info("Downloading spaCy required packages.")
import spacy
@@ -122,28 +121,39 @@ def _post_install():
nltk.download("stopwords")
def _post_install_if_needed():
def set_cache_dir(cache_dir):
""" Sets all relevant cache directories to ``TA_CACHE_DIR``. """
# Tensorflow Hub cache directory
os.environ["TFHUB_CACHE_DIR"] = cache_dir
# HuggingFace `transformers` cache directory
os.environ["PYTORCH_TRANSFORMERS_CACHE"] = cache_dir
# HuggingFace `nlp` cache directory
os.environ["HF_HOME"] = cache_dir
# Basic directory for Linux user-specific non-data files
os.environ["XDG_CACHE_HOME"] = cache_dir
def _post_install_if_needed(cache_dir):
""" Runs _post_install if hasn't been run since install. """
# Check for post-install file.
post_install_file_path = os.path.join(config("CACHE_DIR"), "post_install_check")
post_install_file_path = os.path.join(cache_dir, "post_install_check")
post_install_file_lock_path = post_install_file_path + ".lock"
post_install_file_lock = filelock.FileLock(post_install_file_lock_path)
post_install_file_lock.acquire()
if os.path.exists(post_install_file_path):
post_install_file_lock.release()
return
# Run post-install.
_post_install()
# Create file that indicates post-install completed.
open(post_install_file_path, "w").close()
post_install_file_lock.release()
def config(key):
return config_dict[key]
TEXTATTACK_CACHE_DIR = os.environ.get(
"TA_CACHE_DIR", os.path.expanduser("~/.cache/textattack")
)
if "TA_CACHE_DIR" in os.environ:
set_cache_dir(os.environ["TA_CACHE_DIR"])
config_dict = {
"CACHE_DIR": os.environ.get(
"TA_CACHE_DIR", os.path.expanduser("~/.cache/textattack")
),
}
config_path = download_if_needed("config.yaml")
with open(config_path, "r") as f:
config_dict.update(yaml.load(f, Loader=yaml.FullLoader))
_post_install_if_needed()
_post_install_if_needed(TEXTATTACK_CACHE_DIR)

View File

@@ -86,3 +86,4 @@ def set_seed(random_seed):
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)

View File

@@ -13,7 +13,7 @@ def batch_tokenize(tokenizer, attacked_text_list):
return [tokenizer.encode(x) for x in inputs]
def batch_model_predict(model, inputs, batch_size=utils.config("MODEL_BATCH_SIZE")):
def batch_model_predict(model, inputs, batch_size=32):
outputs = []
i = 0
while i < len(inputs):