1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

update args for nlp, add sst models to our model hub

This commit is contained in:
Jack Morris
2020-06-18 20:49:30 -04:00
27 changed files with 79 additions and 383 deletions

View File

@@ -12,7 +12,7 @@ import torch
import textattack
from .attack_args import *
from attack_args import *
def set_seed(random_seed):
@@ -44,15 +44,14 @@ def get_args():
model_group = parser.add_mutually_exclusive_group()
model_names = (
list(TEXTATTACK_MODEL_CLASS_NAMES.keys())
+ list(HUGGINGFACE_DATASET_BY_MODEL.keys())
+ list(TEXTATTACK_DATASET_BY_MODEL.keys())
list(HUGGINGFACE_DATASET_BY_MODEL.keys())
+ list(TEXTATTACK_DATASET_BY_MODEL.keys())
)
model_group.add_argument(
"--model",
type=str,
required=False,
default="bert-base-uncased-yelp",
default=None,
choices=model_names,
help="The pre-trained model to attack.",
)
@@ -292,14 +291,12 @@ def get_args():
raise ValueError("Cannot use `--checkpoint-interval` with `--shuffle=True`")
set_seed(args.random_seed)
# Shortcuts for huggingface models using --model.
if not args.checkpoint_resume and args.model in HUGGINGFACE_DATASET_BY_MODEL:
(
args.model_from_huggingface,
args.dataset_from_nlp,
) = HUGGINGFACE_DATASET_BY_MODEL[args.model]
args.model = None
_, args.dataset_from_nlp = HUGGINGFACE_DATASET_BY_MODEL[args.model]
elif not args.checkpoint_resume and args.model in TEXTATTACK_DATASET_BY_MODEL:
_, args.dataset_from_nlp = TEXTATTACK_DATASET_BY_MODEL[args.model]
return args
@@ -452,16 +449,18 @@ def parse_model_from_args(args):
)
model = model.to(textattack.shared.utils.device)
setattr(model, "tokenizer", tokenizer)
elif args.model_from_huggingface:
elif (args.model in HUGGINGFACE_DATASET_BY_MODEL) or args.model_from_huggingface:
import transformers
model_name = args.model if (args.model in HUGGINGFACE_DATASET_BY_MODEL) else args.model_from_huggingface
if ":" in args.model_from_huggingface:
model_class, model_name = args.model_from_huggingface.split(":")
if ":" in model_name:
model_class, model_name = model_name
model_class = eval(f"transformers.{model_class}")
else:
model_class, model_name = (
transformers.AutoModelForSequenceClassification,
args.model_from_huggingface,
model_name,
)
colored_model_name = textattack.shared.utils.color_text(
model_name, color="blue", method="ansi"
@@ -482,48 +481,11 @@ def parse_model_from_args(args):
tokenizer = textattack.models.tokenizers.AutoTokenizer("bert-base-uncased")
setattr(model, "tokenizer", tokenizer)
else:
if ":" in args.model:
model_name, params = args.model.split(":")
colored_model_name = textattack.shared.utils.color_text(
model_name, color="blue", method="ansi"
)
textattack.shared.logger.info(
f"Loading pre-trained TextAttack model: {colored_model_name}"
)
if model_name not in TEXTATTACK_MODEL_CLASS_NAMES:
raise ValueError(f"Error: unsupported model {model_name}")
model = eval(f"{TEXTATTACK_MODEL_CLASS_NAMES[model_name]}({params})")
elif args.model in TEXTATTACK_MODEL_CLASS_NAMES:
colored_model_name = textattack.shared.utils.color_text(
args.model, color="blue", method="ansi"
)
textattack.shared.logger.info(
f"Loading pre-trained TextAttack model: {colored_model_name}"
)
model = eval(f"{TEXTATTACK_MODEL_CLASS_NAMES[args.model]}()")
elif args.model in TEXTATTACK_DATASET_BY_MODEL:
colored_model_name = textattack.shared.utils.color_text(
args.model, color="blue", method="ansi"
)
model_path, args.dataset_from_nlp = TEXTATTACK_DATASET_BY_MODEL[args.model]
if args.model.startswith("lstm"):
textattack.shared.logger.info(
f"Loading pre-trained TextAttack LSTM: {colored_model_name}"
)
model = textattack.models.helpers.LSTMForClassification(
model_path=model_path
)
elif args.model.startswith("cnn"):
textattack.shared.logger.info(
f"Loading pre-trained TextAttack CNN: {colored_model_name}"
)
model = textattack.models.helpers.WordCNNForClassification(
model_path=model_path
)
else:
raise ValueError(f"Unknown TextAttack pretrained model {args.model}")
if args.model in TEXTATTACK_DATASET_BY_MODEL:
model_path, _ = TEXTATTACK_DATASET_BY_MODEL[args.model]
model = textattack.shared.utils.load_textattack_model_from_path(args.model, model_path)
else:
raise ValueError(f"Error: unsupported model {args.model}")
raise ValueError(f"Error: unsupported TextAttack model {args.model}")
return model
@@ -557,14 +519,7 @@ def parse_dataset_from_args(args):
*dataset_args, shuffle=args.shuffle
)
else:
if not args.model:
raise ValueError("Must supply pretrained model or dataset")
elif args.model in DATASET_BY_MODEL:
dataset = DATASET_BY_MODEL[args.model](
offset=args.num_examples_offset, shuffle=args.shuffle
)
else:
raise ValueError(f"Error: unsupported model {args.model}")
raise ValueError("Must supply pretrained model or dataset")
return dataset