mirror of
https://github.com/tanyuqian/knowledge-harvest-from-lms.git
synced 2023-06-02 01:35:42 +03:00
updated.
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
import string
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
||||
|
||||
from data_utils.data_utils import stopwords, get_n_ents, get_sent, find_sublist
|
||||
@@ -35,6 +37,7 @@ class LanguageModelWrapper:
|
||||
def get_mask_filling_logprobs(self, prompt, ent_tuple):
|
||||
assert get_n_ents(prompt) == len(ent_tuple)
|
||||
|
||||
ent_tuple = deepcopy(ent_tuple)
|
||||
for ent_idx, ent in enumerate(ent_tuple):
|
||||
if prompt.startswith(f'<ENT{ent_idx}>'):
|
||||
ent_tuple[ent_idx] = ent[0].upper() + ent[1:]
|
||||
|
||||
Reference in New Issue
Block a user