This commit is contained in:
bwt09
2022-06-05 00:23:30 -07:00
parent 3cb8de6c83
commit d619e0f544
3 changed files with 21 additions and 11 deletions

View File

@@ -10,7 +10,7 @@ stopwords.extend([
'nothing', 'nobody',
'one', 'neither', 'either', 'many',
'us', 'first', 'second', 'next',
'following', 'last', 'new', 'main'])
'following', 'last', 'new', 'main', 'also'])
def get_n_ents(prompt):

View File

@@ -32,7 +32,12 @@ class EntityTupleSearcher:
max_word_repeat=max_word_repeat,
n=n)
return [t[1] for t in collected_tuples_heap]
ent_tuples = sorted([t[1] for t in collected_tuples_heap])
ent_tuples = [ent_tuples[i] for i in range(len(ent_tuples))
if i == 0 or ent_tuples[i] != ent_tuples[i - 1]]
return ent_tuples
def dfs(self,
weighted_prompts,
@@ -132,10 +137,13 @@ class EntityTupleSearcher:
if min([len(t) for t in pred_ent.split()]) <= 1:
return
# filter repeating entity in the entity tuple
for ent in cur_ent_tuple:
# filter repeating entity in the entity tuple
if pred_ent.replace(' ', '') == ent.replace(' ', ''):
return
# filter repeating entity in the entity tuple
if ent in pred_ent or pred_ent in ent:
return
# filter entity appearing in the prompt
for raw_prompt, _ in weighted_prompts:

View File

@@ -51,7 +51,7 @@ def get_paraphrase_prompt(gpt3, prompt, ent_tuple):
return None
def search_prompts(init_prompts, seed_ent_tuples):
def search_prompts(init_prompts, seed_ent_tuples, similarity_threshold):
gpt3 = GPT3()
cache = {}
@@ -88,8 +88,9 @@ def search_prompts(init_prompts, seed_ent_tuples):
max_sim = max([fuzz.ratio(new_prompt, prompt)
for prompt in prompts])
print(f'-- {new_prompt}: {max_sim}')
if len(prompts) == 0 or max([fuzz.ratio(
new_prompt, prompt) for prompt in prompts]) < 75:
if len(prompts) == 0 or \
max([fuzz.ratio(new_prompt, prompt)
for prompt in prompts]) < similarity_threshold:
prompts.append(new_prompt)
flag = True
@@ -99,17 +100,21 @@ def search_prompts(init_prompts, seed_ent_tuples):
if len(prompts) >= 10 or flag == False:
break
for i in range(len(prompts)):
prompts[i] = fix_prompt_style(prompts[i])
return prompts
def main(rel_set='conceptnet'):
def main(rel_set='conceptnet', similarity_threshold=75):
relation_info = json.load(open(f'relation_info/{rel_set}.json'))
for rel, info in relation_info.items():
if 'prompts' not in info or len(info['prompts']) == 0:
info['prompts'] = search_prompts(
init_prompts=info['init_prompts'],
seed_ent_tuples=info['seed_ent_tuples'])
seed_ent_tuples=info['seed_ent_tuples'],
similarity_threshold=similarity_threshold)
for key, value in info.items():
print(f'{key}: {value}')
@@ -117,9 +122,6 @@ def main(rel_set='conceptnet'):
print(f'- {prompt}')
print('=' * 50)
for i in range(len(info['prompts'])):
info['prompts'][i] = fix_prompt_style(info['prompts'][i])
output_path = f'relation_info/{rel_set}.json'
json.dump(relation_info, open(output_path, 'w'), indent=4)