From f9a3154a13810fab62b1701ebd76cc998d688718 Mon Sep 17 00:00:00 2001 From: Juliana Resplande Date: Sat, 7 Nov 2020 14:45:18 +0000 Subject: [PATCH] Support community models --- src/mlm/models/__init__.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/mlm/models/__init__.py b/src/mlm/models/__init__.py index acdee9a..b2c25bb 100644 --- a/src/mlm/models/__init__.py +++ b/src/mlm/models/__init__.py @@ -89,8 +89,9 @@ def get_pretrained(ctxs: List[mx.Context], name: str = 'bert-base-en-uncased', p if name not in SUPPORTED: logging.warn("Model '{}' not recognized as an MXNet model; treating as PyTorch model".format(name)) model_fullname = name + model_name = model_fullname.split('/')[-1] - if model_fullname.startswith('albert-'): + if model_name.startswith('albert-'): if params_file is None: model, loading_info = AlbertForMaskedLMOptimized.from_pretrained(model_fullname, output_loading_info=True) @@ -100,7 +101,7 @@ def get_pretrained(ctxs: List[mx.Context], name: str = 'bert-base-en-uncased', p tokenizer = transformers.AlbertTokenizer.from_pretrained(model_fullname) vocab = None - elif model_fullname.startswith('bert-'): + elif model_name.startswith('bert-'): if params_file is None: model, loading_info = BertForMaskedLMOptimized.from_pretrained(model_fullname, output_loading_info=True) @@ -110,7 +111,7 @@ def get_pretrained(ctxs: List[mx.Context], name: str = 'bert-base-en-uncased', p tokenizer = transformers.BertTokenizer.from_pretrained(model_fullname) vocab = None - elif model_fullname.startswith('distilbert-'): + elif model_name.startswith('distilbert-'): if params_file is None: model, loading_info = DistilBertForMaskedLMOptimized.from_pretrained(model_fullname, output_loading_info=True) @@ -120,7 +121,7 @@ def get_pretrained(ctxs: List[mx.Context], name: str = 'bert-base-en-uncased', p tokenizer = transformers.DistilBertTokenizer.from_pretrained(model_fullname) vocab = None - elif model_fullname.startswith('xlm-'): + elif model_name.startswith('xlm-'): model, loading_info = transformers.XLMWithLMHeadModel.from_pretrained(model_fullname, output_loading_info=True) tokenizer = transformers.XLMTokenizer.from_pretrained(model_fullname)