mirror of
https://github.com/BaranziniLab/KG_RAG.git
synced 2024-06-08 14:12:54 +03:00
323 lines
15 KiB
Plaintext
323 lines
15 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "e38c27d7-ebfe-406f-aa9f-07f9d6662d52",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"import numpy as np\n",
|
|
"import os\n",
|
|
"from tqdm import tqdm\n",
|
|
"import re\n",
|
|
"from scipy import stats\n",
|
|
"import seaborn as sns\n",
|
|
"import matplotlib.pyplot as plt\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "f3c91843",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def extract_answer(text):\n",
|
|
" pattern = r\"(True|False|Don't know)\"\n",
|
|
" matches = re.findall(pattern, text)\n",
|
|
" return matches\n",
|
|
"\n",
|
|
"def process_df(prompt_response_df, rag_response_df):\n",
|
|
" prompt_response_df.loc[:, \"answer_count\"] = prompt_response_df.extracted_answer.apply(lambda x:len(x))\n",
|
|
" prompt_response_df_multiple_answers = prompt_response_df[prompt_response_df.answer_count > 1]\n",
|
|
" prompt_response_df_single_answer = prompt_response_df.drop(prompt_response_df_multiple_answers.index)\n",
|
|
" prompt_response_df_single_answer.drop(\"answer_count\", axis=1, inplace=True)\n",
|
|
" prompt_response_df_multiple_answers_ = []\n",
|
|
" for index, row in prompt_response_df_multiple_answers.iterrows():\n",
|
|
" if row[\"extracted_answer\"][0] == row[\"extracted_answer\"][1]:\n",
|
|
" prompt_response_df_multiple_answers_.append((row[\"question\"], row[\"label\"], row[\"llm_answer\"], row[\"extracted_answer\"][0]))\n",
|
|
" else:\n",
|
|
" prompt_response_df_multiple_answers_.append((row[\"question\"], row[\"label\"], row[\"llm_answer\"], None))\n",
|
|
" prompt_response_df_multiple_answers_ = pd.DataFrame(prompt_response_df_multiple_answers_, columns=[\"question\", \"label\", \"llm_answer\", \"extracted_answer\"])\n",
|
|
" prompt_response_df_final = pd.concat([prompt_response_df_single_answer, prompt_response_df_multiple_answers_], ignore_index=True)\n",
|
|
" prompt_response_df_final = prompt_response_df_final.explode(\"extracted_answer\")\n",
|
|
"\n",
|
|
" rag_response_df.loc[:, \"answer_count\"] = rag_response_df.extracted_answer.apply(lambda x:len(x))\n",
|
|
" rag_response_df_multiple_answers = rag_response_df[rag_response_df.answer_count > 1]\n",
|
|
" rag_response_df_single_answer = rag_response_df.drop(rag_response_df_multiple_answers.index)\n",
|
|
" rag_response_df_single_answer.drop(\"answer_count\", axis=1, inplace=True)\n",
|
|
" rag_response_df_multiple_answers_ = []\n",
|
|
" for index, row in rag_response_df_multiple_answers.iterrows():\n",
|
|
" if row[\"extracted_answer\"][0] == row[\"extracted_answer\"][1]:\n",
|
|
" rag_response_df_multiple_answers_.append((row[\"question\"], row[\"label\"], row[\"llm_answer\"], row[\"extracted_answer\"][0]))\n",
|
|
" else:\n",
|
|
" rag_response_df_multiple_answers_.append((row[\"question\"], row[\"label\"], row[\"llm_answer\"], None))\n",
|
|
" rag_response_df_multiple_answers_ = pd.DataFrame(rag_response_df_multiple_answers_, columns=[\"question\", \"label\", \"llm_answer\", \"extracted_answer\"])\n",
|
|
" rag_response_df_final = pd.concat([rag_response_df_single_answer, rag_response_df_multiple_answers_], ignore_index=True)\n",
|
|
" rag_response_df_final = rag_response_df_final.explode(\"extracted_answer\")\n",
|
|
" \n",
|
|
" prompt_incorrect_answers_because_of_na = prompt_response_df_final[prompt_response_df_final.extracted_answer.isna()]\n",
|
|
" rag_incorrect_answers_because_of_na = rag_response_df_final[rag_response_df_final.extracted_answer.isna()]\n",
|
|
"\n",
|
|
" row_index_to_drop = list(prompt_incorrect_answers_because_of_na.index.values) + list(rag_incorrect_answers_because_of_na.index.values)\n",
|
|
"\n",
|
|
" prompt_response_df_final.drop(row_index_to_drop, inplace=True)\n",
|
|
" rag_response_df_final.drop(row_index_to_drop, inplace=True)\n",
|
|
"\n",
|
|
" prompt_response_df_final = prompt_response_df_final.reset_index()\n",
|
|
" rag_response_df_final = rag_response_df_final.reset_index()\n",
|
|
" response_transform = {\n",
|
|
" \"True\" : True,\n",
|
|
" \"False\" : False\n",
|
|
" }\n",
|
|
"\n",
|
|
" prompt_response_df_final.extracted_answer = prompt_response_df_final.extracted_answer.apply(lambda x:response_transform[x])\n",
|
|
" rag_response_df_final.extracted_answer = rag_response_df_final.extracted_answer.apply(lambda x:response_transform[x])\n",
|
|
"\n",
|
|
" return prompt_response_df_final, rag_response_df_final\n",
|
|
"\n",
|
|
"\n",
|
|
"def evaluate(df):\n",
|
|
" correct = df[df.label == df.extracted_answer]\n",
|
|
" incorrect = df[df.label != df.extracted_answer]\n",
|
|
" correct_frac = correct.shape[0]/df.shape[0]\n",
|
|
" incorrect_frac = incorrect.shape[0]/df.shape[0]\n",
|
|
" return correct_frac, incorrect_frac\n",
|
|
"\n",
|
|
"\n",
|
|
"def bootstrap(prompt_response_df_final, rag_response_df_final, niter = 1000, nsample = 150):\n",
|
|
" prompt_correct_frac_list = []\n",
|
|
" rag_correct_frac_list = []\n",
|
|
" for i in tqdm(range(niter)):\n",
|
|
" prompt_response_df_final_sample = prompt_response_df_final.sample(n=nsample, random_state=i)\n",
|
|
" prompt_correct_frac, prompt_incorrect_frac = evaluate(prompt_response_df_final_sample)\n",
|
|
" rag_response_df_final_sample = rag_response_df_final.iloc[prompt_response_df_final_sample.index]\n",
|
|
" rag_correct_frac, rag_incorrect_frac = evaluate(rag_response_df_final_sample)\n",
|
|
" prompt_correct_frac_list.append(prompt_correct_frac)\n",
|
|
" rag_correct_frac_list.append(rag_correct_frac)\n",
|
|
" return prompt_correct_frac_list, rag_correct_frac_list\n",
|
|
"\n",
|
|
"def plot_figure(prompt_correct_frac_list, rag_correct_frac_list):\n",
|
|
" fig = plt.figure(figsize=(5, 3))\n",
|
|
" ax = plt.gca()\n",
|
|
"\n",
|
|
" sns.kdeplot(prompt_correct_frac_list, color=\"blue\", shade=True, label=\"Prompt based\", ax=ax, lw=2, linestyle=\"-\", alpha=0.6)\n",
|
|
" sns.kdeplot(rag_correct_frac_list, color=\"lightcoral\", shade=True, label=\"KG-RAG based\", ax=ax, lw=2, linestyle=\"-\", alpha=0.6)\n",
|
|
"\n",
|
|
" for artist in ax.lines:\n",
|
|
" artist.set_edgecolor(\"black\")\n",
|
|
"\n",
|
|
" plt.xlabel(\"Accuracy\")\n",
|
|
" plt.ylabel(\"Density\")\n",
|
|
" plt.legend(loc=\"upper left\")\n",
|
|
" plt.xlim(0.75,1)\n",
|
|
"\n",
|
|
" ax.axvline(np.mean(prompt_correct_frac_list), color='black', linestyle='--', lw=2)\n",
|
|
" ax.axvline(np.mean(rag_correct_frac_list), color='black', linestyle='--', lw=2)\n",
|
|
"\n",
|
|
" sns.despine(top=True, right=True)\n",
|
|
"\n",
|
|
" plt.show()\n",
|
|
" return fig\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "a5723e54-a5ef-48cd-85d1-88d0fe7ea78e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"'''\n",
|
|
"Following files can be obtained by running the run_true_false_generation.py script.\n",
|
|
"Make sure to change the parent path and filenames based on where and how you save the files\n",
|
|
"'''\n",
|
|
"\n",
|
|
"data_path ='../data/results'\n",
|
|
"\n",
|
|
"llama_prompt_path = os.path.join(data_path, 'Llama_2_13b_chat_hf_prompt_based_true_false_binary_response.csv')\n",
|
|
"llama_kg_rag_path = os.path.join(data_path, 'Llama_2_13b_chat_hf_PubMedBert_and_entity_recognition_based_node_retrieval_rag_based_true_false_binary_response.csv')\n",
|
|
"\n",
|
|
"gpt_35_prompt_path = os.path.join(data_path, 'gpt_35_turbo_prompt_based_true_false_binary_response.csv')\n",
|
|
"gpt_35_kg_rag_path = os.path.join(data_path, 'gpt_35_turbo_PubMedBert_and_entity_recognition_based_node_retrieval_rag_based_true_false_binary_response.csv')\n",
|
|
"\n",
|
|
"gpt_4_prompt_path = os.path.join(data_path, 'gpt_4_prompt_based_true_false_binary_response.csv')\n",
|
|
"gpt_4_kg_rag_path = os.path.join(data_path, 'gpt_4_PubMedBert_and_entity_recognition_based_node_retrieval_rag_based_true_false_binary_response.csv')\n",
|
|
"\n",
|
|
"curated_data = pd.read_csv('../data/benchmark_data/true_false_questions.csv').drop('Unnamed: 0', axis=1)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "47f3c405",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"llama_prompt_df = pd.read_csv(llama_prompt_path)\n",
|
|
"llama_prompt_df = pd.merge(curated_data, llama_prompt_df, left_on='text', right_on='question').drop(['text', 'label_y'], axis=1).rename(columns={'label_x':'label'})\n",
|
|
"\n",
|
|
"llama_kg_rag_df = pd.read_csv(llama_kg_rag_path)\n",
|
|
"llama_kg_rag_df = pd.merge(curated_data, llama_kg_rag_df, left_on='text', right_on='question').drop(['text', 'label_y'], axis=1).rename(columns={'label_x':'label'})\n",
|
|
"\n",
|
|
"gpt_35_prompt_df = pd.read_csv(gpt_35_prompt_path)\n",
|
|
"gpt_35_prompt_df = pd.merge(curated_data, gpt_35_prompt_df, left_on='text', right_on='question').drop(['text', 'label_y'], axis=1).rename(columns={'label_x':'label'})\n",
|
|
"\n",
|
|
"gpt_35_kg_rag_df = pd.read_csv(gpt_35_kg_rag_path)\n",
|
|
"gpt_35_kg_rag_df = pd.merge(curated_data, gpt_35_kg_rag_df, left_on='text', right_on='question').drop(['text', 'label_y'], axis=1).rename(columns={'label_x':'label'})\n",
|
|
"\n",
|
|
"gpt_4_prompt_df = pd.read_csv(gpt_4_prompt_path)\n",
|
|
"gpt_4_prompt_df = pd.merge(curated_data, gpt_4_prompt_df, left_on='text', right_on='question').drop(['text', 'label_y'], axis=1).rename(columns={'label_x':'label'})\n",
|
|
"\n",
|
|
"gpt_4_kg_rag_df = pd.read_csv(gpt_4_kg_rag_path)\n",
|
|
"gpt_4_kg_rag_df = pd.merge(curated_data, gpt_4_kg_rag_df, left_on='text', right_on='question').drop(['text', 'label_y'], axis=1).rename(columns={'label_x':'label'})\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "0ca02159",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"llama_prompt_df.loc[:, 'extracted_answer'] = llama_prompt_df['llm_answer'].apply(extract_answer)\n",
|
|
"llama_kg_rag_df.loc[:, 'extracted_answer'] = llama_kg_rag_df['llm_answer'].apply(extract_answer)\n",
|
|
"\n",
|
|
"gpt_35_prompt_df.loc[:, 'extracted_answer'] = gpt_35_prompt_df['llm_answer'].apply(extract_answer)\n",
|
|
"gpt_35_kg_rag_df.loc[:, 'extracted_answer'] = gpt_35_kg_rag_df['llm_answer'].apply(extract_answer)\n",
|
|
"\n",
|
|
"gpt_4_prompt_df.loc[:, 'extracted_answer'] = gpt_4_prompt_df['llm_answer'].apply(extract_answer)\n",
|
|
"gpt_4_kg_rag_df.loc[:, 'extracted_answer'] = gpt_4_kg_rag_df['llm_answer'].apply(extract_answer)\n",
|
|
"\n",
|
|
"llama_prompt_df, llama_kg_rag_df = process_df(llama_prompt_df, llama_kg_rag_df)\n",
|
|
"\n",
|
|
"gpt_35_prompt_df, gpt_35_kg_rag_df = process_df(gpt_35_prompt_df, gpt_35_kg_rag_df)\n",
|
|
"\n",
|
|
"gpt_4_prompt_df, gpt_4_kg_rag_df = process_df(gpt_4_prompt_df, gpt_4_kg_rag_df)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "90ccd105",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|█████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1323.05it/s]\n",
|
|
"100%|█████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1207.89it/s]\n",
|
|
" 22%|██████████████████▍ | 225/1000 [00:00<00:00, 1134.89it/s]"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"llama_prompt_correct_frac_list, llama_rag_correct_frac_list = bootstrap(llama_prompt_df, llama_kg_rag_df)\n",
|
|
"\n",
|
|
"gpt_35_prompt_correct_frac_list, gpt_35_rag_correct_frac_list = bootstrap(gpt_35_prompt_df, gpt_35_kg_rag_df)\n",
|
|
"\n",
|
|
"gpt_4_prompt_correct_frac_list, gpt_4_rag_correct_frac_list = bootstrap(gpt_4_prompt_df, gpt_4_kg_rag_df)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "428528e7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"llama_fig = plot_figure(llama_prompt_correct_frac_list, llama_rag_correct_frac_list)\n",
|
|
"\n",
|
|
"fig_path = '../data/results/figures'\n",
|
|
"os.makedirs(fig_path, exist_ok=True)\n",
|
|
"llama_fig.savefig(os.path.join(fig_path, 'llama_true_false.svg'), format='svg', bbox_inches='tight') \n",
|
|
"\n",
|
|
"print('---Prompt based mean and std---')\n",
|
|
"print(np.mean(llama_prompt_correct_frac_list))\n",
|
|
"print(np.std(llama_prompt_correct_frac_list))\n",
|
|
"print('')\n",
|
|
"print('---KG-RAG based mean and std---')\n",
|
|
"print(np.mean(llama_rag_correct_frac_list))\n",
|
|
"print(np.std(llama_rag_correct_frac_list))\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "49f9f1ae",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"gpt_35_fig = plot_figure(gpt_35_prompt_correct_frac_list, gpt_35_rag_correct_frac_list)\n",
|
|
"\n",
|
|
"fig_path = '../data/results/figures'\n",
|
|
"os.makedirs(fig_path, exist_ok=True)\n",
|
|
"gpt_35_fig.savefig(os.path.join(fig_path, 'gpt_35_true_false.svg'), format='svg', bbox_inches='tight') \n",
|
|
"\n",
|
|
"print('---Prompt based mean and std---')\n",
|
|
"print(np.mean(gpt_35_prompt_correct_frac_list))\n",
|
|
"print(np.std(gpt_35_prompt_correct_frac_list))\n",
|
|
"print('')\n",
|
|
"print('---KG-RAG based mean and std---')\n",
|
|
"print(np.mean(gpt_35_rag_correct_frac_list))\n",
|
|
"print(np.std(gpt_35_rag_correct_frac_list))\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "e9f75a7f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"gpt_4_fig = plot_figure(gpt_4_prompt_correct_frac_list, gpt_4_rag_correct_frac_list)\n",
|
|
"\n",
|
|
"fig_path = '../data/results/figures'\n",
|
|
"os.makedirs(fig_path, exist_ok=True)\n",
|
|
"gpt_4_fig.savefig(os.path.join(fig_path, 'gpt_4_true_false.svg'), format='svg', bbox_inches='tight') \n",
|
|
"\n",
|
|
"print('---Prompt based mean and std---')\n",
|
|
"print(np.mean(gpt_4_prompt_correct_frac_list))\n",
|
|
"print(np.std(gpt_4_prompt_correct_frac_list))\n",
|
|
"print('')\n",
|
|
"print('---KG-RAG based mean and std---')\n",
|
|
"print(np.mean(gpt_4_rag_correct_frac_list))\n",
|
|
"print(np.std(gpt_4_rag_correct_frac_list))\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "00bc1f64",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.9"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|