mirror of
https://github.com/tanyuqian/knowledge-harvest-from-lms.git
synced 2023-06-02 01:35:42 +03:00
updated.
This commit is contained in:
@@ -42,13 +42,13 @@ class KnowledgeHarvester:
|
||||
|
||||
def update_prompts(self):
|
||||
for i, (prompt, _) in enumerate(self._weighted_prompts):
|
||||
scores = []
|
||||
pos_scores, neg_scores = [], []
|
||||
for ent_tuple in self._seed_ent_tuples:
|
||||
ent_tuple = [ent.replace('_', ' ') for ent in ent_tuple]
|
||||
|
||||
score = self.score(prompt=prompt, ent_tuple=ent_tuple)
|
||||
pos_scores.append(self.score(
|
||||
prompt=prompt, ent_tuple=ent_tuple))
|
||||
|
||||
neg_scores = []
|
||||
for ent_idx in range(len(ent_tuple)):
|
||||
for ent_tuple1 in self._seed_ent_tuples:
|
||||
if ent_tuple1[ent_idx] == ent_tuple[ent_idx]:
|
||||
@@ -62,11 +62,11 @@ class KnowledgeHarvester:
|
||||
neg_scores.append(self.score(
|
||||
prompt=prompt, ent_tuple=ent_tuple_neg))
|
||||
|
||||
neg_score = sum(neg_scores) / len(neg_scores)
|
||||
scores.append(score - neg_score)
|
||||
pos_score = sum(pos_scores) / len(pos_scores)
|
||||
neg_score = sum(neg_scores) / len(neg_scores)
|
||||
|
||||
self._weighted_prompts[i][1] = \
|
||||
sum(scores) / len(scores) / self._prompt_temp
|
||||
(pos_score - 0.2 * neg_score) / self._prompt_temp
|
||||
|
||||
self._weighted_prompts = sorted(
|
||||
self._weighted_prompts,
|
||||
@@ -108,8 +108,7 @@ class KnowledgeHarvester:
|
||||
self._weighted_ent_tuples.append([best_ent_tuple, best_score])
|
||||
|
||||
self._weighted_ent_tuples = sorted(
|
||||
self._weighted_ent_tuples,
|
||||
key=lambda t: t[1], reverse=True)[:self._max_n_ent_tuples]
|
||||
self._weighted_ent_tuples, key=lambda t: t[1], reverse=True)
|
||||
|
||||
norm_weights = softmax(
|
||||
[weight for _, weight in self._weighted_ent_tuples])
|
||||
|
||||
Reference in New Issue
Block a user