mirror of
https://github.com/tanyuqian/knowledge-harvest-from-lms.git
synced 2023-06-02 01:35:42 +03:00
66 lines
2.1 KiB
Python
66 lines
2.1 KiB
Python
import os
|
|
import json
|
|
import fire
|
|
|
|
from models.knowledge_harvester import KnowledgeHarvester
|
|
|
|
|
|
def main(rel_set='conceptnet',
|
|
model_name='roberta-large',
|
|
max_n_ent_tuples=1000,
|
|
max_n_prompts=20,
|
|
prompt_temp=2.,
|
|
max_word_repeat=5,
|
|
max_ent_subwords=2,
|
|
use_init_prompts=False):
|
|
|
|
knowledge_harvester = KnowledgeHarvester(
|
|
model_name=model_name,
|
|
max_n_ent_tuples=max_n_ent_tuples,
|
|
max_n_prompts=max_n_prompts,
|
|
max_word_repeat=max_word_repeat,
|
|
max_ent_subwords=max_ent_subwords,
|
|
prompt_temp=prompt_temp)
|
|
|
|
relation_info = json.load(open(f'relation_info/{rel_set}.json'))
|
|
|
|
for rel, info in relation_info.items():
|
|
print(f'Harvesting for relation {rel}...')
|
|
|
|
setting = f'{max_n_ent_tuples}tuples'
|
|
if use_init_prompts:
|
|
setting += '_initprompts'
|
|
else:
|
|
setting += f'_top{max_n_prompts}prompts'
|
|
|
|
output_dir = f'results/{rel_set}/{setting}/{model_name}'
|
|
if os.path.exists(f'{output_dir}/{rel}/ent_tuples.json'):
|
|
print(f'file {output_dir}/{rel}/ent_tuples.json exists, skipped.')
|
|
continue
|
|
else:
|
|
os.makedirs(f'{output_dir}/{rel}', exist_ok=True)
|
|
json.dump([], open(f'{output_dir}/{rel}/ent_tuples.json', 'w'))
|
|
|
|
knowledge_harvester.clear()
|
|
|
|
knowledge_harvester.set_seed_ent_tuples(
|
|
seed_ent_tuples=info['seed_ent_tuples'])
|
|
knowledge_harvester.set_prompts(
|
|
prompts=info['init_prompts'] if use_init_prompts
|
|
else list(set(info['init_prompts'] + info['prompts'])))
|
|
|
|
knowledge_harvester.update_prompts()
|
|
json.dump(knowledge_harvester.weighted_prompts, open(
|
|
f'{output_dir}/{rel}/prompts.json', 'w'), indent=4)
|
|
|
|
for prompt, weight in knowledge_harvester.weighted_prompts:
|
|
print(f'{weight:.4f} {prompt}')
|
|
|
|
knowledge_harvester.update_ent_tuples()
|
|
json.dump(knowledge_harvester.weighted_ent_tuples, open(
|
|
f'{output_dir}/{rel}/ent_tuples.json', 'w'), indent=4)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
fire.Fire(main)
|