mirror of
https://github.com/tanyuqian/knowledge-harvest-from-lms.git
synced 2023-06-02 01:35:42 +03:00
updated.
This commit is contained in:
@@ -8,7 +8,7 @@ stopwords.extend([
|
||||
'anything', 'anybody', 'anyone',
|
||||
'something', 'somebody', 'someone',
|
||||
'nothing', 'nobody',
|
||||
'one', 'neither', 'either',
|
||||
'one', 'neither', 'either', 'many',
|
||||
'us', 'first', 'second', 'next',
|
||||
'following', 'last', 'new', 'main'])
|
||||
|
||||
|
||||
7
main.py
7
main.py
@@ -10,7 +10,7 @@ def main(rel_set='conceptnet',
|
||||
max_n_ent_tuples=1000,
|
||||
max_n_prompts=20,
|
||||
prompt_temp=2.,
|
||||
max_ent_repeat=5,
|
||||
max_word_repeat=5,
|
||||
max_ent_subwords=2,
|
||||
use_init_prompts=False):
|
||||
|
||||
@@ -18,7 +18,7 @@ def main(rel_set='conceptnet',
|
||||
model_name=model_name,
|
||||
max_n_ent_tuples=max_n_ent_tuples,
|
||||
max_n_prompts=max_n_prompts,
|
||||
max_ent_repeat=max_ent_repeat,
|
||||
max_word_repeat=max_word_repeat,
|
||||
max_ent_subwords=max_ent_subwords,
|
||||
prompt_temp=prompt_temp)
|
||||
|
||||
@@ -46,7 +46,8 @@ def main(rel_set='conceptnet',
|
||||
knowledge_harvester.set_seed_ent_tuples(
|
||||
seed_ent_tuples=info['seed_ent_tuples'])
|
||||
knowledge_harvester.set_prompts(
|
||||
prompts=info['init_prompts'] + info['prompts'])
|
||||
prompts=info['init_prompts'] if use_init_prompts
|
||||
else info['init_prompts'] + info['prompts'])
|
||||
|
||||
knowledge_harvester.update_prompts()
|
||||
json.dump(knowledge_harvester.weighted_prompts, open(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import string
|
||||
import torch
|
||||
import heapq
|
||||
|
||||
@@ -9,7 +10,7 @@ class EntityTupleSearcher:
|
||||
def __init__(self, model):
|
||||
self._model = model
|
||||
|
||||
def search(self, weighted_prompts, max_ent_repeat, max_ent_subwords, n):
|
||||
def search(self, weighted_prompts, max_word_repeat, max_ent_subwords, n):
|
||||
n_ents = get_n_ents(weighted_prompts[0][0])
|
||||
|
||||
collected_tuples_heap = []
|
||||
@@ -28,7 +29,7 @@ class EntityTupleSearcher:
|
||||
cur_logprobs=[],
|
||||
collected_tuples_heap=collected_tuples_heap,
|
||||
repeat_cnt=repeat_cnt,
|
||||
max_ent_repeat=max_ent_repeat,
|
||||
max_word_repeat=max_word_repeat,
|
||||
n=n)
|
||||
|
||||
return [t[1] for t in collected_tuples_heap]
|
||||
@@ -41,7 +42,7 @@ class EntityTupleSearcher:
|
||||
cur_logprobs,
|
||||
collected_tuples_heap,
|
||||
repeat_cnt,
|
||||
max_ent_repeat,
|
||||
max_word_repeat,
|
||||
n):
|
||||
cur_ent_idx = len(cur_ent_tuple)
|
||||
|
||||
@@ -49,17 +50,20 @@ class EntityTupleSearcher:
|
||||
pred = [min(cur_logprobs), cur_ent_tuple]
|
||||
|
||||
for ent in cur_ent_tuple:
|
||||
if repeat_cnt.get(ent, 0) + 1 > max_ent_repeat:
|
||||
return
|
||||
for word in ent.split():
|
||||
if repeat_cnt.get(word, 0) + 1 > max_word_repeat:
|
||||
return
|
||||
|
||||
heapq.heappush(collected_tuples_heap, pred)
|
||||
for ent in cur_ent_tuple:
|
||||
repeat_cnt[ent] = repeat_cnt.get(ent, 0) + 1
|
||||
for word in ent.split():
|
||||
repeat_cnt[word] = repeat_cnt.get(word, 0) + 1
|
||||
|
||||
while len(collected_tuples_heap) > n:
|
||||
heap_top = heapq.heappop(collected_tuples_heap)
|
||||
for ent in heap_top[1]:
|
||||
repeat_cnt[ent] = repeat_cnt[ent] - 1
|
||||
for word in ent.split():
|
||||
repeat_cnt[word] = repeat_cnt[word] - 1
|
||||
|
||||
return
|
||||
|
||||
@@ -75,7 +79,7 @@ class EntityTupleSearcher:
|
||||
cur_logprobs=[],
|
||||
collected_ent_heap=collected_ents,
|
||||
logprob_threashold=logprob_threshold,
|
||||
n=n if len(cur_ent_tuple) == 0 else max_ent_repeat)
|
||||
n=n if len(cur_ent_tuple) == 0 else max_word_repeat)
|
||||
|
||||
collected_ents.sort(reverse=True)
|
||||
|
||||
@@ -98,7 +102,7 @@ class EntityTupleSearcher:
|
||||
cur_logprobs=cur_logprobs + [ent_min_logprob],
|
||||
collected_tuples_heap=collected_tuples_heap,
|
||||
repeat_cnt=repeat_cnt,
|
||||
max_ent_repeat=max_ent_repeat,
|
||||
max_word_repeat=max_word_repeat,
|
||||
n=n)
|
||||
|
||||
def dfs_ent(self,
|
||||
@@ -185,6 +189,10 @@ class EntityTupleSearcher:
|
||||
self._model.tokenizer.decode(pred_id)]):
|
||||
continue
|
||||
|
||||
if any([punc in self._model.tokenizer.decode(pred_id)
|
||||
for punc in string.punctuation]):
|
||||
continue
|
||||
|
||||
self.dfs_ent(
|
||||
cur_ent_tuple=cur_ent_tuple,
|
||||
n_masks=n_masks,
|
||||
|
||||
@@ -12,14 +12,14 @@ class KnowledgeHarvester:
|
||||
model_name,
|
||||
max_n_prompts=20,
|
||||
max_n_ent_tuples=10000,
|
||||
max_ent_repeat=10,
|
||||
max_word_repeat=5,
|
||||
max_ent_subwords=1,
|
||||
prompt_temp=1.):
|
||||
self._weighted_prompts = []
|
||||
self._weighted_ent_tuples = []
|
||||
self._max_n_prompts = max_n_prompts
|
||||
self._max_n_ent_tuples = max_n_ent_tuples
|
||||
self._max_ent_repeat = max_ent_repeat
|
||||
self._max_word_repeat = max_word_repeat
|
||||
self._max_ent_subwords = max_ent_subwords
|
||||
self._prompt_temp = prompt_temp
|
||||
|
||||
@@ -62,7 +62,7 @@ class KnowledgeHarvester:
|
||||
ent_tuples = self._ent_tuple_searcher.search(
|
||||
weighted_prompts=self._weighted_prompts,
|
||||
n=self._max_n_ent_tuples,
|
||||
max_ent_repeat=self._max_ent_repeat,
|
||||
max_word_repeat=self._max_word_repeat,
|
||||
max_ent_subwords=self._max_ent_subwords)
|
||||
|
||||
self._weighted_ent_tuples = []
|
||||
|
||||
Reference in New Issue
Block a user