This commit is contained in:
bwt09
2022-06-06 16:11:21 -07:00
parent 3bf52451d4
commit 6f0644bae7

View File

@@ -42,13 +42,13 @@ class KnowledgeHarvester:
def update_prompts(self):
for i, (prompt, _) in enumerate(self._weighted_prompts):
scores = []
pos_scores, neg_scores = [], []
for ent_tuple in self._seed_ent_tuples:
ent_tuple = [ent.replace('_', ' ') for ent in ent_tuple]
score = self.score(prompt=prompt, ent_tuple=ent_tuple)
pos_scores.append(self.score(
prompt=prompt, ent_tuple=ent_tuple))
neg_scores = []
for ent_idx in range(len(ent_tuple)):
for ent_tuple1 in self._seed_ent_tuples:
if ent_tuple1[ent_idx] == ent_tuple[ent_idx]:
@@ -62,11 +62,11 @@ class KnowledgeHarvester:
neg_scores.append(self.score(
prompt=prompt, ent_tuple=ent_tuple_neg))
neg_score = sum(neg_scores) / len(neg_scores)
scores.append(score - neg_score)
pos_score = sum(pos_scores) / len(pos_scores)
neg_score = sum(neg_scores) / len(neg_scores)
self._weighted_prompts[i][1] = \
sum(scores) / len(scores) / self._prompt_temp
(pos_score - 0.2 * neg_score) / self._prompt_temp
self._weighted_prompts = sorted(
self._weighted_prompts,
@@ -108,8 +108,7 @@ class KnowledgeHarvester:
self._weighted_ent_tuples.append([best_ent_tuple, best_score])
self._weighted_ent_tuples = sorted(
self._weighted_ent_tuples,
key=lambda t: t[1], reverse=True)[:self._max_n_ent_tuples]
self._weighted_ent_tuples, key=lambda t: t[1], reverse=True)
norm_weights = softmax(
[weight for _, weight in self._weighted_ent_tuples])