This commit is contained in:
bwt09
2022-06-05 23:28:46 -07:00
parent 5a094db63c
commit 37696628ac

View File

@@ -92,8 +92,14 @@ class EntityTupleSearcher:
collected_ents.sort(reverse=True)
for ent_min_logprob, pred_ent in collected_ents:
min_upd = min(cur_logprobs + [ent_min_logprob])
flag = set()
for ent_logprob, pred_ent in collected_ents:
if pred_ent in flag:
continue
else:
flag.add(pred_ent)
min_upd = min(cur_logprobs + [ent_logprob])
if len(collected_tuples_heap) == n and \
min_upd < collected_tuples_heap[0][0]:
break
@@ -108,7 +114,7 @@ class EntityTupleSearcher:
n_ents=n_ents,
n_masks=n_masks,
cur_ent_tuple=cur_ent_tuple + [pred_ent],
cur_logprobs=cur_logprobs + [ent_min_logprob],
cur_logprobs=cur_logprobs + [ent_logprob],
collected_tuples_heap=collected_tuples_heap,
repeat_cnt=repeat_cnt,
max_word_repeat=max_word_repeat,
@@ -146,7 +152,7 @@ class EntityTupleSearcher:
if pred_ent.replace(' ', '') == ent.replace(' ', ''):
return
# filter repeating entity in the entity tuple
if ent in pred_ent or pred_ent in ent:
if ent.startswith(pred_ent) or pred_ent.startswith(ent):
return
# filter entity appearing in the prompt
@@ -154,7 +160,7 @@ class EntityTupleSearcher:
if pred_ent in raw_prompt:
return
heapq.heappush(collected_ent_heap, [min(cur_logprobs), pred_ent])
heapq.heappush(collected_ent_heap, [sum(cur_logprobs), pred_ent])
while len(collected_ent_heap) > n:
heapq.heappop(collected_ent_heap)
@@ -188,13 +194,13 @@ class EntityTupleSearcher:
logprobs = torch.log_softmax(mask_logits_total, dim=-1)
logprobs, pred_ids = torch.sort(logprobs, descending=True)
for logprob, pred_id in zip(logprobs, pred_ids):
min_upd = min(cur_logprobs + [logprob.item()])
for logprob, pred_id in zip(logprobs[:5 * n], pred_ids[:5 * n]):
sum_logprob_upd = sum(cur_logprobs) + logprob.item()
if len(collected_ent_heap) == n and \
min_upd < collected_ent_heap[0][0]:
sum_logprob_upd < collected_ent_heap[0][0]:
break
if min_upd < logprob_threashold:
if sum_logprob_upd < logprob_threashold:
break
if not any([ch.isalpha() for ch in