mirror of
https://github.com/BaranziniLab/KG_RAG.git
synced 2024-06-08 14:12:54 +03:00
152 lines
5.5 KiB
Plaintext
152 lines
5.5 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "aa1bc1aa-68b1-41bb-94f0-033c34d76b13",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/root/anaconda3/envs/kg_rag_test_2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
|
" from .autonotebook import tqdm as notebook_tqdm\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer\n",
|
|
"import torch\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "32d27545-5454-4328-b62f-91a86c905710",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"BIOMEDGPT_MODEL_CARD = 'PharMolix/BioMedGPT-LM-7B'\n",
|
|
"BRANCH_NAME = 'main'\n",
|
|
"CACHE_DIR = '/data/somank/llm_data/llm_models/huggingface'\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "e875c7d1-22de-4787-b75f-71975d11ff92",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:19<00:00, 6.56s/it]\n",
|
|
"/root/anaconda3/envs/kg_rag_test_2/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:362: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
|
|
" warnings.warn(\n",
|
|
"/root/anaconda3/envs/kg_rag_test_2/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:367: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
|
|
" warnings.warn(\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"CPU times: user 18.9 s, sys: 26.3 s, total: 45.3 s\n",
|
|
"Wall time: 26 s\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%%time\n",
|
|
"\n",
|
|
"tokenizer = AutoTokenizer.from_pretrained(BIOMEDGPT_MODEL_CARD,\n",
|
|
" revision=BRANCH_NAME,\n",
|
|
" cache_dir=CACHE_DIR)\n",
|
|
"model = AutoModelForCausalLM.from_pretrained(BIOMEDGPT_MODEL_CARD, \n",
|
|
" device_map='auto',\n",
|
|
" torch_dtype=torch.float16,\n",
|
|
" revision=BRANCH_NAME,\n",
|
|
" cache_dir=CACHE_DIR\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "0a28f1ce-5cc5-4a17-84d7-b0dda29815a5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"question = '''\n",
|
|
"For the following Question, answer if it is either True or False.\n",
|
|
"Question:alpha-Mannosidosis associates Gene MAN2B1\n",
|
|
"Answer:\n",
|
|
"'''\n",
|
|
"text = [\"Out of the given list, which Gene is associated with psoriasis and Takayasu's arteritis. Given list is: SHTN1, HLA-B, SLC14A2, BTBD9, DTNB\"]\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "f59eeb37-57dd-42ae-b9ff-f8442eb613a7",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"P1, TNFAIP3, TNIP1, TNIP3, TNFAIP2, TNFAIP6, TNFAIP7, TNFAIP8, TNFAIP9, TNFAIP10\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"input = tokenizer(text,\n",
|
|
" truncation=True,\n",
|
|
" return_tensors=\"pt\")\n",
|
|
"\n",
|
|
"with torch.no_grad():\n",
|
|
" streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n",
|
|
" output = model.generate(**input, \n",
|
|
" streamer=streamer, \n",
|
|
" max_new_tokens=64,\n",
|
|
" temperature=0.01,\n",
|
|
" do_sample=True)\n",
|
|
" output_text = tokenizer.decode(output[0], \n",
|
|
" skip_special_tokens=True)\n",
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "00e94dd0-4861-402a-9a3f-57874efcde31",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.9"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|