diff --git a/src/mlm/models/__init__.py b/src/mlm/models/__init__.py index acdee9a..b2c25bb 100644 --- a/src/mlm/models/__init__.py +++ b/src/mlm/models/__init__.py @@ -89,8 +89,9 @@ def get_pretrained(ctxs: List[mx.Context], name: str = 'bert-base-en-uncased', p if name not in SUPPORTED: logging.warn("Model '{}' not recognized as an MXNet model; treating as PyTorch model".format(name)) model_fullname = name + model_name = model_fullname.split('/')[-1] - if model_fullname.startswith('albert-'): + if model_name.startswith('albert-'): if params_file is None: model, loading_info = AlbertForMaskedLMOptimized.from_pretrained(model_fullname, output_loading_info=True) @@ -100,7 +101,7 @@ def get_pretrained(ctxs: List[mx.Context], name: str = 'bert-base-en-uncased', p tokenizer = transformers.AlbertTokenizer.from_pretrained(model_fullname) vocab = None - elif model_fullname.startswith('bert-'): + elif model_name.startswith('bert-'): if params_file is None: model, loading_info = BertForMaskedLMOptimized.from_pretrained(model_fullname, output_loading_info=True) @@ -110,7 +111,7 @@ def get_pretrained(ctxs: List[mx.Context], name: str = 'bert-base-en-uncased', p tokenizer = transformers.BertTokenizer.from_pretrained(model_fullname) vocab = None - elif model_fullname.startswith('distilbert-'): + elif model_name.startswith('distilbert-'): if params_file is None: model, loading_info = DistilBertForMaskedLMOptimized.from_pretrained(model_fullname, output_loading_info=True) @@ -120,7 +121,7 @@ def get_pretrained(ctxs: List[mx.Context], name: str = 'bert-base-en-uncased', p tokenizer = transformers.DistilBertTokenizer.from_pretrained(model_fullname) vocab = None - elif model_fullname.startswith('xlm-'): + elif model_name.startswith('xlm-'): model, loading_info = transformers.XLMWithLMHeadModel.from_pretrained(model_fullname, output_loading_info=True) tokenizer = transformers.XLMTokenizer.from_pretrained(model_fullname)