mirror of
https://github.com/BaranziniLab/KG_RAG.git
synced 2024-06-08 14:12:54 +03:00
340 lines
12 KiB
Plaintext
340 lines
12 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "293b41f7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.chains import GraphCypherQAChain\n",
|
|
"from langchain.chat_models import ChatOpenAI\n",
|
|
"from langchain.graphs import Neo4jGraph\n",
|
|
"from dotenv import load_dotenv\n",
|
|
"import os\n",
|
|
"import openai\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "ca51ba69",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"load_dotenv(os.path.join(os.path.expanduser('~'), '.spoke_neo4j_config.env'))\n",
|
|
"username = os.environ.get('NEO4J_USER')\n",
|
|
"password = os.environ.get('NEO4J_PSW')\n",
|
|
"url = os.environ.get('NEO4J_URI')\n",
|
|
"database = os.environ.get('NEO4J_DB')\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "5db0b6ca",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"CPU times: user 419 ms, sys: 72.2 ms, total: 491 ms\n",
|
|
"Wall time: 27.6 s\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%%time\n",
|
|
"\n",
|
|
"graph = Neo4jGraph(\n",
|
|
" url=url, \n",
|
|
" username=username, \n",
|
|
" password=password,\n",
|
|
" database = database\n",
|
|
")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "5a6aa873",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"WARNING! engine is not default parameter.\n",
|
|
" engine was transferred to model_kwargs.\n",
|
|
" Please confirm that engine is what you intended.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"load_dotenv(os.path.join(os.path.expanduser('~'), '.gpt_config.env'))\n",
|
|
"API_KEY = os.environ.get('API_KEY')\n",
|
|
"API_VERSION = os.environ.get('API_VERSION')\n",
|
|
"RESOURCE_ENDPOINT = os.environ.get('RESOURCE_ENDPOINT')\n",
|
|
"\n",
|
|
"openai.api_type = \"azure\"\n",
|
|
"openai.api_key = API_KEY\n",
|
|
"openai.api_base = RESOURCE_ENDPOINT\n",
|
|
"openai.api_version = API_VERSION\n",
|
|
"\n",
|
|
"chat_deployment_id = 'gpt-4-32k'\n",
|
|
"chat_model_id = chat_deployment_id\n",
|
|
"\n",
|
|
"temperature = 0\n",
|
|
"\n",
|
|
"chat_model = ChatOpenAI(openai_api_key=API_KEY, \n",
|
|
" engine=chat_deployment_id, \n",
|
|
" temperature=temperature)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "76437285",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"chain = GraphCypherQAChain.from_llm(\n",
|
|
" chat_model, \n",
|
|
" graph=graph, \n",
|
|
" verbose=True, \n",
|
|
" validate_cypher=True,\n",
|
|
" return_intermediate_steps=True\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"id": "3fd3b9c5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.callbacks import get_openai_callback\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"id": "ed67a504",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
|
|
"Generated Cypher:\n",
|
|
"\u001b[32;1m\u001b[1;3mMATCH (c:Compound {name: 'levodopa'})-[:TREATS_CtD]->(d:Disease {name: 'Parkinson\\'s disease'}) RETURN c,d\u001b[0m\n",
|
|
"Full Context:\n",
|
|
"\u001b[32;1m\u001b[1;3m[]\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n",
|
|
"I'm sorry, but I don't have the information to answer that question.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"question = 'What are the genes associated with multiple sclerosis?'\n",
|
|
"question = \"Is Parkinson's disease associated with levodopa?\"\n",
|
|
"# question=\"Which gene has stronger association with the disease 'liver benign neoplasm', is it PNPLA3 or HLA-B?\"\n",
|
|
"# question='What is the clinical phase of levodopa in treating parkinson disease?'\n",
|
|
"with get_openai_callback() as cb:\n",
|
|
" out = chain.run(query=question, return_final_only=False)\n",
|
|
" print(out)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 100,
|
|
"id": "d1673080",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"question = 'What are the genes associated with multiple sclerosis?'\n",
|
|
"question=\"Which gene has stronger association with the disease 'liver benign neoplasm', is it PNPLA3 or HLA-B?\"\n",
|
|
"question='What is the clinical phase of levodopa treating parkinson disease?'\n",
|
|
"\n",
|
|
"class OutputCapturer:\n",
|
|
" def __init__(self):\n",
|
|
" self.outputs = []\n",
|
|
"\n",
|
|
" def write(self, output):\n",
|
|
" self.outputs.append(output)\n",
|
|
"\n",
|
|
"# Create an instance of OutputCapturer\n",
|
|
"output_capturer = OutputCapturer()\n",
|
|
"\n",
|
|
"# Redirect standard output to the output_capturer\n",
|
|
"import sys\n",
|
|
"original_stdout = sys.stdout\n",
|
|
"sys.stdout = output_capturer\n",
|
|
"\n",
|
|
"# Run the chain with your query\n",
|
|
"out = chain.run(query=question, return_final_only=False)\n",
|
|
"\n",
|
|
"# Restore original stdout\n",
|
|
"sys.stdout = original_stdout\n",
|
|
"\n",
|
|
"# Now `output_capturer.outputs` should contain all intermediate outputs\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 101,
|
|
"id": "4d84de0e",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(output_capturer.outputs[8])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3347c11f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"1. Correct context retrieval\n",
|
|
"2. Token utilization\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 54,
|
|
"id": "ecffdba7",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"<bound method Chain.run of GraphCypherQAChain(memory=None, callbacks=None, callback_manager=None, verbose=True, tags=None, metadata=None, graph=<langchain.graphs.neo4j_graph.Neo4jGraph object at 0x7fd47aab9ed0>, cypher_generation_chain=LLMChain(memory=None, callbacks=None, callback_manager=None, verbose=False, tags=None, metadata=None, prompt=PromptTemplate(input_variables=['schema', 'question'], output_parser=None, partial_variables={}, template='Task:Generate Cypher statement to query a graph database.\\nInstructions:\\nUse only the provided relationship types and properties in the schema.\\nDo not use any other relationship types or properties that are not provided.\\nSchema:\\n{schema}\\nNote: Do not include any explanations or apologies in your responses.\\nDo not respond to any questions that might ask anything else than for you to construct a Cypher statement.\\nDo not include any text except the generated Cypher statement.\\n\\nThe question is:\\n{question}', template_format='f-string', validate_template=True), llm=ChatOpenAI(cache=None, verbose=False, callbacks=None, callback_manager=None, tags=None, metadata=None, client=<class 'openai.api_resources.chat_completion.ChatCompletion'>, model_name='gpt-3.5-turbo', temperature=0.0, model_kwargs={'engine': 'gpt-4-32k'}, openai_api_key='N2ZiNDk0ZjhkNWNiNDhhZjhlMjNhNzY0YjNhYWRkZjI6M0FBZjBEODY2ZkIxNGQxZmEwRDc2NjRiQjQzMzFBOTI=', openai_api_base='', openai_organization='', openai_proxy='', request_timeout=None, max_retries=6, streaming=False, n=1, max_tokens=None, tiktoken_model_name=None), output_key='text', output_parser=StrOutputParser(), return_final_only=True, llm_kwargs={}), qa_chain=LLMChain(memory=None, callbacks=None, callback_manager=None, verbose=False, tags=None, metadata=None, prompt=PromptTemplate(input_variables=['context', 'question'], output_parser=None, partial_variables={}, template=\"You are an assistant that helps to form nice and human understandable answers.\\nThe information part contains the provided information that you must use to construct an answer.\\nThe provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it.\\nMake the answer sound as a response to the question. Do not mention that you based the result on the given information.\\nIf the provided information is empty, say that you don't know the answer.\\nInformation:\\n{context}\\n\\nQuestion: {question}\\nHelpful Answer:\", template_format='f-string', validate_template=True), llm=ChatOpenAI(cache=None, verbose=False, callbacks=None, callback_manager=None, tags=None, metadata=None, client=<class 'openai.api_resources.chat_completion.ChatCompletion'>, model_name='gpt-3.5-turbo', temperature=0.0, model_kwargs={'engine': 'gpt-4-32k'}, openai_api_key='N2ZiNDk0ZjhkNWNiNDhhZjhlMjNhNzY0YjNhYWRkZjI6M0FBZjBEODY2ZkIxNGQxZmEwRDc2NjRiQjQzMzFBOTI=', openai_api_base='', openai_organization='', openai_proxy='', request_timeout=None, max_retries=6, streaming=False, n=1, max_tokens=None, tiktoken_model_name=None), output_key='text', output_parser=StrOutputParser(), return_final_only=True, llm_kwargs={}), input_key='query', output_key='result', top_k=10, return_intermediate_steps=True, return_direct=False)>"
|
|
]
|
|
},
|
|
"execution_count": 54,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"chain.run"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 29,
|
|
"id": "db7ed8c9",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Task:Generate Cypher statement to query a graph database.\n",
|
|
"Instructions:\n",
|
|
"Use only the provided relationship types and properties in the schema.\n",
|
|
"Do not use any other relationship types or properties that are not provided.\n",
|
|
"Schema:\n",
|
|
"{schema}\n",
|
|
"Note: Do not include any explanations or apologies in your responses.\n",
|
|
"Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.\n",
|
|
"Do not include any text except the generated Cypher statement.\n",
|
|
"\n",
|
|
"The question is:\n",
|
|
"{question}\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(chain.cypher_generation_chain.prompt.template)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 42,
|
|
"id": "2c98eaf6",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"You are an assistant that helps to form nice and human understandable answers.\n",
|
|
"The information part contains the provided information that you must use to construct an answer.\n",
|
|
"The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it.\n",
|
|
"Make the answer sound as a response to the question. Do not mention that you based the result on the given information.\n",
|
|
"If the provided information is empty, say that you don't know the answer.\n",
|
|
"Information:\n",
|
|
"{context}\n",
|
|
"\n",
|
|
"Question: {question}\n",
|
|
"Helpful Answer:\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(chain.qa_chain.prompt.template)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 48,
|
|
"id": "46905634",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"schema = graph.schema\n",
|
|
"question\n",
|
|
"cypher_template = chain.cypher_generation_chain.prompt.template\n",
|
|
"cypher_template = cypher_template.format(schema=schema, question=question)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a392f860",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.9"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|