This commit is contained in:
bwt09
2022-06-06 03:25:26 -07:00
parent 37696628ac
commit 1ebaebfbd2
2 changed files with 14 additions and 7 deletions

View File

@@ -54,9 +54,9 @@ class EntityTupleSearcher:
if cur_ent_idx == n_ents:
pred = [min(cur_logprobs), cur_ent_tuple]
# filter tuples with only very short entities
if sum([len(ent) for ent in cur_ent_tuple]) == 3 * n_ents:
return
# # filter tuples with only very short entities
# if sum([len(ent) for ent in cur_ent_tuple]) == 3 * n_ents:
# return
for ent in cur_ent_tuple:
for word in ent.split():
@@ -160,7 +160,9 @@ class EntityTupleSearcher:
if pred_ent in raw_prompt:
return
heapq.heappush(collected_ent_heap, [sum(cur_logprobs), pred_ent])
ent_logprob = (sum(cur_logprobs) + min(cur_logprobs)) / 2.
heapq.heappush(collected_ent_heap, [ent_logprob, pred_ent])
while len(collected_ent_heap) > n:
heapq.heappop(collected_ent_heap)
@@ -197,10 +199,10 @@ class EntityTupleSearcher:
for logprob, pred_id in zip(logprobs[:5 * n], pred_ids[:5 * n]):
sum_logprob_upd = sum(cur_logprobs) + logprob.item()
if len(collected_ent_heap) == n and \
sum_logprob_upd < collected_ent_heap[0][0]:
sum_logprob_upd / 2. < collected_ent_heap[0][0]:
break
if sum_logprob_upd < logprob_threashold:
if sum_logprob_upd / 2. < logprob_threashold:
break
if not any([ch.isalpha() for ch in
@@ -219,4 +221,4 @@ class EntityTupleSearcher:
cur_logprobs=cur_logprobs + [logprob.item()],
collected_ent_heap=collected_ent_heap,
logprob_threashold=logprob_threashold,
n=n)
n=n)

View File

@@ -55,8 +55,13 @@ class KnowledgeHarvester:
key=lambda t: t[1], reverse=True)[:self._max_n_prompts]
norm_weights = softmax([weight for _, weight in self._weighted_prompts])
norm_weights[norm_weights < 0.02] = 0.
norm_weights /= norm_weights.sum()
for i, norm_weight in enumerate(norm_weights):
self._weighted_prompts[i][1] = norm_weight
self._weighted_prompts = [
t for t in self._weighted_prompts if t[1] > 1e-4]
def update_ent_tuples(self):
ent_tuples = self._ent_tuple_searcher.search(