mirror of
https://github.com/tanyuqian/knowledge-harvest-from-lms.git
synced 2023-06-02 01:35:42 +03:00
61 lines
2.0 KiB
Python
61 lines
2.0 KiB
Python
import os
|
|
import sys
|
|
import fire
|
|
import json
|
|
import random
|
|
from prettytable import PrettyTable
|
|
|
|
|
|
def main(result_dir, n_present=20):
|
|
rel_set = result_dir.split('/')[1]
|
|
relation_info = json.load(open(f'relation_info/{rel_set}.json'))
|
|
|
|
summary_file = open(f'{result_dir}/summary.txt', 'w')
|
|
|
|
for rel, info in relation_info.items():
|
|
columns = {'Seed samples': info['seed_ent_tuples']}
|
|
|
|
if not os.path.exists(f'{result_dir}/{rel}/ent_tuples.json'):
|
|
print(f'outputs of relation \"{rel}\" not found. skipped.')
|
|
continue
|
|
|
|
weighted_prompts = json.load(open(f'{result_dir}/{rel}/prompts.json'))
|
|
weighted_ent_tuples = json.load(open(
|
|
f'{result_dir}/{rel}/ent_tuples.json'))
|
|
|
|
if len(weighted_ent_tuples) == 0:
|
|
print(f'outputs of relation \"{rel}\" not found. skipped.')
|
|
continue
|
|
weighted_ent_tuples = weighted_ent_tuples[:200]
|
|
|
|
columns[f'Ours (Top {n_present})'] = [
|
|
str(ent_tuple) for ent_tuple, _ in weighted_ent_tuples[:n_present]]
|
|
|
|
columns[f'Ours (Random samples over top 200 tuples)'] = [
|
|
str(ent_tuple) for ent_tuple, _ in random.sample(
|
|
weighted_ent_tuples, n_present)]
|
|
|
|
table = PrettyTable()
|
|
for key, col in columns.items():
|
|
if len(col) < n_present:
|
|
col.extend(['\\'] * (n_present - len(col)))
|
|
table.add_column(key, col)
|
|
|
|
def _print_results(output_file):
|
|
print(f'Relation: {rel}', file=output_file)
|
|
print('Prompts:', file=output_file)
|
|
for prompt, weight in weighted_prompts:
|
|
print(f'- {weight:.4f} {prompt}', file=output_file)
|
|
print('Harvested Tuples:', file=output_file)
|
|
print(table, file=output_file)
|
|
print('=' * 50, file=output_file, flush=True)
|
|
|
|
_print_results(output_file=summary_file)
|
|
_print_results(output_file=sys.stdout)
|
|
|
|
print(f'This summary has been saved into {summary_file.name}.')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
fire.Fire(main)
|