This commit is contained in:
bwt09
2022-06-19 22:45:53 -07:00
parent c06e6125f3
commit 823321048b
2 changed files with 14 additions and 6 deletions

View File

@@ -124,8 +124,19 @@ class KnowledgeHarvester:
return score
def score(self, prompt, ent_tuple):
logprobs = self._model.get_mask_filling_logprobs(
prompt=prompt, ent_tuple=ent_tuple)['mask_logprobs']
fill_in_result = self._model.fill_ent_tuple_in_prompt(
prompt=prompt, ent_tuple=ent_tuple)
logprobs = fill_in_result['mask_logprobs']
# encourage entities with multiple words
for mask_span in fill_in_result['mask_spans']:
ent_in_sent = self._model.tokenizer.decode(
fill_in_result['input_ids'][mask_span[0]:mask_span[1]]).strip()
n_words = len(ent_in_sent.split())
for i, pos in enumerate(fill_in_result['mask_positions']):
if mask_span[0] <= pos < mask_span[1]:
logprobs[i] /= n_words
token_wise_score = sum(logprobs) / len(logprobs)
ent_wise_score = sum(logprobs) / len(ent_tuple)

View File

@@ -34,7 +34,7 @@ class LanguageModelWrapper:
return outputs.logits[
inputs['input_ids'] == self.tokenizer.mask_token_id]
def get_mask_filling_logprobs(self, prompt, ent_tuple):
def fill_ent_tuple_in_prompt(self, prompt, ent_tuple):
assert get_n_ents(prompt) == len(ent_tuple)
ent_tuple = deepcopy(ent_tuple)
@@ -94,9 +94,6 @@ class LanguageModelWrapper:
ent_in_sent = ent_in_sent.split(punc)[0]
ent_in_sent = ent_in_sent.replace(f'<ENT{ent_idx}>', ent)
# a trick to encourage generating longer entities
ent_in_sent = ent_in_sent.split()[0]
ent_token_ids = self.tokenizer.encode(
f' {ent_in_sent}' if sent[len(prefix)] == ' ' else ent_in_sent,
add_special_tokens=False)