mirror of
https://github.com/BaranziniLab/KG_RAG.git
synced 2024-06-08 14:12:54 +03:00
adding true-false notebook
This commit is contained in:
@@ -37,7 +37,7 @@ SAVE_RESULTS_PATH : '/data/somank/KG_RAG/data/analysis_results'
|
||||
|
||||
# File paths for test questions
|
||||
MCQ_PATH : '/data/somank/KG_RAG/data/benchmark_data/test_questions_two_hop_mcq_from_monarch_and_robokop.csv'
|
||||
TRUE_FALSE_PATH : '/data/somank/KG_RAG/data/benchmark_data/test_questions_one_hop_true_false_v3.csv'
|
||||
TRUE_FALSE_PATH : '/data/somank/kg_rag_fork/KG_RAG/data/benchmark_data/test_questions_one_hop_true_false_v3.csv'
|
||||
ONE_HOP_GRAPH_TRAVERSAL : '/data/somank/KG_RAG/data/hyperparam_tuning_data/one_hop_graph_traversal_questions_v2.csv'
|
||||
TWO_HOP_GRAPH_TRAVERSAL : '/data/somank/KG_RAG/data/hyperparam_tuning_data/two_hop_graph_traversal_questions.csv'
|
||||
|
||||
|
||||
160
notebooks/true_false_data.ipynb
Normal file
160
notebooks/true_false_data.ipynb
Normal file
@@ -0,0 +1,160 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "6396b4b5-4a64-4a91-9dd0-8961dd1fb7ad",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"os.chdir('..')\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "e1f43477-120b-4cc5-a588-22b5f18eee92",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from kg_rag.utility import *\n",
|
||||
"from tqdm import tqdm\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6a0e7155-fa9e-46cc-98df-1950603b1193",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Choose the LLM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "75bb74c1-e4ac-48d2-ad36-a793f4c140c5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"LLM_MODEL = 'gpt-35-turbo'\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4aa934b0-81a0-4ab7-b144-af53a350bf1a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Configure KG-RAG"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "155bda08-d15b-413c-89d1-d1f97f43bb30",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"QUESTION_PATH = config_data[\"TRUE_FALSE_PATH\"]\n",
|
||||
"SYSTEM_PROMPT = system_prompts[\"TRUE_FALSE_QUESTION\"]\n",
|
||||
"QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data[\"QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD\"])\n",
|
||||
"QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data[\"QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY\"])\n",
|
||||
"VECTOR_DB_PATH = config_data[\"VECTOR_DB_PATH\"]\n",
|
||||
"NODE_CONTEXT_PATH = config_data[\"NODE_CONTEXT_PATH\"]\n",
|
||||
"SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data[\"SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL\"]\n",
|
||||
"SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data[\"SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL\"]\n",
|
||||
"TEMPERATURE = config_data[\"LLM_TEMPERATURE\"]\n",
|
||||
"SAVE_PATH = config_data[\"SAVE_RESULTS_PATH\"]\n",
|
||||
"CONTEXT_VOLUME = 100\n",
|
||||
"EDGE_EVIDENCE = False\n",
|
||||
"\n",
|
||||
"CHAT_MODEL_ID = LLM_MODEL\n",
|
||||
"CHAT_DEPLOYMENT_ID = LLM_MODEL\n",
|
||||
"\n",
|
||||
"vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL)\n",
|
||||
"embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL)\n",
|
||||
"node_context_df = pd.read_csv(NODE_CONTEXT_PATH)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "df93fd81-3cd1-4a87-b024-ea21f4c79956",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Load test data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "041b21f3-8746-47ff-b4f2-c3c29f2a0dcf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"question_df = pd.read_csv(QUESTION_PATH)\n",
|
||||
" \n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "6f4632d9-6a60-4cfe-851d-4d65d4089a52",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"0it [01:21, ?it/s]\n",
|
||||
"\n",
|
||||
"KeyboardInterrupt\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"\n",
|
||||
"answer_list = []\n",
|
||||
"for index, row in tqdm(question_df.iterrows()):\n",
|
||||
" question = row[\"text\"]\n",
|
||||
" context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, EDGE_EVIDENCE)\n",
|
||||
" enriched_prompt = \"Context: \"+ context + \"\\n\" + \"Question: \"+ question\n",
|
||||
" output = get_GPT_response(enriched_prompt, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE)\n",
|
||||
" answer_list.append((row[\"text\"], row[\"label\"], output))\n",
|
||||
"\n",
|
||||
"answer_df = pd.DataFrame(answer_list, columns=[\"question\", \"label\", \"llm_answer\"])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bf710d7c-db58-4083-9762-03de0dd5eb1a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user