1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/textattack/models/helpers/utils.py

13 lines
329 B
Python

import os
import torch
from textattack.shared import utils
def load_cached_state_dict(model_folder_path):
model_folder_path = utils.download_if_needed(model_folder_path)
model_path = os.path.join(model_folder_path, "model.bin")
state_dict = torch.load(model_path, map_location=utils.device)
return state_dict