mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
13 lines
329 B
Python
13 lines
329 B
Python
import os
|
|
|
|
import torch
|
|
|
|
from textattack.shared import utils
|
|
|
|
|
|
def load_cached_state_dict(model_folder_path):
|
|
model_folder_path = utils.download_if_needed(model_folder_path)
|
|
model_path = os.path.join(model_folder_path, "model.bin")
|
|
state_dict = torch.load(model_path, map_location=utils.device)
|
|
return state_dict
|