mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
update args for nlp, add sst models to our model hub
This commit is contained in:
@@ -12,7 +12,7 @@ import torch
|
||||
|
||||
import textattack
|
||||
|
||||
from .attack_args import *
|
||||
from attack_args import *
|
||||
|
||||
|
||||
def set_seed(random_seed):
|
||||
@@ -44,15 +44,14 @@ def get_args():
|
||||
model_group = parser.add_mutually_exclusive_group()
|
||||
|
||||
model_names = (
|
||||
list(TEXTATTACK_MODEL_CLASS_NAMES.keys())
|
||||
+ list(HUGGINGFACE_DATASET_BY_MODEL.keys())
|
||||
+ list(TEXTATTACK_DATASET_BY_MODEL.keys())
|
||||
list(HUGGINGFACE_DATASET_BY_MODEL.keys())
|
||||
+ list(TEXTATTACK_DATASET_BY_MODEL.keys())
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=False,
|
||||
default="bert-base-uncased-yelp",
|
||||
default=None,
|
||||
choices=model_names,
|
||||
help="The pre-trained model to attack.",
|
||||
)
|
||||
@@ -292,14 +291,12 @@ def get_args():
|
||||
raise ValueError("Cannot use `--checkpoint-interval` with `--shuffle=True`")
|
||||
|
||||
set_seed(args.random_seed)
|
||||
|
||||
|
||||
# Shortcuts for huggingface models using --model.
|
||||
if not args.checkpoint_resume and args.model in HUGGINGFACE_DATASET_BY_MODEL:
|
||||
(
|
||||
args.model_from_huggingface,
|
||||
args.dataset_from_nlp,
|
||||
) = HUGGINGFACE_DATASET_BY_MODEL[args.model]
|
||||
args.model = None
|
||||
_, args.dataset_from_nlp = HUGGINGFACE_DATASET_BY_MODEL[args.model]
|
||||
elif not args.checkpoint_resume and args.model in TEXTATTACK_DATASET_BY_MODEL:
|
||||
_, args.dataset_from_nlp = TEXTATTACK_DATASET_BY_MODEL[args.model]
|
||||
|
||||
return args
|
||||
|
||||
@@ -452,16 +449,18 @@ def parse_model_from_args(args):
|
||||
)
|
||||
model = model.to(textattack.shared.utils.device)
|
||||
setattr(model, "tokenizer", tokenizer)
|
||||
elif args.model_from_huggingface:
|
||||
elif (args.model in HUGGINGFACE_DATASET_BY_MODEL) or args.model_from_huggingface:
|
||||
import transformers
|
||||
|
||||
model_name = args.model if (args.model in HUGGINGFACE_DATASET_BY_MODEL) else args.model_from_huggingface
|
||||
|
||||
if ":" in args.model_from_huggingface:
|
||||
model_class, model_name = args.model_from_huggingface.split(":")
|
||||
if ":" in model_name:
|
||||
model_class, model_name = model_name
|
||||
model_class = eval(f"transformers.{model_class}")
|
||||
else:
|
||||
model_class, model_name = (
|
||||
transformers.AutoModelForSequenceClassification,
|
||||
args.model_from_huggingface,
|
||||
model_name,
|
||||
)
|
||||
colored_model_name = textattack.shared.utils.color_text(
|
||||
model_name, color="blue", method="ansi"
|
||||
@@ -482,48 +481,11 @@ def parse_model_from_args(args):
|
||||
tokenizer = textattack.models.tokenizers.AutoTokenizer("bert-base-uncased")
|
||||
setattr(model, "tokenizer", tokenizer)
|
||||
else:
|
||||
if ":" in args.model:
|
||||
model_name, params = args.model.split(":")
|
||||
colored_model_name = textattack.shared.utils.color_text(
|
||||
model_name, color="blue", method="ansi"
|
||||
)
|
||||
textattack.shared.logger.info(
|
||||
f"Loading pre-trained TextAttack model: {colored_model_name}"
|
||||
)
|
||||
if model_name not in TEXTATTACK_MODEL_CLASS_NAMES:
|
||||
raise ValueError(f"Error: unsupported model {model_name}")
|
||||
model = eval(f"{TEXTATTACK_MODEL_CLASS_NAMES[model_name]}({params})")
|
||||
elif args.model in TEXTATTACK_MODEL_CLASS_NAMES:
|
||||
colored_model_name = textattack.shared.utils.color_text(
|
||||
args.model, color="blue", method="ansi"
|
||||
)
|
||||
textattack.shared.logger.info(
|
||||
f"Loading pre-trained TextAttack model: {colored_model_name}"
|
||||
)
|
||||
model = eval(f"{TEXTATTACK_MODEL_CLASS_NAMES[args.model]}()")
|
||||
elif args.model in TEXTATTACK_DATASET_BY_MODEL:
|
||||
colored_model_name = textattack.shared.utils.color_text(
|
||||
args.model, color="blue", method="ansi"
|
||||
)
|
||||
model_path, args.dataset_from_nlp = TEXTATTACK_DATASET_BY_MODEL[args.model]
|
||||
if args.model.startswith("lstm"):
|
||||
textattack.shared.logger.info(
|
||||
f"Loading pre-trained TextAttack LSTM: {colored_model_name}"
|
||||
)
|
||||
model = textattack.models.helpers.LSTMForClassification(
|
||||
model_path=model_path
|
||||
)
|
||||
elif args.model.startswith("cnn"):
|
||||
textattack.shared.logger.info(
|
||||
f"Loading pre-trained TextAttack CNN: {colored_model_name}"
|
||||
)
|
||||
model = textattack.models.helpers.WordCNNForClassification(
|
||||
model_path=model_path
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown TextAttack pretrained model {args.model}")
|
||||
if args.model in TEXTATTACK_DATASET_BY_MODEL:
|
||||
model_path, _ = TEXTATTACK_DATASET_BY_MODEL[args.model]
|
||||
model = textattack.shared.utils.load_textattack_model_from_path(args.model, model_path)
|
||||
else:
|
||||
raise ValueError(f"Error: unsupported model {args.model}")
|
||||
raise ValueError(f"Error: unsupported TextAttack model {args.model}")
|
||||
return model
|
||||
|
||||
|
||||
@@ -557,14 +519,7 @@ def parse_dataset_from_args(args):
|
||||
*dataset_args, shuffle=args.shuffle
|
||||
)
|
||||
else:
|
||||
if not args.model:
|
||||
raise ValueError("Must supply pretrained model or dataset")
|
||||
elif args.model in DATASET_BY_MODEL:
|
||||
dataset = DATASET_BY_MODEL[args.model](
|
||||
offset=args.num_examples_offset, shuffle=args.shuffle
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Error: unsupported model {args.model}")
|
||||
raise ValueError("Must supply pretrained model or dataset")
|
||||
return dataset
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user