mirror of
https://github.com/BaranziniLab/KG_RAG.git
synced 2024-06-08 14:12:54 +03:00
196 lines
7.4 KiB
Plaintext
196 lines
7.4 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "b33a915d-cc1d-4102-a2ee-159c02e6c579",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"os.chdir('..')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "57c0a1b8-e339-4f6b-941e-7af7b902de7c",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/root/anaconda3/envs/kg_rag_test_2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
|
" from .autonotebook import tqdm as notebook_tqdm\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from langchain import PromptTemplate, LLMChain\n",
|
|
"from kg_rag.utility import *\n",
|
|
"from tqdm import tqdm\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "2672548d-7d25-4f3c-94d1-d19206049076",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"QUESTION_PATH = config_data[\"MCQ_PATH\"]\n",
|
|
"SYSTEM_PROMPT = system_prompts[\"MCQ_QUESTION\"]\n",
|
|
"CONTEXT_VOLUME = int(config_data[\"CONTEXT_VOLUME\"])\n",
|
|
"QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data[\"QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD\"])\n",
|
|
"QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data[\"QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY\"])\n",
|
|
"VECTOR_DB_PATH = config_data[\"VECTOR_DB_PATH\"]\n",
|
|
"NODE_CONTEXT_PATH = config_data[\"NODE_CONTEXT_PATH\"]\n",
|
|
"SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data[\"SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL\"]\n",
|
|
"SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data[\"SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL\"]\n",
|
|
"SAVE_PATH = config_data[\"SAVE_RESULTS_PATH\"]\n",
|
|
"\n",
|
|
"MODEL_NAME = 'PharMolix/BioMedGPT-LM-7B'\n",
|
|
"BRANCH_NAME = 'main'\n",
|
|
"CACHE_DIR = config_data[\"LLM_CACHE_DIR\"]\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "c753b053-be44-4ddb-8d55-3bf434428954",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"INSTRUCTION = \"Context:\\n\\n{context} \\n\\nQuestion: {question}\"\n",
|
|
"\n",
|
|
"vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL)\n",
|
|
"embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL)\n",
|
|
"node_context_df = pd.read_csv(NODE_CONTEXT_PATH)\n",
|
|
"edge_evidence = False\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "f18c9efb-556c-4b37-8b00-e06a73a19f86",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:19<00:00, 6.66s/it]\n",
|
|
"/root/anaconda3/envs/kg_rag_test_2/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:362: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
|
|
" warnings.warn(\n",
|
|
"/root/anaconda3/envs/kg_rag_test_2/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:367: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
|
|
" warnings.warn(\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"llm = llama_model(MODEL_NAME, BRANCH_NAME, CACHE_DIR) \n",
|
|
"template = get_prompt(INSTRUCTION, SYSTEM_PROMPT)\n",
|
|
"prompt = PromptTemplate(template=template, input_variables=[\"context\", \"question\"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "0370d703-4e18-4c78-9e9a-2030b498253e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"llm_chain = LLMChain(prompt=prompt, llm=llm) \n",
|
|
"question_df = pd.read_csv(QUESTION_PATH) \n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "275f4171-3be7-46ca-bf16-18160ce72f3b",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"\"Out of the given list, which Gene is associated with psoriasis and Takayasu's arteritis. Given list is: SHTN1, HLA-B, SLC14A2, BTBD9, DTNB\""
|
|
]
|
|
},
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"question_df.iloc[0].text"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "cc5a65fb-6bd3-4948-84e5-f404af83d3f7",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"0it [00:00, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (4135 > 2048). Running this sequence through the model will result in indexing errors\n",
|
|
"This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (4096). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.\n",
|
|
"0it [04:19, ?it/s]\n",
|
|
"\n",
|
|
"KeyboardInterrupt\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%%time\n",
|
|
"\n",
|
|
"answer_list = []\n",
|
|
"question_df = question_df.sample(50, random_state=40)\n",
|
|
"for index, row in tqdm(question_df.iterrows()):\n",
|
|
" question = row[\"text\"]\n",
|
|
" context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, edge_evidence)\n",
|
|
" output = llm_chain.run(context=context, question=question)\n",
|
|
" print(output)\n",
|
|
" input('press enter')\n",
|
|
" answer_list.append((row[\"text\"], row[\"correct_node\"], output))\n",
|
|
"answer_df = pd.DataFrame(answer_list, columns=[\"question\", \"correct_answer\", \"llm_answer\"])\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "94eb325d-17d4-4013-907d-7a38dabaea56",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True) \n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.9"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|