This commit is contained in:
bwt09
2022-06-20 09:42:04 -07:00
parent 33723eea7f
commit 4596c0b8dc
2 changed files with 9 additions and 15 deletions

View File

@@ -124,24 +124,14 @@ class KnowledgeHarvester:
return score
def score(self, prompt, ent_tuple):
fill_in_result = self._model.fill_ent_tuple_in_prompt(
prompt=prompt, ent_tuple=ent_tuple)
logprobs = self._model.fill_ent_tuple_in_prompt(
prompt=prompt, ent_tuple=ent_tuple)['mask_logprobs']
logprobs = []
# encourage entities with multiple words
for mask_span in fill_in_result['mask_spans']:
span_logprobs = []
for i, pos in enumerate(fill_in_result['mask_positions']):
if mask_span[0] <= pos < mask_span[1]:
span_logprobs.append(fill_in_result['mask_logprobs'][i])
logprobs.append(min(span_logprobs))
mean_score = sum(logprobs) / len(logprobs)
token_wise_score = sum(logprobs) / len(logprobs)
ent_wise_score = sum(logprobs) / len(ent_tuple)
min_score = min(logprobs)
return (mean_score * 2. + min_score) / 3.
return (token_wise_score + ent_wise_score + min_score) / 3.
@property
def weighted_ent_tuples(self):
return self._weighted_ent_tuples

View File

@@ -94,6 +94,10 @@ class LanguageModelWrapper:
ent_in_sent = ent_in_sent.split(punc)[0]
ent_in_sent = ent_in_sent.replace(f'<ENT{ent_idx}>', ent)
# only mask the first word in an entity to
# encourage entities with multiple words
ent_in_sent = ent_in_sent.split()[0]
ent_token_ids = self.tokenizer.encode(
f' {ent_in_sent}' if sent[len(prefix)] == ' ' else ent_in_sent,
add_special_tokens=False)