Files
knowledge-harvest-from-lms/search_prompts.py
2022-06-06 16:28:39 -07:00

134 lines
4.4 KiB
Python

import fire
import json
from nltk import sent_tokenize
from thefuzz import fuzz
from models.gpt3 import GPT3
from data_utils.data_utils import get_n_ents, get_sent, fix_prompt_style
TRANSFORMATIONS_SENT = [['', ''], ['a ', ''], ['the ', '']]
TRANSFORMATIONS_ENT = [
['', ''], ['being', 'is'], ['being', 'are'], ['ing', ''], ['ing', 'e']]
def get_paraphrase_prompt(gpt3, prompt, ent_tuple):
assert get_n_ents(prompt) == len(ent_tuple)
ent_tuple = [ent.lower() for ent in ent_tuple]
sent = get_sent(prompt=prompt, ent_tuple=ent_tuple)
for _ in range(5):
raw_response = gpt3.call(prompt=f'paraphrase:\n{sent}\n')
para_sent = raw_response['choices'][0]['text']
para_sent = sent_tokenize(para_sent)[0]
para_sent = para_sent.strip().strip('.').lower()
print('para_sent:', para_sent)
prompt = para_sent
valid = True
for idx, ent in enumerate(ent_tuple):
# prompt = prompt.replace(ent, f'<ENT{idx}>')
for trans_sent in TRANSFORMATIONS_SENT:
for trans_ent in TRANSFORMATIONS_ENT:
if prompt.count(f'<ENT{idx}>') == 0:
transed_prompt = prompt.replace(*trans_sent)
transed_ent = ent.replace(*trans_ent)
if transed_prompt.count(transed_ent) == 1:
prompt = transed_prompt.replace(
transed_ent, f'<ENT{idx}>')
if prompt.count(f'<ENT{idx}>') != 1:
valid = False
break
if valid:
return prompt
return None
def search_prompts(init_prompts, seed_ent_tuples, similarity_threshold):
gpt3 = GPT3()
cache = {}
prompts = []
while True:
new_prompts = []
for prompt in init_prompts + init_prompts + prompts:
for ent_tuple in seed_ent_tuples:
ent_tuple = [ent.replace('_', ' ') for ent in ent_tuple]
request_str = f'{prompt} ||| {ent_tuple}'
if request_str not in cache or prompt in init_prompts:
cache[request_str] = get_paraphrase_prompt(
gpt3=gpt3, prompt=prompt, ent_tuple=ent_tuple)
para_prompt = cache[request_str]
print(f'prompt: {prompt}\tent_tuple: {ent_tuple}'
f'\t-> para_prompt: {para_prompt}')
if para_prompt is not None and \
para_prompt not in init_prompts + prompts:
new_prompts.append(para_prompt)
if len(set(prompts + new_prompts)) >= 20:
break
if len(new_prompts) == 0:
break
else:
# prompts.extend(new_prompts)
flag = False
for new_prompt in sorted(new_prompts, key=lambda t: len(t)):
if len(prompts) != 0:
max_sim = max([fuzz.ratio(new_prompt, prompt)
for prompt in prompts])
print(f'-- {new_prompt}: {max_sim}')
if len(prompts) == 0 or \
max([fuzz.ratio(new_prompt, prompt)
for prompt in prompts]) < similarity_threshold:
prompts.append(new_prompt)
flag = True
prompts = list(set(prompts))
prompts.sort(key=lambda s: len(s))
if len(prompts) >= 10 or flag == False:
break
for i in range(len(prompts)):
prompts[i] = fix_prompt_style(prompts[i])
return prompts
def main(rel_set='conceptnet', similarity_threshold=75):
relation_info = json.load(open(f'relation_info/{rel_set}.json'))
for rel, info in relation_info.items():
info['init_prompts'] = [
fix_prompt_style(prompt) for prompt in info['init_prompts']]
if 'prompts' not in info or len(info['prompts']) == 0:
info['prompts'] = search_prompts(
init_prompts=info['init_prompts'],
seed_ent_tuples=info['seed_ent_tuples'],
similarity_threshold=similarity_threshold)
for key, value in info.items():
print(f'{key}: {value}')
for prompt in info['prompts']:
print(f'- {prompt}')
print('=' * 50)
output_path = f'relation_info/{rel_set}.json'
json.dump(relation_info, open(output_path, 'w'), indent=4)
if __name__ == '__main__':
fire.Fire(main)