This commit is contained in:
bwt09
2022-06-07 01:28:21 -07:00
parent 1f2654572a
commit a5afb91861

View File

@@ -1,5 +1,7 @@
import string
import torch
from copy import deepcopy
from transformers import AutoTokenizer, AutoModelForMaskedLM
from data_utils.data_utils import stopwords, get_n_ents, get_sent, find_sublist
@@ -35,6 +37,7 @@ class LanguageModelWrapper:
def get_mask_filling_logprobs(self, prompt, ent_tuple):
assert get_n_ents(prompt) == len(ent_tuple)
ent_tuple = deepcopy(ent_tuple)
for ent_idx, ent in enumerate(ent_tuple):
if prompt.startswith(f'<ENT{ent_idx}>'):
ent_tuple[ent_idx] = ent[0].upper() + ent[1:]