Merge pull request #5 from jubs12/master

Support community models
This commit is contained in:
Julian Salazar
2020-11-09 10:35:49 -08:00
committed by GitHub

View File

@@ -89,8 +89,9 @@ def get_pretrained(ctxs: List[mx.Context], name: str = 'bert-base-en-uncased', p
if name not in SUPPORTED:
logging.warn("Model '{}' not recognized as an MXNet model; treating as PyTorch model".format(name))
model_fullname = name
model_name = model_fullname.split('/')[-1]
if model_fullname.startswith('albert-'):
if model_name.startswith('albert-'):
if params_file is None:
model, loading_info = AlbertForMaskedLMOptimized.from_pretrained(model_fullname, output_loading_info=True)
@@ -100,7 +101,7 @@ def get_pretrained(ctxs: List[mx.Context], name: str = 'bert-base-en-uncased', p
tokenizer = transformers.AlbertTokenizer.from_pretrained(model_fullname)
vocab = None
elif model_fullname.startswith('bert-'):
elif model_name.startswith('bert-'):
if params_file is None:
model, loading_info = BertForMaskedLMOptimized.from_pretrained(model_fullname, output_loading_info=True)
@@ -110,7 +111,7 @@ def get_pretrained(ctxs: List[mx.Context], name: str = 'bert-base-en-uncased', p
tokenizer = transformers.BertTokenizer.from_pretrained(model_fullname)
vocab = None
elif model_fullname.startswith('distilbert-'):
elif model_name.startswith('distilbert-'):
if params_file is None:
model, loading_info = DistilBertForMaskedLMOptimized.from_pretrained(model_fullname, output_loading_info=True)
@@ -120,7 +121,7 @@ def get_pretrained(ctxs: List[mx.Context], name: str = 'bert-base-en-uncased', p
tokenizer = transformers.DistilBertTokenizer.from_pretrained(model_fullname)
vocab = None
elif model_fullname.startswith('xlm-'):
elif model_name.startswith('xlm-'):
model, loading_info = transformers.XLMWithLMHeadModel.from_pretrained(model_fullname, output_loading_info=True)
tokenizer = transformers.XLMTokenizer.from_pretrained(model_fullname)