mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
update args for nlp, add sst models to our model hub
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import textattack
|
||||
import torch
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
@@ -46,3 +47,32 @@ def html_table_from_rows(rows, title=None, header=None, style_dict=None):
|
||||
table_html += "</table></div>"
|
||||
|
||||
return table_html
|
||||
|
||||
def load_textattack_model_from_path(model_name, model_path):
|
||||
colored_model_name = textattack.shared.utils.color_text(
|
||||
model_name, color="blue", method="ansi"
|
||||
)
|
||||
if model_name.startswith('lstm'):
|
||||
textattack.shared.logger.info(
|
||||
f"Loading pre-trained TextAttack LSTM: {colored_model_name}"
|
||||
)
|
||||
model = textattack.models.helpers.LSTMForClassification(
|
||||
model_path=model_path
|
||||
)
|
||||
elif model_name.startswith('cnn'):
|
||||
textattack.shared.logger.info(
|
||||
f"Loading pre-trained TextAttack CNN: {colored_model_name}"
|
||||
)
|
||||
model = textattack.models.helpers.WordCNNForClassification(
|
||||
model_path=model_path
|
||||
)
|
||||
elif model_name.startswith('bert'):
|
||||
textattack.shared.logger.info(
|
||||
f"Loading pre-trained TextAttack BERT model: {colored_model_name}"
|
||||
)
|
||||
model = textattack.models.helpers.BERTForClassification(
|
||||
model_path=model_path
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unknown textattack model {model_path}')
|
||||
return model
|
||||
Reference in New Issue
Block a user