mirror of
https://github.com/awslabs/mlm-scoring.git
synced 2021-10-10 02:35:08 +03:00
@@ -89,8 +89,9 @@ def get_pretrained(ctxs: List[mx.Context], name: str = 'bert-base-en-uncased', p
|
|||||||
if name not in SUPPORTED:
|
if name not in SUPPORTED:
|
||||||
logging.warn("Model '{}' not recognized as an MXNet model; treating as PyTorch model".format(name))
|
logging.warn("Model '{}' not recognized as an MXNet model; treating as PyTorch model".format(name))
|
||||||
model_fullname = name
|
model_fullname = name
|
||||||
|
model_name = model_fullname.split('/')[-1]
|
||||||
|
|
||||||
if model_fullname.startswith('albert-'):
|
if model_name.startswith('albert-'):
|
||||||
|
|
||||||
if params_file is None:
|
if params_file is None:
|
||||||
model, loading_info = AlbertForMaskedLMOptimized.from_pretrained(model_fullname, output_loading_info=True)
|
model, loading_info = AlbertForMaskedLMOptimized.from_pretrained(model_fullname, output_loading_info=True)
|
||||||
@@ -100,7 +101,7 @@ def get_pretrained(ctxs: List[mx.Context], name: str = 'bert-base-en-uncased', p
|
|||||||
tokenizer = transformers.AlbertTokenizer.from_pretrained(model_fullname)
|
tokenizer = transformers.AlbertTokenizer.from_pretrained(model_fullname)
|
||||||
vocab = None
|
vocab = None
|
||||||
|
|
||||||
elif model_fullname.startswith('bert-'):
|
elif model_name.startswith('bert-'):
|
||||||
|
|
||||||
if params_file is None:
|
if params_file is None:
|
||||||
model, loading_info = BertForMaskedLMOptimized.from_pretrained(model_fullname, output_loading_info=True)
|
model, loading_info = BertForMaskedLMOptimized.from_pretrained(model_fullname, output_loading_info=True)
|
||||||
@@ -110,7 +111,7 @@ def get_pretrained(ctxs: List[mx.Context], name: str = 'bert-base-en-uncased', p
|
|||||||
tokenizer = transformers.BertTokenizer.from_pretrained(model_fullname)
|
tokenizer = transformers.BertTokenizer.from_pretrained(model_fullname)
|
||||||
vocab = None
|
vocab = None
|
||||||
|
|
||||||
elif model_fullname.startswith('distilbert-'):
|
elif model_name.startswith('distilbert-'):
|
||||||
|
|
||||||
if params_file is None:
|
if params_file is None:
|
||||||
model, loading_info = DistilBertForMaskedLMOptimized.from_pretrained(model_fullname, output_loading_info=True)
|
model, loading_info = DistilBertForMaskedLMOptimized.from_pretrained(model_fullname, output_loading_info=True)
|
||||||
@@ -120,7 +121,7 @@ def get_pretrained(ctxs: List[mx.Context], name: str = 'bert-base-en-uncased', p
|
|||||||
tokenizer = transformers.DistilBertTokenizer.from_pretrained(model_fullname)
|
tokenizer = transformers.DistilBertTokenizer.from_pretrained(model_fullname)
|
||||||
vocab = None
|
vocab = None
|
||||||
|
|
||||||
elif model_fullname.startswith('xlm-'):
|
elif model_name.startswith('xlm-'):
|
||||||
|
|
||||||
model, loading_info = transformers.XLMWithLMHeadModel.from_pretrained(model_fullname, output_loading_info=True)
|
model, loading_info = transformers.XLMWithLMHeadModel.from_pretrained(model_fullname, output_loading_info=True)
|
||||||
tokenizer = transformers.XLMTokenizer.from_pretrained(model_fullname)
|
tokenizer = transformers.XLMTokenizer.from_pretrained(model_fullname)
|
||||||
|
|||||||
Reference in New Issue
Block a user