1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

update model training code

This commit is contained in:
Jack Morris
2020-07-27 10:40:19 -04:00
parent 29432d88c8
commit 618e815eea
9 changed files with 74 additions and 28 deletions

View File

@@ -303,9 +303,18 @@ def parse_model_from_args(args):
model = textattack.shared.utils.load_textattack_model_from_path(
args.model, model_path
)
model = textattack.models.wrappers.PyTorchModelWrapper(
model, model.tokenizer, batch_size=args.model_batch_size
)
# Choose the approprate model wrapper (based on whether or not this is
# a HuggingFace model).
if isinstance(
model, textattack.models.helpers.BERTForClassification
) or isinstance(model, textattack.models.helpers.T5ForTextToText):
model = textattack.models.wrappers.HuggingFaceModelWrapper(
model, model.tokenizer, batch_size=args.model_batch_size
)
else:
model = textattack.models.wrappers.PyTorchModelWrapper(
model, model.tokenizer, batch_size=args.model_batch_size
)
elif args.model and os.path.exists(args.model):
# Support loading TextAttack-trained models via just their folder path.
# If `args.model` is a path/directory, let's assume it was a model