mirror of
https://github.com/tanyuqian/knowledge-harvest-from-lms.git
synced 2023-06-02 01:35:42 +03:00
updated.
This commit is contained in:
@@ -10,7 +10,7 @@ stopwords.extend([
|
||||
'nothing', 'nobody',
|
||||
'one', 'neither', 'either', 'many',
|
||||
'us', 'first', 'second', 'next',
|
||||
'following', 'last', 'new', 'main'])
|
||||
'following', 'last', 'new', 'main', 'also'])
|
||||
|
||||
|
||||
def get_n_ents(prompt):
|
||||
|
||||
@@ -32,7 +32,12 @@ class EntityTupleSearcher:
|
||||
max_word_repeat=max_word_repeat,
|
||||
n=n)
|
||||
|
||||
return [t[1] for t in collected_tuples_heap]
|
||||
ent_tuples = sorted([t[1] for t in collected_tuples_heap])
|
||||
|
||||
ent_tuples = [ent_tuples[i] for i in range(len(ent_tuples))
|
||||
if i == 0 or ent_tuples[i] != ent_tuples[i - 1]]
|
||||
|
||||
return ent_tuples
|
||||
|
||||
def dfs(self,
|
||||
weighted_prompts,
|
||||
@@ -132,10 +137,13 @@ class EntityTupleSearcher:
|
||||
if min([len(t) for t in pred_ent.split()]) <= 1:
|
||||
return
|
||||
|
||||
# filter repeating entity in the entity tuple
|
||||
for ent in cur_ent_tuple:
|
||||
# filter repeating entity in the entity tuple
|
||||
if pred_ent.replace(' ', '') == ent.replace(' ', ''):
|
||||
return
|
||||
# filter repeating entity in the entity tuple
|
||||
if ent in pred_ent or pred_ent in ent:
|
||||
return
|
||||
|
||||
# filter entity appearing in the prompt
|
||||
for raw_prompt, _ in weighted_prompts:
|
||||
|
||||
@@ -51,7 +51,7 @@ def get_paraphrase_prompt(gpt3, prompt, ent_tuple):
|
||||
return None
|
||||
|
||||
|
||||
def search_prompts(init_prompts, seed_ent_tuples):
|
||||
def search_prompts(init_prompts, seed_ent_tuples, similarity_threshold):
|
||||
gpt3 = GPT3()
|
||||
|
||||
cache = {}
|
||||
@@ -88,8 +88,9 @@ def search_prompts(init_prompts, seed_ent_tuples):
|
||||
max_sim = max([fuzz.ratio(new_prompt, prompt)
|
||||
for prompt in prompts])
|
||||
print(f'-- {new_prompt}: {max_sim}')
|
||||
if len(prompts) == 0 or max([fuzz.ratio(
|
||||
new_prompt, prompt) for prompt in prompts]) < 75:
|
||||
if len(prompts) == 0 or \
|
||||
max([fuzz.ratio(new_prompt, prompt)
|
||||
for prompt in prompts]) < similarity_threshold:
|
||||
prompts.append(new_prompt)
|
||||
flag = True
|
||||
|
||||
@@ -99,17 +100,21 @@ def search_prompts(init_prompts, seed_ent_tuples):
|
||||
if len(prompts) >= 10 or flag == False:
|
||||
break
|
||||
|
||||
for i in range(len(prompts)):
|
||||
prompts[i] = fix_prompt_style(prompts[i])
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
def main(rel_set='conceptnet'):
|
||||
def main(rel_set='conceptnet', similarity_threshold=75):
|
||||
relation_info = json.load(open(f'relation_info/{rel_set}.json'))
|
||||
|
||||
for rel, info in relation_info.items():
|
||||
if 'prompts' not in info or len(info['prompts']) == 0:
|
||||
info['prompts'] = search_prompts(
|
||||
init_prompts=info['init_prompts'],
|
||||
seed_ent_tuples=info['seed_ent_tuples'])
|
||||
seed_ent_tuples=info['seed_ent_tuples'],
|
||||
similarity_threshold=similarity_threshold)
|
||||
|
||||
for key, value in info.items():
|
||||
print(f'{key}: {value}')
|
||||
@@ -117,9 +122,6 @@ def main(rel_set='conceptnet'):
|
||||
print(f'- {prompt}')
|
||||
print('=' * 50)
|
||||
|
||||
for i in range(len(info['prompts'])):
|
||||
info['prompts'][i] = fix_prompt_style(info['prompts'][i])
|
||||
|
||||
output_path = f'relation_info/{rel_set}.json'
|
||||
json.dump(relation_info, open(output_path, 'w'), indent=4)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user