mirror of
https://github.com/awslabs/mlm-scoring.git
synced 2021-10-10 02:35:08 +03:00
@@ -89,8 +89,9 @@ def get_pretrained(ctxs: List[mx.Context], name: str = 'bert-base-en-uncased', p
|
||||
if name not in SUPPORTED:
|
||||
logging.warn("Model '{}' not recognized as an MXNet model; treating as PyTorch model".format(name))
|
||||
model_fullname = name
|
||||
model_name = model_fullname.split('/')[-1]
|
||||
|
||||
if model_fullname.startswith('albert-'):
|
||||
if model_name.startswith('albert-'):
|
||||
|
||||
if params_file is None:
|
||||
model, loading_info = AlbertForMaskedLMOptimized.from_pretrained(model_fullname, output_loading_info=True)
|
||||
@@ -100,7 +101,7 @@ def get_pretrained(ctxs: List[mx.Context], name: str = 'bert-base-en-uncased', p
|
||||
tokenizer = transformers.AlbertTokenizer.from_pretrained(model_fullname)
|
||||
vocab = None
|
||||
|
||||
elif model_fullname.startswith('bert-'):
|
||||
elif model_name.startswith('bert-'):
|
||||
|
||||
if params_file is None:
|
||||
model, loading_info = BertForMaskedLMOptimized.from_pretrained(model_fullname, output_loading_info=True)
|
||||
@@ -110,7 +111,7 @@ def get_pretrained(ctxs: List[mx.Context], name: str = 'bert-base-en-uncased', p
|
||||
tokenizer = transformers.BertTokenizer.from_pretrained(model_fullname)
|
||||
vocab = None
|
||||
|
||||
elif model_fullname.startswith('distilbert-'):
|
||||
elif model_name.startswith('distilbert-'):
|
||||
|
||||
if params_file is None:
|
||||
model, loading_info = DistilBertForMaskedLMOptimized.from_pretrained(model_fullname, output_loading_info=True)
|
||||
@@ -120,7 +121,7 @@ def get_pretrained(ctxs: List[mx.Context], name: str = 'bert-base-en-uncased', p
|
||||
tokenizer = transformers.DistilBertTokenizer.from_pretrained(model_fullname)
|
||||
vocab = None
|
||||
|
||||
elif model_fullname.startswith('xlm-'):
|
||||
elif model_name.startswith('xlm-'):
|
||||
|
||||
model, loading_info = transformers.XLMWithLMHeadModel.from_pretrained(model_fullname, output_loading_info=True)
|
||||
tokenizer = transformers.XLMTokenizer.from_pretrained(model_fullname)
|
||||
|
||||
Reference in New Issue
Block a user