This commit is contained in:
bwt09
2022-06-06 23:15:09 -07:00
parent b1bf2e13f1
commit 28c809a28c
2 changed files with 9 additions and 2 deletions

View File

@@ -66,12 +66,15 @@ class KnowledgeHarvester:
neg_score = sum(neg_scores) / len(neg_scores)
self._weighted_prompts[i][1] = \
(pos_score - 0.2 * neg_score) / self._prompt_temp
(pos_score - neg_score) / self._prompt_temp
self._weighted_prompts = sorted(
self._weighted_prompts,
key=lambda t: t[1], reverse=True)[:self._max_n_prompts]
for prompt, weight in self._weighted_prompts:
print(f'{weight:.4f}', prompt)
norm_weights = softmax([weight for _, weight in self._weighted_prompts])
norm_weights[norm_weights < 0.02] = 0.
norm_weights /= norm_weights.sum()

View File

@@ -35,6 +35,10 @@ class LanguageModelWrapper:
def get_mask_filling_logprobs(self, prompt, ent_tuple):
assert get_n_ents(prompt) == len(ent_tuple)
for ent_idx, ent in enumerate(ent_tuple):
if prompt.startswith(f'<ENT{ent_idx}>'):
ent_tuple[ent_idx] = ent[0].upper() + ent[1:]
sent = get_sent(prompt=prompt, ent_tuple=ent_tuple)
mask_spans = self.get_mask_spans(prompt=prompt, ent_tuple=ent_tuple)
@@ -46,7 +50,7 @@ class LanguageModelWrapper:
[sent] * len(mask_positions), return_tensors='pt').to('cuda')
label_token_ids = []
for i, pos in enumerate(mask_positions):
label_token_ids.append(masked_inputs['input_ids'][i][pos])
label_token_ids.append(masked_inputs['input_ids'][i][pos].item())
masked_inputs['input_ids'][i][mask_positions[i:]] = \
self.tokenizer.mask_token_id