1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

adding AG News models

This commit is contained in:
Jack Morris
2020-07-07 20:45:51 -04:00
parent ad2dba067e
commit c0a3a734d8
7 changed files with 109 additions and 19 deletions

View File

@@ -314,7 +314,9 @@ def parse_model_from_args(args):
f"Tried to load model from path {args.model} - could not find train_args.json."
)
model_train_args = json.loads(open(model_args_json_path).read())
model_train_args["model"] = args.model
if model_train_args["model"] not in {"cnn", "lstm"}:
# for huggingface models, set args.model to the path of the model
model_train_args["model"] = args.model
num_labels = model_train_args["num_labels"]
from textattack.commands.train_model.train_args_helpers import (
model_from_args,