1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

Merge branch 'master' into train

This commit is contained in:
Jack Morris
2020-06-24 16:43:41 -04:00
committed by GitHub
17 changed files with 162 additions and 106 deletions

View File

@@ -113,7 +113,7 @@ Follow these steps to start contributing:
```bash
$ cd TextAttack
$ pip install -e .
$ pip install -e . ".[dev]"
$ pip install black isort pytest pytest-xdist
```

View File

@@ -1,11 +1,10 @@
<h1 align="center">TextAttack 🐙</h1>
<p align="center">Generating adversarial examples for NLP models</p>
<p align="center">
<a href="https://textattack.readthedocs.io/">Docs</a>
<a href="https://textattack.readthedocs.io/">Docs</a>
<br>
<a href="#about">About</a>
<a href="#setup">Setup</a>
<a href="#usage">Usage</a>
@@ -37,10 +36,9 @@ pip install textattack
Once TextAttack is installed, you can run it via command-line (`textattack ...`)
or via the python module (`python -m textattack ...`).
### Configuration
TextAttack downloads files to `~/.cache/textattack/` by default. This includes pretrained models,
dataset samples, and the configuration file `config.yaml`. To change the cache path, set the
environment variable `TA_CACHE_DIR`.
> TextAttack downloads files to `~/.cache/textattack/` by default. This includes pretrained models,
> dataset samples, and the configuration file `config.yaml`. To change the cache path, set the
> environment variable `TA_CACHE_DIR`. (for example: `TA_CACHE_DIR=/tmp/ textattack attack ...`).
## Usage
@@ -51,7 +49,9 @@ information about all commands using `textattack --help`, or a specific command
### Running Attacks
The [`examples/`](docs/examples/) folder contains notebooks explaining basic usage of TextAttack, including building a custom transformation and a custom constraint. These examples can also be viewed through the [documentation website](https://textattack.readthedocs.io/en/latest).
The [`examples/`](docs/examples/) folder contains notebooks explaining basic usage of TextAttack,
including building a custom transformation and a custom constraint. These examples can also be viewed
through the [documentation website](https://textattack.readthedocs.io/en/latest).
The easiest way to try out an attack is via the command-line interface, `textattack attack`. Here are some concrete examples:
@@ -185,9 +185,9 @@ textattack train --model bert-base-uncased --dataset glue:cola --batch-size 32 -
## Design
### TokenizedText
### AttackedText
To allow for word replacement after a sequence has been tokenized, we include a `TokenizedText` object
To allow for word replacement after a sequence has been tokenized, we include an `AttackedText` object
which maintains both a list of tokens and the original text, with punctuation. We use this object in favor of a list of words or just raw text.
### Models and Datasets
@@ -252,23 +252,23 @@ You can then run attacks on samples from this dataset by adding the argument `--
### Attacks
The `attack_one` method in an `Attack` takes as input a `TokenizedText`, and outputs either a `SuccessfulAttackResult` if it succeeds or a `FailedAttackResult` if it fails. We formulate an attack as consisting of four components: a **goal function** which determines if the attack has succeeded, **constraints** defining which perturbations are valid, a **transformation** that generates potential modifications given an input, and a **search method** which traverses through the search space of possible perturbations.
The `attack_one` method in an `Attack` takes as input an `AttackedText`, and outputs either a `SuccessfulAttackResult` if it succeeds or a `FailedAttackResult` if it fails. We formulate an attack as consisting of four components: a **goal function** which determines if the attack has succeeded, **constraints** defining which perturbations are valid, a **transformation** that generates potential modifications given an input, and a **search method** which traverses through the search space of possible perturbations.
### Goal Functions
A `GoalFunction` takes as input a `TokenizedText` object and the ground truth output, and determines whether the attack has succeeded, returning a `GoalFunctionResult`.
A `GoalFunction` takes as input an `AttackedText` object and the ground truth output, and determines whether the attack has succeeded, returning a `GoalFunctionResult`.
### Constraints
A `Constraint` takes as input a current `TokenizedText`, and a list of transformed `TokenizedText`s. For each transformed option, it returns a boolean representing whether the constraint is met.
A `Constraint` takes as input a current `AttackedText`, and a list of transformed `AttackedText`s. For each transformed option, it returns a boolean representing whether the constraint is met.
### Transformations
A `Transformation` takes as input a `TokenizedText` and returns a list of possible transformed `TokenizedText`s. For example, a transformation might return all possible synonym replacements.
A `Transformation` takes as input an `AttackedText` and returns a list of possible transformed `AttackedText`s. For example, a transformation might return all possible synonym replacements.
### Search Methods
A `SearchMethod` takes as input an initial `GoalFunctionResult` and returns a final `GoalFunctionResult` The search is given access to the `get_transformations` function, which takes as input a `TokenizedText` object and outputs a list of possible transformations filtered by meeting all of the attacks constraints. A search consists of successive calls to `get_transformations` until the search succeeds (determined using `get_goal_results`) or is exhausted.
A `SearchMethod` takes as input an initial `GoalFunctionResult` and returns a final `GoalFunctionResult` The search is given access to the `get_transformations` function, which takes as input an `AttackedText` object and outputs a list of possible transformations filtered by meeting all of the attacks constraints. A search consists of successive calls to `get_transformations` until the search succeeds (determined using `get_goal_results`) or is exhausted.
## Contributing to TextAttack

View File

@@ -7,7 +7,6 @@ nlp
nltk
numpy
pandas
pyyaml>=5.1
scikit-learn
scipy==1.4.1
sentence_transformers

View File

@@ -6,6 +6,11 @@ from docs import conf as docs_conf
with open("README.md", "r") as fh:
long_description = fh.read()
extras = {}
# For developers, install development tools along with all optional dependencies.
extras["dev"] = ["black", "isort", "pytest", "pytest-xdist"]
setuptools.setup(
name="textattack",
version=docs_conf.release,
@@ -28,6 +33,7 @@ setuptools.setup(
"wandb*",
]
),
extras_require=extras,
entry_points={
"console_scripts": ["textattack=textattack.commands.textattack_cli:main"],
},

View File

@@ -154,6 +154,8 @@ def parse_goal_function_from_args(args, model):
else:
raise ValueError(f"Error: unsupported goal_function {goal_function}")
goal_function.query_budget = args.query_budget
goal_function.model_batch_size = args.model_batch_size
goal_function.model_cache_size = args.model_cache_size
return goal_function
@@ -192,6 +194,9 @@ def parse_attack_from_args(args):
else:
raise ValueError(f"Invalid recipe {args.recipe}")
recipe.goal_function.query_budget = args.query_budget
recipe.goal_function.model_batch_size = args.model_batch_size
recipe.goal_function.model_cache_size = args.model_cache_size
recipe.constraint_cache_size = args.constraint_cache_size
return recipe
elif args.attack_from_file:
if ":" in args.attack_from_file:
@@ -219,7 +224,11 @@ def parse_attack_from_args(args):
else:
raise ValueError(f"Error: unsupported attack {args.search}")
return textattack.shared.Attack(
goal_function, constraints, transformation, search_method
goal_function,
constraints,
transformation,
search_method,
constraint_cache_size=args.constraint_cache_size,
)
@@ -434,7 +443,6 @@ def parse_checkpoint_from_args(args):
checkpoint_path = args.checkpoint_file
checkpoint = textattack.shared.Checkpoint.load(checkpoint_path)
set_seed(checkpoint.args.random_seed)
return checkpoint
@@ -451,7 +459,6 @@ def merge_checkpoint_args(saved_args, cmdline_args):
""" Merge previously saved arguments for checkpoint and newly entered arguments """
args = copy.deepcopy(saved_args)
# Newly entered arguments take precedence
args.checkpoint_resume = cmdline_args.checkpoint_resume
args.parallel = cmdline_args.parallel
# If set, replace
if cmdline_args.checkpoint_dir:

View File

@@ -157,6 +157,24 @@ class AttackCommand(TextAttackCommand):
default=float("inf"),
help="The maximum number of model queries allowed per example attacked.",
)
parser.add_argument(
"--model-batch-size",
type=int,
default=32,
help="The batch size for making calls to the model.",
)
parser.add_argument(
"--model-cache-size",
type=int,
default=2 ** 18,
help="The maximum number of items to keep in the model results cache at once.",
)
parser.add_argument(
"--constraint-cache-size",
type=int,
default=2 ** 18,
help="The maximum number of items to keep in the constraints cache at once.",
)
attack_group = parser.add_mutually_exclusive_group(required=False)
search_choices = ", ".join(SEARCH_METHOD_CLASS_NAMES.keys())

View File

@@ -1,6 +1,8 @@
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import textattack
from textattack.commands import TextAttackCommand
from textattack.commands.attack.attack_args_helpers import *
class AttackResumeCommand(TextAttackCommand):
@@ -10,9 +12,11 @@ class AttackResumeCommand(TextAttackCommand):
A command line parser to resume a checkpointed attack from user specifications.
"""
def run(self):
textattack.shared.utils.set_seed(self.random_seed)
self.checkpoint_resume = True
def run(self, args):
checkpoint = parse_checkpoint_from_args(args)
args = merge_checkpoint_args(checkpoint.args, args)
textattack.shared.utils.set_seed(args.random_seed)
args.checkpoint_resume = True
# Run attack from checkpoint.
from textattack.commands.attack.run_attack_parallel import run as run_parallel
@@ -20,10 +24,10 @@ class AttackResumeCommand(TextAttackCommand):
run as run_single_threaded,
)
if self.parallel:
run_parallel(self)
if args.parallel:
run_parallel(args, checkpoint=checkpoint)
else:
run_single_threaded(self)
run_single_threaded(args, checkpoint=checkpoint)
@staticmethod
def register_subcommand(main_parser: ArgumentParser):

View File

@@ -48,26 +48,22 @@ def attack_from_queue(args, in_queue, out_queue):
exit()
def run(args):
def run(args, checkpoint=None):
pytorch_multiprocessing_workaround()
if args.checkpoint_resume:
# Override current args with checkpoint args
resume_checkpoint = parse_checkpoint_from_args(args)
args = merge_checkpoint_args(resume_checkpoint.args, args)
num_total_examples = args.num_examples
num_remaining_attacks = resume_checkpoint.num_remaining_attacks
num_total_examples = args.num_examples
worklist = resume_checkpoint.worklist
worklist_tail = resume_checkpoint.worklist_tail
if args.checkpoint_resume:
num_remaining_attacks = checkpoint.num_remaining_attacks
worklist = checkpoint.worklist
worklist_tail = checkpoint.worklist_tail
logger.info(
"Recovered from checkpoint previously saved at {}".format(
resume_checkpoint.datetime
checkpoint.datetime
)
)
print(resume_checkpoint, "\n")
print(checkpoint, "\n")
else:
num_total_examples = args.num_examples
num_remaining_attacks = num_total_examples
worklist = deque(range(0, num_total_examples))
worklist_tail = worklist[-1]
@@ -79,7 +75,7 @@ def run(args):
start_time = time.time()
if args.checkpoint_resume:
attack_log_manager = resume_checkpoint.log_manager
attack_log_manager = checkpoint.log_manager
else:
attack_log_manager = parse_logger_from_args(args)
@@ -107,9 +103,9 @@ def run(args):
)
# Log results asynchronously and update progress bar.
if args.checkpoint_resume:
num_results = resume_checkpoint.results_count
num_failures = resume_checkpoint.num_failed_attacks
num_successes = resume_checkpoint.num_successful_attacks
num_results = checkpoint.results_count
num_failures = checkpoint.num_failed_attacks
num_successes = checkpoint.num_successful_attacks
else:
num_results = 0
num_failures = 0
@@ -157,10 +153,10 @@ def run(args):
args.checkpoint_interval
and len(attack_log_manager.results) % args.checkpoint_interval == 0
):
checkpoint = textattack.shared.Checkpoint(
new_checkpoint = textattack.shared.Checkpoint(
args, attack_log_manager, worklist, worklist_tail
)
checkpoint.save()
new_checkpoint.save()
attack_log_manager.flush()
pbar.close()

View File

@@ -16,7 +16,7 @@ from .attack_args_helpers import *
logger = textattack.shared.logger
def run(args):
def run(args, checkpoint=None):
# Only use one GPU, if we have one.
if "CUDA_VISIBLE_DEVICES" not in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
@@ -28,20 +28,16 @@ def run(args):
os.environ["TFHUB_CACHE_DIR"] = os.path.expanduser("~/.cache/tensorflow-hub")
if args.checkpoint_resume:
# Override current args with checkpoint args
resume_checkpoint = parse_checkpoint_from_args(args)
args = merge_checkpoint_args(resume_checkpoint.args, args)
num_remaining_attacks = resume_checkpoint.num_remaining_attacks
worklist = resume_checkpoint.worklist
worklist_tail = resume_checkpoint.worklist_tail
num_remaining_attacks = checkpoint.num_remaining_attacks
worklist = checkpoint.worklist
worklist_tail = checkpoint.worklist_tail
logger.info(
"Recovered from checkpoint previously saved at {}".format(
resume_checkpoint.datetime
checkpoint.datetime
)
)
print(resume_checkpoint, "\n")
print(checkpoint, "\n")
else:
num_remaining_attacks = args.num_examples
worklist = deque(range(0, args.num_examples))
@@ -55,7 +51,7 @@ def run(args):
# Logger
if args.checkpoint_resume:
attack_log_manager = resume_checkpoint.log_manager
attack_log_manager = checkpoint.log_manager
else:
attack_log_manager = parse_logger_from_args(args)
@@ -89,9 +85,9 @@ def run(args):
pbar = tqdm.tqdm(total=num_remaining_attacks, smoothing=0)
if args.checkpoint_resume:
num_results = resume_checkpoint.results_count
num_failures = resume_checkpoint.num_failed_attacks
num_successes = resume_checkpoint.num_successful_attacks
num_results = checkpoint.results_count
num_failures = checkpoint.num_failed_attacks
num_successes = checkpoint.num_successful_attacks
else:
num_results = 0
num_failures = 0
@@ -128,10 +124,10 @@ def run(args):
args.checkpoint_interval
and len(attack_log_manager.results) % args.checkpoint_interval == 0
):
checkpoint = textattack.shared.Checkpoint(
new_checkpoint = textattack.shared.Checkpoint(
args, attack_log_manager, worklist, worklist_tail
)
checkpoint.save()
new_checkpoint.save()
attack_log_manager.flush()
pbar.close()

View File

@@ -19,7 +19,7 @@ class UntargetedClassification(ClassificationGoalFunction):
def _is_goal_complete(self, model_output, ground_truth_output):
if self.target_max_score:
return model_output[ground_truth_output] < self.target_max_score
elif (model_output.numel() is 1) and isinstance(ground_truth_output, float):
elif (model_output.numel() == 1) and isinstance(ground_truth_output, float):
return abs(ground_truth_output - model_output.item()) >= (
self.target_max_score or 0.5
)
@@ -29,7 +29,7 @@ class UntargetedClassification(ClassificationGoalFunction):
def _get_score(self, model_output, ground_truth_output):
# If the model outputs a single number and the ground truth output is
# a float, we assume that this is a regression task.
if (model_output.numel() is 1) and isinstance(ground_truth_output, float):
if (model_output.numel() == 1) and isinstance(ground_truth_output, float):
return abs(model_output.item() - ground_truth_output)
else:
return 1 - model_output[ground_truth_output]

View File

@@ -13,12 +13,21 @@ class GoalFunction:
Evaluates how well a perturbed attacked_text object is achieving a specified goal.
Args:
model: The PyTorch or TensorFlow model used for evaluation.
query_budget: The maximum number of model queries allowed.
model: The model used for evaluation.
query_budget (float): The maximum number of model queries allowed.
model_batch_size (int): The batch size for making calls to the model
model_cache_size (int): The maximum number of items to keep in the model
results cache at once
"""
def __init__(
self, model, tokenizer=None, use_cache=True, query_budget=float("inf")
self,
model,
tokenizer=None,
use_cache=True,
query_budget=float("inf"),
model_batch_size=32,
model_cache_size=2 ** 18,
):
validators.validate_model_goal_function_compatibility(
self.__class__, model.__class__
@@ -35,14 +44,15 @@ class GoalFunction:
self.use_cache = use_cache
self.num_queries = 0
self.query_budget = query_budget
self.model_batch_size = model_batch_size
if self.use_cache:
self._call_model_cache = lru.LRU(utils.config("MODEL_CACHE_SIZE"))
self._call_model_cache = lru.LRU(model_cache_size)
else:
self._call_model_cache = None
def should_skip(self, attacked_text, ground_truth_output):
"""
Returns whether or not the goal has already been completed for ``attacked_text``\,
Returns whether or not the goal has already been completed for ``attacked_text``,
due to misprediction by the model.
"""
model_outputs = self._call_model([attacked_text])
@@ -125,7 +135,9 @@ class GoalFunction:
ids = utils.batch_tokenize(self.tokenizer, attacked_text_list)
with torch.no_grad():
outputs = batch_model_predict(self.model, ids)
outputs = batch_model_predict(
self.model, ids, batch_size=self.model_batch_size
)
return self._process_model_outputs(attacked_text_list, outputs)

View File

@@ -12,8 +12,6 @@ class AttackLogManager:
def __init__(self):
self.loggers = []
self.results = []
self.max_words_changed = 0
self.max_seq_len = 2 ** 16
def enable_stdout(self):
self.loggers.append(FileLogger(stdout=True))
@@ -71,7 +69,9 @@ class AttackLogManager:
# Count things about attacks.
all_num_words = np.zeros(len(self.results))
perturbed_word_percentages = np.zeros(len(self.results))
num_words_changed_until_success = np.zeros(self.max_seq_len)
num_words_changed_until_success = np.zeros(
2 ** 16
) # @ TODO: be smarter about this
failed_attacks = 0
skipped_attacks = 0
successful_attacks = 0
@@ -156,7 +156,7 @@ class AttackLogManager:
summary_table_rows, "Attack Results", "attack_results_summary"
)
# Show histogram of words changed.
numbins = max(self.max_words_changed, 10)
numbins = max(max_words_changed, 10)
for logger in self.loggers:
logger.log_hist(
num_words_changed_until_success[:numbins],

View File

@@ -6,7 +6,6 @@ Entailment by Jin et. al, 2019.
See https://arxiv.org/abs/1907.11932 and https://github.com/jind11/TextFooler.
"""
import numpy as np
from textattack.search_methods import SearchMethod
@@ -56,15 +55,19 @@ class GreedyWordSwapWIR(SearchMethod):
leave_one_texts = [
attacked_text.delete_word_at_index(i) for i in range(len_text)
]
leave_one_scores = self._get_index_order(initial_result, leave_one_texts)
leave_one_scores, search_over = self._get_index_order(
initial_result, leave_one_texts
)
elif self.wir_method == "random":
leave_one_scores = torch.random(len_text)
index_order = np.arange(len_text)
np.random.shuffle(index_order)
search_over = False
if self.ascending:
index_order = (leave_one_scores).argsort()
else:
index_order = (-leave_one_scores).argsort()
if self.wir_method != "random":
if self.ascending:
index_order = (leave_one_scores).argsort()
else:
index_order = (-leave_one_scores).argsort()
i = 0
results = None

View File

@@ -27,6 +27,7 @@ class Attack:
constraints: A list of constraints to add to the attack, defining which perturbations are valid.
transformation: The transformation applied at each step of the attack.
search_method: A strategy for exploring the search space of possible perturbations
constraint_cache_size (int): the number of items to keep in the constraints cache
"""
def __init__(
@@ -35,6 +36,7 @@ class Attack:
constraints=[],
transformation=None,
search_method=None,
constraint_cache_size=2 ** 18,
):
""" Initialize an attack object. Attacks can be run multiple times. """
self.goal_function = goal_function
@@ -68,7 +70,8 @@ class Attack:
else:
self.constraints.append(constraint)
self.constraints_cache = lru.LRU(utils.config("CONSTRAINT_CACHE_SIZE"))
self.constraint_cache_size = constraint_cache_size
self.constraints_cache = lru.LRU(constraint_cache_size)
# Give search method access to functions for getting transformations and evaluating them
self.search_method.get_transformations = self.get_transformations
@@ -124,10 +127,10 @@ class Attack:
)
# Default to false for all original transformations.
for original_transformed_text in transformed_texts:
self.constraints_cache[original_transformed_text] = False
self.constraints_cache[(current_text, original_transformed_text)] = False
# Set unfiltered transformations to True in the cache.
for filtered_text in filtered_texts:
self.constraints_cache[filtered_text] = True
self.constraints_cache[(current_text, filtered_text)] = True
return filtered_texts
def _filter_transformations(
@@ -145,18 +148,20 @@ class Attack:
# Populate cache with transformed_texts
uncached_texts = []
for transformed_text in transformed_texts:
if transformed_text not in self.constraints_cache:
if (current_text, transformed_text) not in self.constraints_cache:
uncached_texts.append(transformed_text)
else:
# promote transformed_text to the top of the LRU cache
self.constraints_cache[transformed_text] = self.constraints_cache[
transformed_text
]
self.constraints_cache[
(current_text, transformed_text)
] = self.constraints_cache[(current_text, transformed_text)]
self._filter_transformations_uncached(
uncached_texts, current_text, original_text=original_text
)
# Return transformed_texts from cache
filtered_texts = [t for t in transformed_texts if self.constraints_cache[t]]
filtered_texts = [
t for t in transformed_texts if self.constraints_cache[(current_text, t)]
]
# Sort transformations to ensure order is preserved between runs
filtered_texts.sort(key=lambda t: t.text)
return filtered_texts

View File

@@ -10,16 +10,14 @@ import filelock
import requests
import torch
import tqdm
import yaml
def path_in_cache(file_path):
textattack_cache_dir = config("CACHE_DIR")
try:
os.makedirs(textattack_cache_dir)
os.makedirs(TEXTATTACK_CACHE_DIR)
except FileExistsError: # cache path exists
pass
return os.path.join(textattack_cache_dir, file_path)
return os.path.join(TEXTATTACK_CACHE_DIR, file_path)
def s3_url(uri):
@@ -48,7 +46,7 @@ def download_if_needed(folder_name):
return cache_dest_path
# If the file isn't found yet, download the zip file to the cache.
downloaded_file = tempfile.NamedTemporaryFile(
dir=config("CACHE_DIR"), suffix=".zip", delete=False
dir=TEXTATTACK_CACHE_DIR, suffix=".zip", delete=False
)
http_get(folder_name, downloaded_file)
# Move or unzip the file.
@@ -107,7 +105,7 @@ logger.propagate = False
def _post_install():
logger.info(
"First time importing textattack: downloading remaining required packages."
"First time running textattack: downloading remaining required packages."
)
logger.info("Downloading spaCy required packages.")
import spacy
@@ -122,28 +120,39 @@ def _post_install():
nltk.download("stopwords")
def set_cache_dir(cache_dir):
""" Sets all relevant cache directories to ``TA_CACHE_DIR``. """
# Tensorflow Hub cache directory
os.environ["TFHUB_CACHE_DIR"] = cache_dir
# HuggingFace `transformers` cache directory
os.environ["PYTORCH_TRANSFORMERS_CACHE"] = cache_dir
# HuggingFace `nlp` cache directory
os.environ["HF_HOME"] = cache_dir
# Basic directory for Linux user-specific non-data files
os.environ["XDG_CACHE_HOME"] = cache_dir
def _post_install_if_needed():
""" Runs _post_install if hasn't been run since install. """
# Check for post-install file.
post_install_file_path = os.path.join(config("CACHE_DIR"), "post_install_check")
post_install_file_path = path_in_cache("post_install_check")
post_install_file_lock_path = post_install_file_path + ".lock"
post_install_file_lock = filelock.FileLock(post_install_file_lock_path)
post_install_file_lock.acquire()
if os.path.exists(post_install_file_path):
post_install_file_lock.release()
return
# Run post-install.
_post_install()
# Create file that indicates post-install completed.
open(post_install_file_path, "w").close()
post_install_file_lock.release()
def config(key):
return config_dict[key]
TEXTATTACK_CACHE_DIR = os.environ.get(
"TA_CACHE_DIR", os.path.expanduser("~/.cache/textattack")
)
if "TA_CACHE_DIR" in os.environ:
set_cache_dir(os.environ["TA_CACHE_DIR"])
config_dict = {
"CACHE_DIR": os.environ.get(
"TA_CACHE_DIR", os.path.expanduser("~/.cache/textattack")
),
}
config_path = download_if_needed("config.yaml")
with open(config_path, "r") as f:
config_dict.update(yaml.load(f, Loader=yaml.FullLoader))
_post_install_if_needed()

View File

@@ -86,3 +86,4 @@ def set_seed(random_seed):
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)

View File

@@ -13,7 +13,7 @@ def batch_tokenize(tokenizer, attacked_text_list):
return [tokenizer.encode(x) for x in inputs]
def batch_model_predict(model, inputs, batch_size=utils.config("MODEL_BATCH_SIZE")):
def batch_model_predict(model, inputs, batch_size=32):
outputs = []
i = 0
while i < len(inputs):