mirror of
https://github.com/tanyuqian/knowledge-harvest-from-lms.git
synced 2023-06-02 01:35:42 +03:00
updated.
This commit is contained in:
@@ -92,8 +92,14 @@ class EntityTupleSearcher:
|
||||
|
||||
collected_ents.sort(reverse=True)
|
||||
|
||||
for ent_min_logprob, pred_ent in collected_ents:
|
||||
min_upd = min(cur_logprobs + [ent_min_logprob])
|
||||
flag = set()
|
||||
for ent_logprob, pred_ent in collected_ents:
|
||||
if pred_ent in flag:
|
||||
continue
|
||||
else:
|
||||
flag.add(pred_ent)
|
||||
|
||||
min_upd = min(cur_logprobs + [ent_logprob])
|
||||
if len(collected_tuples_heap) == n and \
|
||||
min_upd < collected_tuples_heap[0][0]:
|
||||
break
|
||||
@@ -108,7 +114,7 @@ class EntityTupleSearcher:
|
||||
n_ents=n_ents,
|
||||
n_masks=n_masks,
|
||||
cur_ent_tuple=cur_ent_tuple + [pred_ent],
|
||||
cur_logprobs=cur_logprobs + [ent_min_logprob],
|
||||
cur_logprobs=cur_logprobs + [ent_logprob],
|
||||
collected_tuples_heap=collected_tuples_heap,
|
||||
repeat_cnt=repeat_cnt,
|
||||
max_word_repeat=max_word_repeat,
|
||||
@@ -146,7 +152,7 @@ class EntityTupleSearcher:
|
||||
if pred_ent.replace(' ', '') == ent.replace(' ', ''):
|
||||
return
|
||||
# filter repeating entity in the entity tuple
|
||||
if ent in pred_ent or pred_ent in ent:
|
||||
if ent.startswith(pred_ent) or pred_ent.startswith(ent):
|
||||
return
|
||||
|
||||
# filter entity appearing in the prompt
|
||||
@@ -154,7 +160,7 @@ class EntityTupleSearcher:
|
||||
if pred_ent in raw_prompt:
|
||||
return
|
||||
|
||||
heapq.heappush(collected_ent_heap, [min(cur_logprobs), pred_ent])
|
||||
heapq.heappush(collected_ent_heap, [sum(cur_logprobs), pred_ent])
|
||||
while len(collected_ent_heap) > n:
|
||||
heapq.heappop(collected_ent_heap)
|
||||
|
||||
@@ -188,13 +194,13 @@ class EntityTupleSearcher:
|
||||
logprobs = torch.log_softmax(mask_logits_total, dim=-1)
|
||||
logprobs, pred_ids = torch.sort(logprobs, descending=True)
|
||||
|
||||
for logprob, pred_id in zip(logprobs, pred_ids):
|
||||
min_upd = min(cur_logprobs + [logprob.item()])
|
||||
for logprob, pred_id in zip(logprobs[:5 * n], pred_ids[:5 * n]):
|
||||
sum_logprob_upd = sum(cur_logprobs) + logprob.item()
|
||||
if len(collected_ent_heap) == n and \
|
||||
min_upd < collected_ent_heap[0][0]:
|
||||
sum_logprob_upd < collected_ent_heap[0][0]:
|
||||
break
|
||||
|
||||
if min_upd < logprob_threashold:
|
||||
if sum_logprob_upd < logprob_threashold:
|
||||
break
|
||||
|
||||
if not any([ch.isalpha() for ch in
|
||||
|
||||
Reference in New Issue
Block a user