Files
2022-06-06 14:54:08 -07:00

66 lines
2.1 KiB
Python

import os
import json
import fire
from models.knowledge_harvester import KnowledgeHarvester
def main(rel_set='conceptnet',
model_name='roberta-large',
max_n_ent_tuples=1000,
max_n_prompts=20,
prompt_temp=2.,
max_word_repeat=5,
max_ent_subwords=2,
use_init_prompts=False):
knowledge_harvester = KnowledgeHarvester(
model_name=model_name,
max_n_ent_tuples=max_n_ent_tuples,
max_n_prompts=max_n_prompts,
max_word_repeat=max_word_repeat,
max_ent_subwords=max_ent_subwords,
prompt_temp=prompt_temp)
relation_info = json.load(open(f'relation_info/{rel_set}.json'))
for rel, info in relation_info.items():
print(f'Harvesting for relation {rel}...')
setting = f'{max_n_ent_tuples}tuples'
if use_init_prompts:
setting += '_initprompts'
else:
setting += f'_top{max_n_prompts}prompts'
output_dir = f'results/{rel_set}/{setting}/{model_name}'
if os.path.exists(f'{output_dir}/{rel}/ent_tuples.json'):
print(f'file {output_dir}/{rel}/ent_tuples.json exists, skipped.')
continue
else:
os.makedirs(f'{output_dir}/{rel}', exist_ok=True)
json.dump([], open(f'{output_dir}/{rel}/ent_tuples.json', 'w'))
knowledge_harvester.clear()
knowledge_harvester.set_seed_ent_tuples(
seed_ent_tuples=info['seed_ent_tuples'])
knowledge_harvester.set_prompts(
prompts=info['init_prompts'] if use_init_prompts
else list(set(info['init_prompts'] + info['prompts'])))
knowledge_harvester.update_prompts()
json.dump(knowledge_harvester.weighted_prompts, open(
f'{output_dir}/{rel}/prompts.json', 'w'), indent=4)
for prompt, weight in knowledge_harvester.weighted_prompts:
print(f'{weight:.4f} {prompt}')
knowledge_harvester.update_ent_tuples()
json.dump(knowledge_harvester.weighted_ent_tuples, open(
f'{output_dir}/{rel}/ent_tuples.json', 'w'), indent=4)
if __name__ == '__main__':
fire.Fire(main)