mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
update model training code
This commit is contained in:
@@ -303,9 +303,18 @@ def parse_model_from_args(args):
|
||||
model = textattack.shared.utils.load_textattack_model_from_path(
|
||||
args.model, model_path
|
||||
)
|
||||
model = textattack.models.wrappers.PyTorchModelWrapper(
|
||||
model, model.tokenizer, batch_size=args.model_batch_size
|
||||
)
|
||||
# Choose the approprate model wrapper (based on whether or not this is
|
||||
# a HuggingFace model).
|
||||
if isinstance(
|
||||
model, textattack.models.helpers.BERTForClassification
|
||||
) or isinstance(model, textattack.models.helpers.T5ForTextToText):
|
||||
model = textattack.models.wrappers.HuggingFaceModelWrapper(
|
||||
model, model.tokenizer, batch_size=args.model_batch_size
|
||||
)
|
||||
else:
|
||||
model = textattack.models.wrappers.PyTorchModelWrapper(
|
||||
model, model.tokenizer, batch_size=args.model_batch_size
|
||||
)
|
||||
elif args.model and os.path.exists(args.model):
|
||||
# Support loading TextAttack-trained models via just their folder path.
|
||||
# If `args.model` is a path/directory, let's assume it was a model
|
||||
|
||||
Reference in New Issue
Block a user