mirror of
https://github.com/tanyuqian/knowledge-harvest-from-lms.git
synced 2023-06-02 01:35:42 +03:00
updated.
This commit is contained in:
@@ -54,9 +54,9 @@ class EntityTupleSearcher:
|
||||
if cur_ent_idx == n_ents:
|
||||
pred = [min(cur_logprobs), cur_ent_tuple]
|
||||
|
||||
# filter tuples with only very short entities
|
||||
if sum([len(ent) for ent in cur_ent_tuple]) == 3 * n_ents:
|
||||
return
|
||||
# # filter tuples with only very short entities
|
||||
# if sum([len(ent) for ent in cur_ent_tuple]) == 3 * n_ents:
|
||||
# return
|
||||
|
||||
for ent in cur_ent_tuple:
|
||||
for word in ent.split():
|
||||
@@ -160,7 +160,9 @@ class EntityTupleSearcher:
|
||||
if pred_ent in raw_prompt:
|
||||
return
|
||||
|
||||
heapq.heappush(collected_ent_heap, [sum(cur_logprobs), pred_ent])
|
||||
ent_logprob = (sum(cur_logprobs) + min(cur_logprobs)) / 2.
|
||||
|
||||
heapq.heappush(collected_ent_heap, [ent_logprob, pred_ent])
|
||||
while len(collected_ent_heap) > n:
|
||||
heapq.heappop(collected_ent_heap)
|
||||
|
||||
@@ -197,10 +199,10 @@ class EntityTupleSearcher:
|
||||
for logprob, pred_id in zip(logprobs[:5 * n], pred_ids[:5 * n]):
|
||||
sum_logprob_upd = sum(cur_logprobs) + logprob.item()
|
||||
if len(collected_ent_heap) == n and \
|
||||
sum_logprob_upd < collected_ent_heap[0][0]:
|
||||
sum_logprob_upd / 2. < collected_ent_heap[0][0]:
|
||||
break
|
||||
|
||||
if sum_logprob_upd < logprob_threashold:
|
||||
if sum_logprob_upd / 2. < logprob_threashold:
|
||||
break
|
||||
|
||||
if not any([ch.isalpha() for ch in
|
||||
@@ -219,4 +221,4 @@ class EntityTupleSearcher:
|
||||
cur_logprobs=cur_logprobs + [logprob.item()],
|
||||
collected_ent_heap=collected_ent_heap,
|
||||
logprob_threashold=logprob_threashold,
|
||||
n=n)
|
||||
n=n)
|
||||
@@ -55,8 +55,13 @@ class KnowledgeHarvester:
|
||||
key=lambda t: t[1], reverse=True)[:self._max_n_prompts]
|
||||
|
||||
norm_weights = softmax([weight for _, weight in self._weighted_prompts])
|
||||
norm_weights[norm_weights < 0.02] = 0.
|
||||
norm_weights /= norm_weights.sum()
|
||||
|
||||
for i, norm_weight in enumerate(norm_weights):
|
||||
self._weighted_prompts[i][1] = norm_weight
|
||||
self._weighted_prompts = [
|
||||
t for t in self._weighted_prompts if t[1] > 1e-4]
|
||||
|
||||
def update_ent_tuples(self):
|
||||
ent_tuples = self._ent_tuple_searcher.search(
|
||||
|
||||
Reference in New Issue
Block a user