1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

fix issues to pass tests

This commit is contained in:
Jin Yong Yoo
2021-03-11 10:34:46 -05:00
parent 3854f8f770
commit 0f39de49f7
13 changed files with 49 additions and 72 deletions

2
.gitignore vendored
View File

@@ -45,4 +45,4 @@ checkpoints/
# vim
*.swp
.vscode
.vscode

View File

@@ -12,7 +12,6 @@ Subpackages
.. toctree::
:maxdepth: 6
textattack.commands.training
.. automodule:: textattack.commands.attack_command
:members:

View File

@@ -1,14 +0,0 @@
textattack.commands.training package
========================================
.. automodule:: textattack.commands.training.run_training
:members:
:undoc-members:
:show-inheritance:
.. automodule:: textattack.commands.training.train_args_helpers
:members:
:undoc-members:
:show-inheritance:

View File

@@ -1,11 +1,6 @@
textattack.models.tokenizers package
====================================
.. automodule:: textattack.models.tokenizers
:members:
:undoc-members:
:show-inheritance:
.. automodule:: textattack.models.tokenizers.glove_tokenizer
:members:

View File

@@ -44,6 +44,16 @@ Complete API Reference
:undoc-members:
:show-inheritance:
.. automodule:: textattack.training_args
:members:
:undoc-members:
:show-inheritance:
.. automodule:: textattack.trainer
:members:
:undoc-members:
:show-inheritance:
.. automodule:: textattack.dataset_args
:members:
:undoc-members:
@@ -53,8 +63,3 @@ Complete API Reference
:members:
:undoc-members:
:show-inheritance:
.. automodule:: textattack.training_args
:members:
:undoc-members:
:show-inheritance:

View File

@@ -1,26 +0,0 @@
import textattack
import functools
if __name__ == "__main__":
def _attack_build_fn(model_name):
import textattack
import transformers
model = transformers.AutoModelForSequenceClassification.from_pretrained(
model_name
)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(
model, tokenizer
)
attack = textattack.attack_recipes.PWWSRen2019.build(model_wrapper)
return attack
model_name = "textattack/bert-base-uncased-imdb"
dataset = textattack.datasets.HuggingFaceDataset("imdb", None, split="test")
attack_args = textattack.AttackArgs(num_examples=50, parallel=True)
attack_build_fn = functools.partial(_attack_build_fn, model_name)
attacker = textattack.Attacker(attack_build_fn, dataset, attack_args)
attacker.attack_dataset()

View File

@@ -17,4 +17,4 @@ tqdm>=4.27,<4.50.0
word2number
num2words
more-itertools
PySocks!=1.5.7,>=1.5.6
PySocks!=1.5.7,>=1.5.6

View File

@@ -1,6 +1,5 @@
import os
import re
import math
from helpers import run_command_and_get_result

View File

@@ -1,3 +1,11 @@
"""
TextAttack Command Package
===========================
"""
from abc import ABC, abstractmethod
from .textattack_command import TextAttackCommand
from . import textattack_cli

View File

@@ -78,10 +78,7 @@ class WordCNNForClassification(nn.Module):
if not os.path.exists(output_path):
os.makedirs(output_path)
state_dict = {k: v.cpu() for k, v in self.state_dict().items()}
torch.save(
state_dict,
os.path.join(output_path, "pytorch_model.bin")
)
torch.save(state_dict, os.path.join(output_path, "pytorch_model.bin"))
with open(os.path.join(output_path, "config.json"), "w") as f:
json.dump(self._config, f)

View File

@@ -15,7 +15,9 @@ class HuggingFaceModelWrapper(PyTorchModelWrapper):
"""Loads a HuggingFace ``transformers`` model and tokenizer."""
def __init__(self, model, tokenizer):
assert isinstance(model, transformers.PreTrainedModel), f"`model` must be of type `transformers.PreTrainedModel`, but got type {type(model)}."
assert isinstance(
model, transformers.PreTrainedModel
), f"`model` must be of type `transformers.PreTrainedModel`, but got type {type(model)}."
assert isinstance(
tokenizer,
(transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast),

View File

@@ -81,11 +81,12 @@ class Trainer:
if not hasattr(model_wrapper, "model"):
raise ValueError("Cannot detect `model` in `model_wrapper`")
else:
assert isinstance(model_wrapper.model, torch.nn.Module), f"`model` in `model_wrapper` must be of type `torch.nn.Module`, but got type `{type(model_wrapper.model)}`."
assert isinstance(
model_wrapper.model, torch.nn.Module
), f"`model` in `model_wrapper` must be of type `torch.nn.Module`, but got type `{type(model_wrapper.model)}`."
if not hasattr(model_wrapper, "tokenizer"):
raise ValueError("Cannot detect `tokenizer` in `model_wrapper`")
self.model_wrapper = model_wrapper
self.task_type = task_type
self.attack = attack
@@ -93,8 +94,6 @@ class Trainer:
self.eval_dataset = eval_dataset
self.training_args = training_args
def _generate_adversarial_examples(self, dataset, epoch, eval_mode=False):
"""Generate adversarial examples using attacker."""
if eval_mode:

View File

@@ -94,16 +94,29 @@ class TrainingArgs:
if self.lr:
self.learning_rate = self.lr
assert self.num_epochs > 0, "`num_epochs` must be greater than 0."
assert self.num_clean_epochs >= 0, "`num_clean_epochs` must be greater than or equal to 0."
assert (
self.num_clean_epochs >= 0
), "`num_clean_epochs` must be greater than or equal to 0."
if self.early_stopping_epochs is not None:
assert self.early_stopping_epochs > 0, "`early_stopping_epochs` must be greater than 0."
assert self.attack_epoch_interval > 0, "`attack_epoch_interval` must be greater than 0."
assert (
self.early_stopping_epochs > 0
), "`early_stopping_epochs` must be greater than 0."
assert (
self.attack_epoch_interval > 0
), "`attack_epoch_interval` must be greater than 0."
assert self.num_warmup_steps > 0, "`num_warmup_steps` must be greater than 0."
assert self.num_train_adv_examples > 0 or self.num_train_adv_examples == -1, "`num_train_adv_examples` must be greater than 0 or equal to -1."
assert self.num_eval_adv_examples > 0 or self.num_eval_adv_examples == -1, "`num_eval_adv_examples` must be greater than 0 or equal to -1."
assert self.gradient_accumulation_steps > 0, "`gradient_accumulation_steps` must be greater than 0."
assert self.num_clean_epochs <= self.num_epochs, f"`num_clean_epochs` cannot be greater than `num_epochs` ({self.num_clean_epochs} > {self.num_epochs})."
assert (
self.num_train_adv_examples > 0 or self.num_train_adv_examples == -1
), "`num_train_adv_examples` must be greater than 0 or equal to -1."
assert (
self.num_eval_adv_examples > 0 or self.num_eval_adv_examples == -1
), "`num_eval_adv_examples` must be greater than 0 or equal to -1."
assert (
self.gradient_accumulation_steps > 0
), "`gradient_accumulation_steps` must be greater than 0."
assert (
self.num_clean_epochs <= self.num_epochs
), f"`num_clean_epochs` cannot be greater than `num_epochs` ({self.num_clean_epochs} > {self.num_epochs})."
@classmethod
def add_parser_args(cls, parser):
@@ -469,7 +482,7 @@ class _CommandLineTrainingArgs:
@classmethod
def create_attack_from_args(cls, args, model_wrapper):
import textattack
import textattack # noqa: F401
assert (
args.attack in ATTACK_RECIPE_NAMES