1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

update args for nlp, add sst models to our model hub

This commit is contained in:
Jack Morris
2020-06-18 20:49:30 -04:00
27 changed files with 79 additions and 383 deletions

View File

@@ -1,3 +1,4 @@
import textattack
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -46,3 +47,32 @@ def html_table_from_rows(rows, title=None, header=None, style_dict=None):
table_html += "</table></div>"
return table_html
def load_textattack_model_from_path(model_name, model_path):
colored_model_name = textattack.shared.utils.color_text(
model_name, color="blue", method="ansi"
)
if model_name.startswith('lstm'):
textattack.shared.logger.info(
f"Loading pre-trained TextAttack LSTM: {colored_model_name}"
)
model = textattack.models.helpers.LSTMForClassification(
model_path=model_path
)
elif model_name.startswith('cnn'):
textattack.shared.logger.info(
f"Loading pre-trained TextAttack CNN: {colored_model_name}"
)
model = textattack.models.helpers.WordCNNForClassification(
model_path=model_path
)
elif model_name.startswith('bert'):
textattack.shared.logger.info(
f"Loading pre-trained TextAttack BERT model: {colored_model_name}"
)
model = textattack.models.helpers.BERTForClassification(
model_path=model_path
)
else:
raise ValueError(f'Unknown textattack model {model_path}')
return model