adding new files

This commit is contained in:
Karthik Soman
2024-03-18 17:48:19 -07:00
parent 2ab25a6d8f
commit e437442faf
13 changed files with 1298 additions and 498 deletions

1
.gitignore vendored
View File

@@ -1,3 +1,4 @@
notebooks/cypher_rag_using_langchain_output_extraction.ipynb
data/arxiv_data
notebooks/neo4j_rag_using_langchain_3M.ipynb
cachegpt

View File

@@ -38,8 +38,8 @@ SAVE_RESULTS_PATH : '/data/somank/KG_RAG/data/analysis_results'
# File paths for test questions
MCQ_PATH : '/data/somank/KG_RAG/data/benchmark_data/test_questions_two_hop_mcq_from_monarch_and_robokop.csv'
TRUE_FALSE_PATH : '/data/somank/kg_rag_fork/KG_RAG/data/benchmark_data/test_questions_one_hop_true_false_v3.csv'
ONE_HOP_GRAPH_TRAVERSAL : '/data/somank/KG_RAG/data/hyperparam_tuning_data/one_hop_graph_traversal_questions_v2.csv'
TWO_HOP_GRAPH_TRAVERSAL : '/data/somank/KG_RAG/data/hyperparam_tuning_data/two_hop_graph_traversal_questions.csv'
ONE_HOP_GRAPH_TRAVERSAL : '/data/somank/KG_RAG/data/hyperparam_tuning_data/single_disease_entity_prompts.csv'
TWO_HOP_GRAPH_TRAVERSAL : '/data/somank/KG_RAG/data/hyperparam_tuning_data/two_disease_entity_prompts.csv'
# SPOKE-API params
BASE_URI : 'https://spoke.rbvi.ucsf.edu'

View File

@@ -6,7 +6,7 @@
<rdf:RDF xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:cc="http://creativecommons.org/ns#" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
<cc:Work>
<dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/>
<dc:date>2024-03-17T02:37:48.089984</dc:date>
<dc:date>2024-03-18T17:45:36.644061</dc:date>
<dc:format>image/svg+xml</dc:format>
<dc:creator>
<cc:Agent>
@@ -43,7 +43,7 @@ L 76.112951 184.176471
L 76.112951 54.293532
L 52.66108 54.293532
z
" clip-path="url(#pd6d26708b6)" style="fill: #ff0000"/>
" clip-path="url(#pb406e27576)" style="fill: #ff0000"/>
</g>
<g id="patch_4">
<path d="M 119.666427 184.176471
@@ -51,7 +51,7 @@ L 143.118299 184.176471
L 143.118299 184.176471
L 119.666427 184.176471
z
" clip-path="url(#pd6d26708b6)" style="fill: #ff0000"/>
" clip-path="url(#pb406e27576)" style="fill: #ff0000"/>
</g>
<g id="patch_5">
<path d="M 76.112951 184.176471
@@ -59,7 +59,7 @@ L 99.564823 184.176471
L 99.564823 16.194536
L 76.112951 16.194536
z
" clip-path="url(#pd6d26708b6)" style="fill: #008000"/>
" clip-path="url(#pb406e27576)" style="fill: #008000"/>
</g>
<g id="patch_6">
<path d="M 143.118299 184.176471
@@ -67,18 +67,18 @@ L 166.57017 184.176471
L 166.57017 16.194536
L 143.118299 16.194536
z
" clip-path="url(#pd6d26708b6)" style="fill: #008000"/>
" clip-path="url(#pb406e27576)" style="fill: #008000"/>
</g>
<g id="matplotlib.axis_1">
<g id="xtick_1">
<g id="line2d_1">
<defs>
<path id="mf3be53fb4d" d="M 0 0
<path id="m7ec6a7be22" d="M 0 0
L 0 3.5
" style="stroke: #000000; stroke-width: 0.8"/>
</defs>
<g>
<use xlink:href="#mf3be53fb4d" x="76.112951" y="184.176471" style="stroke: #000000; stroke-width: 0.8"/>
<use xlink:href="#m7ec6a7be22" x="76.112951" y="184.176471" style="stroke: #000000; stroke-width: 0.8"/>
</g>
</g>
<g id="text_1">
@@ -302,7 +302,7 @@ z
<g id="xtick_2">
<g id="line2d_2">
<g>
<use xlink:href="#mf3be53fb4d" x="143.118299" y="184.176471" style="stroke: #000000; stroke-width: 0.8"/>
<use xlink:href="#m7ec6a7be22" x="143.118299" y="184.176471" style="stroke: #000000; stroke-width: 0.8"/>
</g>
</g>
<g id="text_2">
@@ -348,12 +348,12 @@ z
<g id="ytick_1">
<g id="line2d_3">
<defs>
<path id="md3933c749c" d="M 0 0
<path id="m57963c909c" d="M 0 0
L -3.5 0
" style="stroke: #000000; stroke-width: 0.8"/>
</defs>
<g>
<use xlink:href="#md3933c749c" x="46.965625" y="184.176471" style="stroke: #000000; stroke-width: 0.8"/>
<use xlink:href="#m57963c909c" x="46.965625" y="184.176471" style="stroke: #000000; stroke-width: 0.8"/>
</g>
</g>
<g id="text_3">
@@ -389,7 +389,7 @@ z
<g id="ytick_2">
<g id="line2d_4">
<g>
<use xlink:href="#md3933c749c" x="46.965625" y="149.54102" style="stroke: #000000; stroke-width: 0.8"/>
<use xlink:href="#m57963c909c" x="46.965625" y="149.54102" style="stroke: #000000; stroke-width: 0.8"/>
</g>
</g>
<g id="text_4">
@@ -429,7 +429,7 @@ z
<g id="ytick_3">
<g id="line2d_5">
<g>
<use xlink:href="#md3933c749c" x="46.965625" y="114.90557" style="stroke: #000000; stroke-width: 0.8"/>
<use xlink:href="#m57963c909c" x="46.965625" y="114.90557" style="stroke: #000000; stroke-width: 0.8"/>
</g>
</g>
<g id="text_5">
@@ -464,7 +464,7 @@ z
<g id="ytick_4">
<g id="line2d_6">
<g>
<use xlink:href="#md3933c749c" x="46.965625" y="80.270119" style="stroke: #000000; stroke-width: 0.8"/>
<use xlink:href="#m57963c909c" x="46.965625" y="80.270119" style="stroke: #000000; stroke-width: 0.8"/>
</g>
</g>
<g id="text_6">
@@ -510,7 +510,7 @@ z
<g id="ytick_5">
<g id="line2d_7">
<g>
<use xlink:href="#md3933c749c" x="46.965625" y="45.634669" style="stroke: #000000; stroke-width: 0.8"/>
<use xlink:href="#m57963c909c" x="46.965625" y="45.634669" style="stroke: #000000; stroke-width: 0.8"/>
</g>
</g>
<g id="text_7">
@@ -565,7 +565,7 @@ z
<g id="ytick_6">
<g id="line2d_8">
<g>
<use xlink:href="#md3933c749c" x="46.965625" y="10.999219" style="stroke: #000000; stroke-width: 0.8"/>
<use xlink:href="#m57963c909c" x="46.965625" y="10.999219" style="stroke: #000000; stroke-width: 0.8"/>
</g>
</g>
<g id="text_8">
@@ -925,7 +925,7 @@ z
</g>
</g>
<defs>
<clipPath id="pd6d26708b6">
<clipPath id="pb406e27576">
<rect x="46.965625" y="7.79544" width="125.3" height="176.381031"/>
</clipPath>
</defs>

Before

Width:  |  Height:  |  Size: 21 KiB

After

Width:  |  Height:  |  Size: 21 KiB

File diff suppressed because it is too large Load Diff

Before

Width:  |  Height:  |  Size: 21 KiB

After

Width:  |  Height:  |  Size: 21 KiB

View File

@@ -0,0 +1,502 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 11,
"id": "f0ed1d29",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import GraphCypherQAChain\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.graphs import Neo4jGraph\n",
"from langchain.callbacks import get_openai_callback\n",
"from dotenv import load_dotenv\n",
"import os\n",
"import openai\n",
"import pandas as pd\n",
"from neo4j.exceptions import CypherSyntaxError\n"
]
},
{
"cell_type": "markdown",
"id": "1d905ac1",
"metadata": {},
"source": [
"## Choose the LLM"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "763d7ad7",
"metadata": {},
"outputs": [],
"source": [
"LLM_MODEL = 'gpt-4-32k'\n"
]
},
{
"cell_type": "markdown",
"id": "f02bd807",
"metadata": {},
"source": [
"## Load test data"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "43772231",
"metadata": {},
"outputs": [],
"source": [
"data = pd.read_csv('../data/rag_comparison_data.csv')\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "8bb1d1e3",
"metadata": {},
"source": [
"## Custom function for neo4j RAG chain"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "611fee8c",
"metadata": {},
"outputs": [],
"source": [
"def get_neo4j_cypher_rag_chain():\n",
" load_dotenv(os.path.join(os.path.expanduser('~'), '.mate_neo4j_config.env'))\n",
" username = os.environ.get('MATE_USR')\n",
" password = os.environ.get('MATE_PSW')\n",
" url = os.environ.get('MATE_URI')\n",
" database = os.environ.get('DB_NAME')\n",
" \n",
" graph = Neo4jGraph(\n",
" url=url, \n",
" username=username, \n",
" password=password,\n",
" database = database\n",
" )\n",
"\n",
" load_dotenv(os.path.join(os.path.expanduser('~'), '.gpt_config.env'))\n",
" API_KEY = os.environ.get('API_KEY')\n",
" API_VERSION = os.environ.get('API_VERSION')\n",
" RESOURCE_ENDPOINT = os.environ.get('RESOURCE_ENDPOINT')\n",
" openai.api_type = \"azure\"\n",
" openai.api_key = API_KEY\n",
" openai.api_base = RESOURCE_ENDPOINT\n",
" openai.api_version = API_VERSION\n",
" chat_deployment_id = LLM_MODEL\n",
" chat_model_id = chat_deployment_id\n",
" temperature = 0\n",
" chat_model = ChatOpenAI(openai_api_key=API_KEY, \n",
" engine=chat_deployment_id, \n",
" temperature=temperature)\n",
" chain = GraphCypherQAChain.from_llm(\n",
" chat_model, \n",
" graph=graph, \n",
" verbose=True, \n",
" validate_cypher=True,\n",
" return_intermediate_steps=True\n",
" )\n",
" return chain"
]
},
{
"cell_type": "markdown",
"id": "8b920685",
"metadata": {},
"source": [
"## Initiate neo4j RAG chain"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "29b40370",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING! engine is not default parameter.\n",
" engine was transferred to model_kwargs.\n",
" Please confirm that engine is what you intended.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 13.4 ms, sys: 5.35 ms, total: 18.7 ms\n",
"Wall time: 71.1 ms\n"
]
}
],
"source": [
"%%time\n",
"neo4j_rag_chain = get_neo4j_cypher_rag_chain()\n"
]
},
{
"cell_type": "markdown",
"id": "cfa082f8",
"metadata": {},
"source": [
"## Run on test data"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "172e0f96",
"metadata": {
"collapsed": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
"Generated Cypher:\n",
"\u001b[32;1m\u001b[1;3mThe provided schema does not include a property for GWAS p-value or any nodes or relationships that would represent an association between a disease and a gene. Therefore, it is not possible to construct a Cypher statement to answer this question based on the provided schema.\u001b[0m\n"
]
},
{
"ename": "ValueError",
"evalue": "Length of values (1) does not match length of index (100)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"File \u001b[0;32m<timed exec>:15\u001b[0m\n",
"File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/pandas/core/indexing.py:818\u001b[0m, in \u001b[0;36m_LocationIndexer.__setitem__\u001b[0;34m(self, key, value)\u001b[0m\n\u001b[1;32m 815\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_has_valid_setitem_indexer(key)\n\u001b[1;32m 817\u001b[0m iloc \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mname \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124miloc\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mobj\u001b[38;5;241m.\u001b[39miloc\n\u001b[0;32m--> 818\u001b[0m \u001b[43miloc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_setitem_with_indexer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindexer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/pandas/core/indexing.py:1728\u001b[0m, in \u001b[0;36m_iLocIndexer._setitem_with_indexer\u001b[0;34m(self, indexer, value, name)\u001b[0m\n\u001b[1;32m 1725\u001b[0m \u001b[38;5;66;03m# add a new item with the dtype setup\u001b[39;00m\n\u001b[1;32m 1726\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m com\u001b[38;5;241m.\u001b[39mis_null_slice(indexer[\u001b[38;5;241m0\u001b[39m]):\n\u001b[1;32m 1727\u001b[0m \u001b[38;5;66;03m# We are setting an entire column\u001b[39;00m\n\u001b[0;32m-> 1728\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mobj[key] \u001b[38;5;241m=\u001b[39m value\n\u001b[1;32m 1729\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 1730\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m is_array_like(value):\n\u001b[1;32m 1731\u001b[0m \u001b[38;5;66;03m# GH#42099\u001b[39;00m\n",
"File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/pandas/core/frame.py:3980\u001b[0m, in \u001b[0;36mDataFrame.__setitem__\u001b[0;34m(self, key, value)\u001b[0m\n\u001b[1;32m 3977\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_setitem_array([key], value)\n\u001b[1;32m 3978\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3979\u001b[0m \u001b[38;5;66;03m# set column\u001b[39;00m\n\u001b[0;32m-> 3980\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_set_item\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/pandas/core/frame.py:4174\u001b[0m, in \u001b[0;36mDataFrame._set_item\u001b[0;34m(self, key, value)\u001b[0m\n\u001b[1;32m 4164\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_set_item\u001b[39m(\u001b[38;5;28mself\u001b[39m, key, value) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 4165\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 4166\u001b[0m \u001b[38;5;124;03m Add series to DataFrame in specified column.\u001b[39;00m\n\u001b[1;32m 4167\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 4172\u001b[0m \u001b[38;5;124;03m ensure homogeneity.\u001b[39;00m\n\u001b[1;32m 4173\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 4174\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sanitize_column\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4176\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 4177\u001b[0m key \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcolumns\n\u001b[1;32m 4178\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m value\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 4179\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_extension_array_dtype(value)\n\u001b[1;32m 4180\u001b[0m ):\n\u001b[1;32m 4181\u001b[0m \u001b[38;5;66;03m# broadcast across multiple columns if necessary\u001b[39;00m\n\u001b[1;32m 4182\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcolumns\u001b[38;5;241m.\u001b[39mis_unique \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcolumns, MultiIndex):\n",
"File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/pandas/core/frame.py:4915\u001b[0m, in \u001b[0;36mDataFrame._sanitize_column\u001b[0;34m(self, value)\u001b[0m\n\u001b[1;32m 4912\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _reindex_for_setitem(Series(value), \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mindex)\n\u001b[1;32m 4914\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_list_like(value):\n\u001b[0;32m-> 4915\u001b[0m \u001b[43mcom\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequire_length_match\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4916\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m sanitize_array(value, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mindex, copy\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, allow_2d\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
"File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/pandas/core/common.py:571\u001b[0m, in \u001b[0;36mrequire_length_match\u001b[0;34m(data, index)\u001b[0m\n\u001b[1;32m 567\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 568\u001b[0m \u001b[38;5;124;03mCheck the length of data matches the length of the index.\u001b[39;00m\n\u001b[1;32m 569\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 570\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(data) \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mlen\u001b[39m(index):\n\u001b[0;32m--> 571\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 572\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLength of values \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 573\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m(\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(data)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m) \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 574\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdoes not match length of index \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 575\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m(\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(index)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m)\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 576\u001b[0m )\n",
"\u001b[0;31mValueError\u001b[0m: Length of values (1) does not match length of index (100)"
]
}
],
"source": [
"%%time\n",
"\n",
"neo4j_rag_answer = []\n",
"total_tokens_used = []\n",
"\n",
"for index, row in data.iterrows():\n",
" question = row['question']\n",
" with get_openai_callback() as cb:\n",
" try:\n",
" neo4j_rag_answer.append(neo4j_rag_chain.run(query=question, return_final_only=True, verbose=False))\n",
" except ValueError as e:\n",
" neo4j_rag_answer.append(None)\n",
" total_tokens_used.append(cb.total_tokens) \n",
" \n",
"\n",
"data.loc[:,'neo4j_rag_answer'] = neo4j_rag_answer\n",
"data.loc[:, 'total_tokens_used'] = total_tokens_used\n"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "43e51892",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[473]"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"total_tokens_used"
]
},
{
"cell_type": "markdown",
"id": "5ea4a4a9",
"metadata": {},
"source": [
"## Save the result"
]
},
{
"cell_type": "code",
"execution_count": 120,
"id": "27ea4be4",
"metadata": {},
"outputs": [],
"source": [
"save_path = '../data/results'\n",
"os.makedirs(save_path, exist_ok=True)\n",
"data.to_csv(os.path.join(save_path, 'neo4j_rag_output.csv'), index=False)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9635a67c",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>disease_name</th>\n",
" <th>gene_name</th>\n",
" <th>gwas_pvalue</th>\n",
" <th>question</th>\n",
" <th>question_perturbed</th>\n",
" <th>neo4j_rag_answer</th>\n",
" <th>total_tokens_used</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>childhood-onset asthma</td>\n",
" <td>RORA</td>\n",
" <td>2.000000e-37</td>\n",
" <td>What is the GWAS p-value for the association b...</td>\n",
" <td>What is the GWAS p-value for the association b...</td>\n",
" <td>None</td>\n",
" <td>10993</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>skin benign neoplasm</td>\n",
" <td>SHANK2</td>\n",
" <td>5.000000e-08</td>\n",
" <td>What is the GWAS p-value for the association b...</td>\n",
" <td>What is the GWAS p-value for the association b...</td>\n",
" <td>I'm sorry, but I don't know the answer.</td>\n",
" <td>11138</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>hypertrophic cardiomyopathy</td>\n",
" <td>AMBRA1</td>\n",
" <td>1.000000e-16</td>\n",
" <td>Is hypertrophic cardiomyopathy associated with...</td>\n",
" <td>Is hypertrophic cardiomyopathy associated with...</td>\n",
" <td>Yes, hypertrophic cardiomyopathy is associated...</td>\n",
" <td>11468</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>lung adenocarcinoma</td>\n",
" <td>CYP2A6</td>\n",
" <td>8.000000e-11</td>\n",
" <td>What is the GWAS p-value for the association b...</td>\n",
" <td>What is the GWAS p-value for the association b...</td>\n",
" <td>I'm sorry, but I don't have the information to...</td>\n",
" <td>11150</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>idiopathic generalized epilepsy</td>\n",
" <td>RYR2</td>\n",
" <td>3.000000e-09</td>\n",
" <td>Is idiopathic generalized epilepsy associated ...</td>\n",
" <td>Is idiopathic generalized epilepsy associated ...</td>\n",
" <td>No, idiopathic generalized epilepsy is not ass...</td>\n",
" <td>11129</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>95</th>\n",
" <td>lung squamous cell carcinoma</td>\n",
" <td>BRCA2</td>\n",
" <td>1.000000e-15</td>\n",
" <td>Is lung squamous cell carcinoma associated wit...</td>\n",
" <td>Is lung squamous cell carcinoma associated wit...</td>\n",
" <td>Yes, lung squamous cell carcinoma is associate...</td>\n",
" <td>11129</td>\n",
" </tr>\n",
" <tr>\n",
" <th>96</th>\n",
" <td>systemic lupus erythematosus</td>\n",
" <td>HLA-DRA</td>\n",
" <td>2.000000e-60</td>\n",
" <td>What is the GWAS p-value for the association b...</td>\n",
" <td>What is the GWAS p-value for the association b...</td>\n",
" <td>I'm sorry, but I don't have the information to...</td>\n",
" <td>11152</td>\n",
" </tr>\n",
" <tr>\n",
" <th>97</th>\n",
" <td>type 2 diabetes mellitus</td>\n",
" <td>UBE2E2</td>\n",
" <td>2.000000e-42</td>\n",
" <td>Is type 2 diabetes mellitus associated with UB...</td>\n",
" <td>Is type 2 diabetes mellitus associated with ub...</td>\n",
" <td>No, type 2 diabetes mellitus is not associated...</td>\n",
" <td>11139</td>\n",
" </tr>\n",
" <tr>\n",
" <th>98</th>\n",
" <td>allergic rhinitis</td>\n",
" <td>HLA-DQA1</td>\n",
" <td>1.000000e-43</td>\n",
" <td>What is the GWAS p-value for the association b...</td>\n",
" <td>What is the GWAS p-value for the association b...</td>\n",
" <td>I'm sorry, but I don't know the answer.</td>\n",
" <td>11143</td>\n",
" </tr>\n",
" <tr>\n",
" <th>99</th>\n",
" <td>systemic lupus erythematosus</td>\n",
" <td>HLA-DQA1</td>\n",
" <td>1.000000e-54</td>\n",
" <td>What is the GWAS p-value for the association b...</td>\n",
" <td>What is the GWAS p-value for the association b...</td>\n",
" <td>I'm sorry, but I don't have the information to...</td>\n",
" <td>11149</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>100 rows × 7 columns</p>\n",
"</div>"
],
"text/plain": [
" disease_name gene_name gwas_pvalue \\\n",
"0 childhood-onset asthma RORA 2.000000e-37 \n",
"1 skin benign neoplasm SHANK2 5.000000e-08 \n",
"2 hypertrophic cardiomyopathy AMBRA1 1.000000e-16 \n",
"3 lung adenocarcinoma CYP2A6 8.000000e-11 \n",
"4 idiopathic generalized epilepsy RYR2 3.000000e-09 \n",
".. ... ... ... \n",
"95 lung squamous cell carcinoma BRCA2 1.000000e-15 \n",
"96 systemic lupus erythematosus HLA-DRA 2.000000e-60 \n",
"97 type 2 diabetes mellitus UBE2E2 2.000000e-42 \n",
"98 allergic rhinitis HLA-DQA1 1.000000e-43 \n",
"99 systemic lupus erythematosus HLA-DQA1 1.000000e-54 \n",
"\n",
" question \\\n",
"0 What is the GWAS p-value for the association b... \n",
"1 What is the GWAS p-value for the association b... \n",
"2 Is hypertrophic cardiomyopathy associated with... \n",
"3 What is the GWAS p-value for the association b... \n",
"4 Is idiopathic generalized epilepsy associated ... \n",
".. ... \n",
"95 Is lung squamous cell carcinoma associated wit... \n",
"96 What is the GWAS p-value for the association b... \n",
"97 Is type 2 diabetes mellitus associated with UB... \n",
"98 What is the GWAS p-value for the association b... \n",
"99 What is the GWAS p-value for the association b... \n",
"\n",
" question_perturbed \\\n",
"0 What is the GWAS p-value for the association b... \n",
"1 What is the GWAS p-value for the association b... \n",
"2 Is hypertrophic cardiomyopathy associated with... \n",
"3 What is the GWAS p-value for the association b... \n",
"4 Is idiopathic generalized epilepsy associated ... \n",
".. ... \n",
"95 Is lung squamous cell carcinoma associated wit... \n",
"96 What is the GWAS p-value for the association b... \n",
"97 Is type 2 diabetes mellitus associated with ub... \n",
"98 What is the GWAS p-value for the association b... \n",
"99 What is the GWAS p-value for the association b... \n",
"\n",
" neo4j_rag_answer total_tokens_used \n",
"0 None 10993 \n",
"1 I'm sorry, but I don't know the answer. 11138 \n",
"2 Yes, hypertrophic cardiomyopathy is associated... 11468 \n",
"3 I'm sorry, but I don't have the information to... 11150 \n",
"4 No, idiopathic generalized epilepsy is not ass... 11129 \n",
".. ... ... \n",
"95 Yes, lung squamous cell carcinoma is associate... 11129 \n",
"96 I'm sorry, but I don't have the information to... 11152 \n",
"97 No, type 2 diabetes mellitus is not associated... 11139 \n",
"98 I'm sorry, but I don't know the answer. 11143 \n",
"99 I'm sorry, but I don't have the information to... 11149 \n",
"\n",
"[100 rows x 7 columns]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "183b82b1",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 62,
"execution_count": 171,
"id": "d514b0e6",
"metadata": {},
"outputs": [],
@@ -25,12 +25,12 @@
},
{
"cell_type": "code",
"execution_count": 109,
"execution_count": 172,
"id": "109057cd",
"metadata": {},
"outputs": [],
"source": [
"neo4j_rag = pd.read_csv('../data/results/neo4j_rag_output.csv')\n",
"neo4j_rag = pd.read_csv('../data/results/cypher_rag_output.csv')\n",
"kg_rag = pd.read_csv('../data/results/kg_rag_output.csv')\n",
"\n"
]
@@ -45,7 +45,7 @@
},
{
"cell_type": "code",
"execution_count": 151,
"execution_count": 173,
"id": "12e415b1",
"metadata": {},
"outputs": [
@@ -78,7 +78,7 @@
"kg_sem = kg_rag['token_usage'].sem()\n",
"\n",
"\n",
"plt.figure(figsize=(3, 3))\n",
"fig = plt.figure(figsize=(3, 3))\n",
"\n",
"plt.bar(0, neo4j_avg, yerr=neo4j_sem, color='red', ecolor='black', capsize=5, label='Cypher-RAG')\n",
"\n",
@@ -119,7 +119,7 @@
},
{
"cell_type": "code",
"execution_count": 125,
"execution_count": 174,
"id": "757f36d5",
"metadata": {},
"outputs": [
@@ -171,7 +171,7 @@
},
{
"cell_type": "code",
"execution_count": 126,
"execution_count": 175,
"id": "0a433581",
"metadata": {},
"outputs": [
@@ -223,7 +223,7 @@
},
{
"cell_type": "code",
"execution_count": 152,
"execution_count": 176,
"id": "e6d8690d",
"metadata": {},
"outputs": [

View File

@@ -0,0 +1,339 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 12,
"id": "293b41f7",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import GraphCypherQAChain\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.graphs import Neo4jGraph\n",
"from dotenv import load_dotenv\n",
"import os\n",
"import openai\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "ca51ba69",
"metadata": {},
"outputs": [],
"source": [
"\n",
"load_dotenv(os.path.join(os.path.expanduser('~'), '.spoke_neo4j_config.env'))\n",
"username = os.environ.get('NEO4J_USER')\n",
"password = os.environ.get('NEO4J_PSW')\n",
"url = os.environ.get('NEO4J_URI')\n",
"database = os.environ.get('NEO4J_DB')\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "5db0b6ca",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 419 ms, sys: 72.2 ms, total: 491 ms\n",
"Wall time: 27.6 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"graph = Neo4jGraph(\n",
" url=url, \n",
" username=username, \n",
" password=password,\n",
" database = database\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "5a6aa873",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING! engine is not default parameter.\n",
" engine was transferred to model_kwargs.\n",
" Please confirm that engine is what you intended.\n"
]
}
],
"source": [
"load_dotenv(os.path.join(os.path.expanduser('~'), '.gpt_config.env'))\n",
"API_KEY = os.environ.get('API_KEY')\n",
"API_VERSION = os.environ.get('API_VERSION')\n",
"RESOURCE_ENDPOINT = os.environ.get('RESOURCE_ENDPOINT')\n",
"\n",
"openai.api_type = \"azure\"\n",
"openai.api_key = API_KEY\n",
"openai.api_base = RESOURCE_ENDPOINT\n",
"openai.api_version = API_VERSION\n",
"\n",
"chat_deployment_id = 'gpt-4-32k'\n",
"chat_model_id = chat_deployment_id\n",
"\n",
"temperature = 0\n",
"\n",
"chat_model = ChatOpenAI(openai_api_key=API_KEY, \n",
" engine=chat_deployment_id, \n",
" temperature=temperature)\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "76437285",
"metadata": {},
"outputs": [],
"source": [
"chain = GraphCypherQAChain.from_llm(\n",
" chat_model, \n",
" graph=graph, \n",
" verbose=True, \n",
" validate_cypher=True,\n",
" return_intermediate_steps=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "3fd3b9c5",
"metadata": {},
"outputs": [],
"source": [
"from langchain.callbacks import get_openai_callback\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "ed67a504",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
"Generated Cypher:\n",
"\u001b[32;1m\u001b[1;3mMATCH (c:Compound {name: 'levodopa'})-[:TREATS_CtD]->(d:Disease {name: 'Parkinson\\'s disease'}) RETURN c,d\u001b[0m\n",
"Full Context:\n",
"\u001b[32;1m\u001b[1;3m[]\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"I'm sorry, but I don't have the information to answer that question.\n"
]
}
],
"source": [
"question = 'What are the genes associated with multiple sclerosis?'\n",
"question = \"Is Parkinson's disease associated with levodopa?\"\n",
"# question=\"Which gene has stronger association with the disease 'liver benign neoplasm', is it PNPLA3 or HLA-B?\"\n",
"# question='What is the clinical phase of levodopa in treating parkinson disease?'\n",
"with get_openai_callback() as cb:\n",
" out = chain.run(query=question, return_final_only=False)\n",
" print(out)\n"
]
},
{
"cell_type": "code",
"execution_count": 100,
"id": "d1673080",
"metadata": {},
"outputs": [],
"source": [
"question = 'What are the genes associated with multiple sclerosis?'\n",
"question=\"Which gene has stronger association with the disease 'liver benign neoplasm', is it PNPLA3 or HLA-B?\"\n",
"question='What is the clinical phase of levodopa treating parkinson disease?'\n",
"\n",
"class OutputCapturer:\n",
" def __init__(self):\n",
" self.outputs = []\n",
"\n",
" def write(self, output):\n",
" self.outputs.append(output)\n",
"\n",
"# Create an instance of OutputCapturer\n",
"output_capturer = OutputCapturer()\n",
"\n",
"# Redirect standard output to the output_capturer\n",
"import sys\n",
"original_stdout = sys.stdout\n",
"sys.stdout = output_capturer\n",
"\n",
"# Run the chain with your query\n",
"out = chain.run(query=question, return_final_only=False)\n",
"\n",
"# Restore original stdout\n",
"sys.stdout = original_stdout\n",
"\n",
"# Now `output_capturer.outputs` should contain all intermediate outputs\n"
]
},
{
"cell_type": "code",
"execution_count": 101,
"id": "4d84de0e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
]
}
],
"source": [
"print(output_capturer.outputs[8])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3347c11f",
"metadata": {},
"outputs": [],
"source": [
"1. Correct context retrieval\n",
"2. Token utilization\n"
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "ecffdba7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<bound method Chain.run of GraphCypherQAChain(memory=None, callbacks=None, callback_manager=None, verbose=True, tags=None, metadata=None, graph=<langchain.graphs.neo4j_graph.Neo4jGraph object at 0x7fd47aab9ed0>, cypher_generation_chain=LLMChain(memory=None, callbacks=None, callback_manager=None, verbose=False, tags=None, metadata=None, prompt=PromptTemplate(input_variables=['schema', 'question'], output_parser=None, partial_variables={}, template='Task:Generate Cypher statement to query a graph database.\\nInstructions:\\nUse only the provided relationship types and properties in the schema.\\nDo not use any other relationship types or properties that are not provided.\\nSchema:\\n{schema}\\nNote: Do not include any explanations or apologies in your responses.\\nDo not respond to any questions that might ask anything else than for you to construct a Cypher statement.\\nDo not include any text except the generated Cypher statement.\\n\\nThe question is:\\n{question}', template_format='f-string', validate_template=True), llm=ChatOpenAI(cache=None, verbose=False, callbacks=None, callback_manager=None, tags=None, metadata=None, client=<class 'openai.api_resources.chat_completion.ChatCompletion'>, model_name='gpt-3.5-turbo', temperature=0.0, model_kwargs={'engine': 'gpt-4-32k'}, openai_api_key='N2ZiNDk0ZjhkNWNiNDhhZjhlMjNhNzY0YjNhYWRkZjI6M0FBZjBEODY2ZkIxNGQxZmEwRDc2NjRiQjQzMzFBOTI=', openai_api_base='', openai_organization='', openai_proxy='', request_timeout=None, max_retries=6, streaming=False, n=1, max_tokens=None, tiktoken_model_name=None), output_key='text', output_parser=StrOutputParser(), return_final_only=True, llm_kwargs={}), qa_chain=LLMChain(memory=None, callbacks=None, callback_manager=None, verbose=False, tags=None, metadata=None, prompt=PromptTemplate(input_variables=['context', 'question'], output_parser=None, partial_variables={}, template=\"You are an assistant that helps to form nice and human understandable answers.\\nThe information part contains the provided information that you must use to construct an answer.\\nThe provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it.\\nMake the answer sound as a response to the question. Do not mention that you based the result on the given information.\\nIf the provided information is empty, say that you don't know the answer.\\nInformation:\\n{context}\\n\\nQuestion: {question}\\nHelpful Answer:\", template_format='f-string', validate_template=True), llm=ChatOpenAI(cache=None, verbose=False, callbacks=None, callback_manager=None, tags=None, metadata=None, client=<class 'openai.api_resources.chat_completion.ChatCompletion'>, model_name='gpt-3.5-turbo', temperature=0.0, model_kwargs={'engine': 'gpt-4-32k'}, openai_api_key='N2ZiNDk0ZjhkNWNiNDhhZjhlMjNhNzY0YjNhYWRkZjI6M0FBZjBEODY2ZkIxNGQxZmEwRDc2NjRiQjQzMzFBOTI=', openai_api_base='', openai_organization='', openai_proxy='', request_timeout=None, max_retries=6, streaming=False, n=1, max_tokens=None, tiktoken_model_name=None), output_key='text', output_parser=StrOutputParser(), return_final_only=True, llm_kwargs={}), input_key='query', output_key='result', top_k=10, return_intermediate_steps=True, return_direct=False)>"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.run"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "db7ed8c9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Task:Generate Cypher statement to query a graph database.\n",
"Instructions:\n",
"Use only the provided relationship types and properties in the schema.\n",
"Do not use any other relationship types or properties that are not provided.\n",
"Schema:\n",
"{schema}\n",
"Note: Do not include any explanations or apologies in your responses.\n",
"Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.\n",
"Do not include any text except the generated Cypher statement.\n",
"\n",
"The question is:\n",
"{question}\n"
]
}
],
"source": [
"print(chain.cypher_generation_chain.prompt.template)\n"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "2c98eaf6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"You are an assistant that helps to form nice and human understandable answers.\n",
"The information part contains the provided information that you must use to construct an answer.\n",
"The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it.\n",
"Make the answer sound as a response to the question. Do not mention that you based the result on the given information.\n",
"If the provided information is empty, say that you don't know the answer.\n",
"Information:\n",
"{context}\n",
"\n",
"Question: {question}\n",
"Helpful Answer:\n"
]
}
],
"source": [
"print(chain.qa_chain.prompt.template)"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "46905634",
"metadata": {},
"outputs": [],
"source": [
"schema = graph.schema\n",
"question\n",
"cypher_template = chain.cypher_generation_chain.prompt.template\n",
"cypher_template = cypher_template.format(schema=schema, question=question)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a392f860",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}