mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
adding AG News models
This commit is contained in:
@@ -314,7 +314,9 @@ def parse_model_from_args(args):
|
||||
f"Tried to load model from path {args.model} - could not find train_args.json."
|
||||
)
|
||||
model_train_args = json.loads(open(model_args_json_path).read())
|
||||
model_train_args["model"] = args.model
|
||||
if model_train_args["model"] not in {"cnn", "lstm"}:
|
||||
# for huggingface models, set args.model to the path of the model
|
||||
model_train_args["model"] = args.model
|
||||
num_labels = model_train_args["num_labels"]
|
||||
from textattack.commands.train_model.train_args_helpers import (
|
||||
model_from_args,
|
||||
|
||||
Reference in New Issue
Block a user