mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
fix issues to pass tests
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -45,4 +45,4 @@ checkpoints/
|
||||
# vim
|
||||
*.swp
|
||||
|
||||
.vscode
|
||||
.vscode
|
||||
@@ -12,7 +12,6 @@ Subpackages
|
||||
.. toctree::
|
||||
:maxdepth: 6
|
||||
|
||||
textattack.commands.training
|
||||
|
||||
.. automodule:: textattack.commands.attack_command
|
||||
:members:
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
textattack.commands.training package
|
||||
========================================
|
||||
|
||||
|
||||
.. automodule:: textattack.commands.training.run_training
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
.. automodule:: textattack.commands.training.train_args_helpers
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
@@ -1,11 +1,6 @@
|
||||
textattack.models.tokenizers package
|
||||
====================================
|
||||
|
||||
.. automodule:: textattack.models.tokenizers
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
.. automodule:: textattack.models.tokenizers.glove_tokenizer
|
||||
:members:
|
||||
|
||||
@@ -44,6 +44,16 @@ Complete API Reference
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. automodule:: textattack.training_args
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. automodule:: textattack.trainer
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. automodule:: textattack.dataset_args
|
||||
:members:
|
||||
:undoc-members:
|
||||
@@ -53,8 +63,3 @@ Complete API Reference
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. automodule:: textattack.training_args
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
import textattack
|
||||
import functools
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def _attack_build_fn(model_name):
|
||||
import textattack
|
||||
import transformers
|
||||
|
||||
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
||||
model_name
|
||||
)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
|
||||
model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(
|
||||
model, tokenizer
|
||||
)
|
||||
|
||||
attack = textattack.attack_recipes.PWWSRen2019.build(model_wrapper)
|
||||
return attack
|
||||
|
||||
model_name = "textattack/bert-base-uncased-imdb"
|
||||
dataset = textattack.datasets.HuggingFaceDataset("imdb", None, split="test")
|
||||
attack_args = textattack.AttackArgs(num_examples=50, parallel=True)
|
||||
attack_build_fn = functools.partial(_attack_build_fn, model_name)
|
||||
attacker = textattack.Attacker(attack_build_fn, dataset, attack_args)
|
||||
attacker.attack_dataset()
|
||||
@@ -17,4 +17,4 @@ tqdm>=4.27,<4.50.0
|
||||
word2number
|
||||
num2words
|
||||
more-itertools
|
||||
PySocks!=1.5.7,>=1.5.6
|
||||
PySocks!=1.5.7,>=1.5.6
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
import re
|
||||
import math
|
||||
|
||||
from helpers import run_command_and_get_result
|
||||
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
"""
|
||||
|
||||
TextAttack Command Package
|
||||
===========================
|
||||
|
||||
"""
|
||||
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from .textattack_command import TextAttackCommand
|
||||
from . import textattack_cli
|
||||
|
||||
@@ -78,10 +78,7 @@ class WordCNNForClassification(nn.Module):
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
state_dict = {k: v.cpu() for k, v in self.state_dict().items()}
|
||||
torch.save(
|
||||
state_dict,
|
||||
os.path.join(output_path, "pytorch_model.bin")
|
||||
)
|
||||
torch.save(state_dict, os.path.join(output_path, "pytorch_model.bin"))
|
||||
with open(os.path.join(output_path, "config.json"), "w") as f:
|
||||
json.dump(self._config, f)
|
||||
|
||||
|
||||
@@ -15,7 +15,9 @@ class HuggingFaceModelWrapper(PyTorchModelWrapper):
|
||||
"""Loads a HuggingFace ``transformers`` model and tokenizer."""
|
||||
|
||||
def __init__(self, model, tokenizer):
|
||||
assert isinstance(model, transformers.PreTrainedModel), f"`model` must be of type `transformers.PreTrainedModel`, but got type {type(model)}."
|
||||
assert isinstance(
|
||||
model, transformers.PreTrainedModel
|
||||
), f"`model` must be of type `transformers.PreTrainedModel`, but got type {type(model)}."
|
||||
assert isinstance(
|
||||
tokenizer,
|
||||
(transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast),
|
||||
|
||||
@@ -81,11 +81,12 @@ class Trainer:
|
||||
if not hasattr(model_wrapper, "model"):
|
||||
raise ValueError("Cannot detect `model` in `model_wrapper`")
|
||||
else:
|
||||
assert isinstance(model_wrapper.model, torch.nn.Module), f"`model` in `model_wrapper` must be of type `torch.nn.Module`, but got type `{type(model_wrapper.model)}`."
|
||||
assert isinstance(
|
||||
model_wrapper.model, torch.nn.Module
|
||||
), f"`model` in `model_wrapper` must be of type `torch.nn.Module`, but got type `{type(model_wrapper.model)}`."
|
||||
if not hasattr(model_wrapper, "tokenizer"):
|
||||
raise ValueError("Cannot detect `tokenizer` in `model_wrapper`")
|
||||
|
||||
|
||||
self.model_wrapper = model_wrapper
|
||||
self.task_type = task_type
|
||||
self.attack = attack
|
||||
@@ -93,8 +94,6 @@ class Trainer:
|
||||
self.eval_dataset = eval_dataset
|
||||
self.training_args = training_args
|
||||
|
||||
|
||||
|
||||
def _generate_adversarial_examples(self, dataset, epoch, eval_mode=False):
|
||||
"""Generate adversarial examples using attacker."""
|
||||
if eval_mode:
|
||||
|
||||
@@ -94,16 +94,29 @@ class TrainingArgs:
|
||||
if self.lr:
|
||||
self.learning_rate = self.lr
|
||||
assert self.num_epochs > 0, "`num_epochs` must be greater than 0."
|
||||
assert self.num_clean_epochs >= 0, "`num_clean_epochs` must be greater than or equal to 0."
|
||||
assert (
|
||||
self.num_clean_epochs >= 0
|
||||
), "`num_clean_epochs` must be greater than or equal to 0."
|
||||
if self.early_stopping_epochs is not None:
|
||||
assert self.early_stopping_epochs > 0, "`early_stopping_epochs` must be greater than 0."
|
||||
assert self.attack_epoch_interval > 0, "`attack_epoch_interval` must be greater than 0."
|
||||
assert (
|
||||
self.early_stopping_epochs > 0
|
||||
), "`early_stopping_epochs` must be greater than 0."
|
||||
assert (
|
||||
self.attack_epoch_interval > 0
|
||||
), "`attack_epoch_interval` must be greater than 0."
|
||||
assert self.num_warmup_steps > 0, "`num_warmup_steps` must be greater than 0."
|
||||
assert self.num_train_adv_examples > 0 or self.num_train_adv_examples == -1, "`num_train_adv_examples` must be greater than 0 or equal to -1."
|
||||
assert self.num_eval_adv_examples > 0 or self.num_eval_adv_examples == -1, "`num_eval_adv_examples` must be greater than 0 or equal to -1."
|
||||
assert self.gradient_accumulation_steps > 0, "`gradient_accumulation_steps` must be greater than 0."
|
||||
assert self.num_clean_epochs <= self.num_epochs, f"`num_clean_epochs` cannot be greater than `num_epochs` ({self.num_clean_epochs} > {self.num_epochs})."
|
||||
|
||||
assert (
|
||||
self.num_train_adv_examples > 0 or self.num_train_adv_examples == -1
|
||||
), "`num_train_adv_examples` must be greater than 0 or equal to -1."
|
||||
assert (
|
||||
self.num_eval_adv_examples > 0 or self.num_eval_adv_examples == -1
|
||||
), "`num_eval_adv_examples` must be greater than 0 or equal to -1."
|
||||
assert (
|
||||
self.gradient_accumulation_steps > 0
|
||||
), "`gradient_accumulation_steps` must be greater than 0."
|
||||
assert (
|
||||
self.num_clean_epochs <= self.num_epochs
|
||||
), f"`num_clean_epochs` cannot be greater than `num_epochs` ({self.num_clean_epochs} > {self.num_epochs})."
|
||||
|
||||
@classmethod
|
||||
def add_parser_args(cls, parser):
|
||||
@@ -469,7 +482,7 @@ class _CommandLineTrainingArgs:
|
||||
|
||||
@classmethod
|
||||
def create_attack_from_args(cls, args, model_wrapper):
|
||||
import textattack
|
||||
import textattack # noqa: F401
|
||||
|
||||
assert (
|
||||
args.attack in ATTACK_RECIPE_NAMES
|
||||
|
||||
Reference in New Issue
Block a user