mirror of
https://github.com/tanyuqian/knowledge-harvest-from-lms.git
synced 2023-06-02 01:35:42 +03:00
updated
This commit is contained in:
@@ -66,12 +66,15 @@ class KnowledgeHarvester:
|
||||
neg_score = sum(neg_scores) / len(neg_scores)
|
||||
|
||||
self._weighted_prompts[i][1] = \
|
||||
(pos_score - 0.2 * neg_score) / self._prompt_temp
|
||||
(pos_score - neg_score) / self._prompt_temp
|
||||
|
||||
self._weighted_prompts = sorted(
|
||||
self._weighted_prompts,
|
||||
key=lambda t: t[1], reverse=True)[:self._max_n_prompts]
|
||||
|
||||
for prompt, weight in self._weighted_prompts:
|
||||
print(f'{weight:.4f}', prompt)
|
||||
|
||||
norm_weights = softmax([weight for _, weight in self._weighted_prompts])
|
||||
norm_weights[norm_weights < 0.02] = 0.
|
||||
norm_weights /= norm_weights.sum()
|
||||
|
||||
@@ -35,6 +35,10 @@ class LanguageModelWrapper:
|
||||
def get_mask_filling_logprobs(self, prompt, ent_tuple):
|
||||
assert get_n_ents(prompt) == len(ent_tuple)
|
||||
|
||||
for ent_idx, ent in enumerate(ent_tuple):
|
||||
if prompt.startswith(f'<ENT{ent_idx}>'):
|
||||
ent_tuple[ent_idx] = ent[0].upper() + ent[1:]
|
||||
|
||||
sent = get_sent(prompt=prompt, ent_tuple=ent_tuple)
|
||||
mask_spans = self.get_mask_spans(prompt=prompt, ent_tuple=ent_tuple)
|
||||
|
||||
@@ -46,7 +50,7 @@ class LanguageModelWrapper:
|
||||
[sent] * len(mask_positions), return_tensors='pt').to('cuda')
|
||||
label_token_ids = []
|
||||
for i, pos in enumerate(mask_positions):
|
||||
label_token_ids.append(masked_inputs['input_ids'][i][pos])
|
||||
label_token_ids.append(masked_inputs['input_ids'][i][pos].item())
|
||||
masked_inputs['input_ids'][i][mask_positions[i:]] = \
|
||||
self.tokenizer.mask_token_id
|
||||
|
||||
|
||||
Reference in New Issue
Block a user