mirror of
https://github.com/BaranziniLab/KG_RAG.git
synced 2024-06-08 14:12:54 +03:00
adding kg-rag notebook
This commit is contained in:
179
notebooks/kg_rag.ipynb
Normal file
179
notebooks/kg_rag.ipynb
Normal file
@@ -0,0 +1,179 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "b86c2320-71ed-4223-9df7-0b9281cb652c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"os.chdir('..')\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"id": "8e9dc80f-43a6-4d8d-9d99-343bc6515ff8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from kg_rag.utility import *\n",
|
||||
"from tqdm import tqdm\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "db3c5056-15d6-4608-87c8-1e897dc4075e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Configure KG-RAG"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "fdf4d8fd-2265-4237-ba85-06a3efbf8145",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"SYSTEM_PROMPT = system_prompts[\"KG_RAG_BASED_TEXT_GENERATION\"]\n",
|
||||
"CONTEXT_VOLUME = int(config_data[\"CONTEXT_VOLUME\"])\n",
|
||||
"QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data[\"QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD\"])\n",
|
||||
"QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data[\"QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY\"])\n",
|
||||
"VECTOR_DB_PATH = config_data[\"VECTOR_DB_PATH\"]\n",
|
||||
"NODE_CONTEXT_PATH = config_data[\"NODE_CONTEXT_PATH\"]\n",
|
||||
"SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data[\"SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL\"]\n",
|
||||
"SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data[\"SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL\"]\n",
|
||||
"TEMPERATURE = config_data[\"LLM_TEMPERATURE\"]\n",
|
||||
"\n",
|
||||
"CHAT_MODEL_ID = 'gpt-4'\n",
|
||||
"EDGE_EVIDENCE = True\n",
|
||||
"\n",
|
||||
"CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL)\n",
|
||||
"embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL)\n",
|
||||
"node_context_df = pd.read_csv(NODE_CONTEXT_PATH)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "547cf664-8b48-4f19-a232-09a5b2fa4ffa",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Load test data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "00fa2491-901e-44ea-8109-2a60b23771ba",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data = pd.read_csv('data/rag_comparison_data.csv')\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "39c207c9-49be-449b-9b70-a92cdf8095d3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Function for chat completion with token usage tracking"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8ca41e38-79fb-4f68-aa16-db1785b6551f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"def chat_completion_with_token_usage(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature):\n",
|
||||
" response = openai.ChatCompletion.create(\n",
|
||||
" temperature=temperature,\n",
|
||||
" deployment_id=chat_deployment_id,\n",
|
||||
" model=chat_model_id,\n",
|
||||
" messages=[\n",
|
||||
" {\"role\": \"system\", \"content\": system_prompt},\n",
|
||||
" {\"role\": \"user\", \"content\": instruction}\n",
|
||||
" ]\n",
|
||||
" )\n",
|
||||
" return response['choices'][0]['message']['content'], response.usage.total_tokens\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0fdf1242-c2d3-4dc6-9a81-ee672bb1c7a8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Run on test data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"id": "637671b2-a06c-4fe4-a7a6-855b0ba48fcd",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"KeyboardInterrupt\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"\n",
|
||||
"kg_rag_answer = []\n",
|
||||
"total_tokens_used = []\n",
|
||||
"\n",
|
||||
"for index, row in tqdm(data.iterrows()):\n",
|
||||
" question = row['question']\n",
|
||||
" context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, EDGE_EVIDENCE)\n",
|
||||
" enriched_prompt = \"Context: \"+ context + \"\\n\" + \"Question: \" + question\n",
|
||||
" output, token_usage = chat_completion_with_token_usage(enriched_prompt, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE)\n",
|
||||
" kg_rag_answer.append(output)\n",
|
||||
" total_tokens_used.append(token_usage)\n",
|
||||
" \n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0fcd4aaf-1d64-4aef-983e-51c02ca7d223",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user