This commit is contained in:
bwt09
2022-06-04 21:26:01 -07:00
parent 2a8544202b
commit 3cb8de6c83
4 changed files with 25 additions and 16 deletions

View File

@@ -8,7 +8,7 @@ stopwords.extend([
'anything', 'anybody', 'anyone',
'something', 'somebody', 'someone',
'nothing', 'nobody',
'one', 'neither', 'either',
'one', 'neither', 'either', 'many',
'us', 'first', 'second', 'next',
'following', 'last', 'new', 'main'])

View File

@@ -10,7 +10,7 @@ def main(rel_set='conceptnet',
max_n_ent_tuples=1000,
max_n_prompts=20,
prompt_temp=2.,
max_ent_repeat=5,
max_word_repeat=5,
max_ent_subwords=2,
use_init_prompts=False):
@@ -18,7 +18,7 @@ def main(rel_set='conceptnet',
model_name=model_name,
max_n_ent_tuples=max_n_ent_tuples,
max_n_prompts=max_n_prompts,
max_ent_repeat=max_ent_repeat,
max_word_repeat=max_word_repeat,
max_ent_subwords=max_ent_subwords,
prompt_temp=prompt_temp)
@@ -46,7 +46,8 @@ def main(rel_set='conceptnet',
knowledge_harvester.set_seed_ent_tuples(
seed_ent_tuples=info['seed_ent_tuples'])
knowledge_harvester.set_prompts(
prompts=info['init_prompts'] + info['prompts'])
prompts=info['init_prompts'] if use_init_prompts
else info['init_prompts'] + info['prompts'])
knowledge_harvester.update_prompts()
json.dump(knowledge_harvester.weighted_prompts, open(

View File

@@ -1,3 +1,4 @@
import string
import torch
import heapq
@@ -9,7 +10,7 @@ class EntityTupleSearcher:
def __init__(self, model):
self._model = model
def search(self, weighted_prompts, max_ent_repeat, max_ent_subwords, n):
def search(self, weighted_prompts, max_word_repeat, max_ent_subwords, n):
n_ents = get_n_ents(weighted_prompts[0][0])
collected_tuples_heap = []
@@ -28,7 +29,7 @@ class EntityTupleSearcher:
cur_logprobs=[],
collected_tuples_heap=collected_tuples_heap,
repeat_cnt=repeat_cnt,
max_ent_repeat=max_ent_repeat,
max_word_repeat=max_word_repeat,
n=n)
return [t[1] for t in collected_tuples_heap]
@@ -41,7 +42,7 @@ class EntityTupleSearcher:
cur_logprobs,
collected_tuples_heap,
repeat_cnt,
max_ent_repeat,
max_word_repeat,
n):
cur_ent_idx = len(cur_ent_tuple)
@@ -49,17 +50,20 @@ class EntityTupleSearcher:
pred = [min(cur_logprobs), cur_ent_tuple]
for ent in cur_ent_tuple:
if repeat_cnt.get(ent, 0) + 1 > max_ent_repeat:
return
for word in ent.split():
if repeat_cnt.get(word, 0) + 1 > max_word_repeat:
return
heapq.heappush(collected_tuples_heap, pred)
for ent in cur_ent_tuple:
repeat_cnt[ent] = repeat_cnt.get(ent, 0) + 1
for word in ent.split():
repeat_cnt[word] = repeat_cnt.get(word, 0) + 1
while len(collected_tuples_heap) > n:
heap_top = heapq.heappop(collected_tuples_heap)
for ent in heap_top[1]:
repeat_cnt[ent] = repeat_cnt[ent] - 1
for word in ent.split():
repeat_cnt[word] = repeat_cnt[word] - 1
return
@@ -75,7 +79,7 @@ class EntityTupleSearcher:
cur_logprobs=[],
collected_ent_heap=collected_ents,
logprob_threashold=logprob_threshold,
n=n if len(cur_ent_tuple) == 0 else max_ent_repeat)
n=n if len(cur_ent_tuple) == 0 else max_word_repeat)
collected_ents.sort(reverse=True)
@@ -98,7 +102,7 @@ class EntityTupleSearcher:
cur_logprobs=cur_logprobs + [ent_min_logprob],
collected_tuples_heap=collected_tuples_heap,
repeat_cnt=repeat_cnt,
max_ent_repeat=max_ent_repeat,
max_word_repeat=max_word_repeat,
n=n)
def dfs_ent(self,
@@ -185,6 +189,10 @@ class EntityTupleSearcher:
self._model.tokenizer.decode(pred_id)]):
continue
if any([punc in self._model.tokenizer.decode(pred_id)
for punc in string.punctuation]):
continue
self.dfs_ent(
cur_ent_tuple=cur_ent_tuple,
n_masks=n_masks,

View File

@@ -12,14 +12,14 @@ class KnowledgeHarvester:
model_name,
max_n_prompts=20,
max_n_ent_tuples=10000,
max_ent_repeat=10,
max_word_repeat=5,
max_ent_subwords=1,
prompt_temp=1.):
self._weighted_prompts = []
self._weighted_ent_tuples = []
self._max_n_prompts = max_n_prompts
self._max_n_ent_tuples = max_n_ent_tuples
self._max_ent_repeat = max_ent_repeat
self._max_word_repeat = max_word_repeat
self._max_ent_subwords = max_ent_subwords
self._prompt_temp = prompt_temp
@@ -62,7 +62,7 @@ class KnowledgeHarvester:
ent_tuples = self._ent_tuple_searcher.search(
weighted_prompts=self._weighted_prompts,
n=self._max_n_ent_tuples,
max_ent_repeat=self._max_ent_repeat,
max_word_repeat=self._max_word_repeat,
max_ent_subwords=self._max_ent_subwords)
self._weighted_ent_tuples = []