more work

This commit is contained in:
Kyle Corbitt
2023-08-24 23:49:44 +00:00
parent 14eae45d18
commit 40638a7848
9 changed files with 2708 additions and 463 deletions

1
examples/.gitignore vendored
View File

@@ -1,3 +1,4 @@
axolotl/
models/
data/
wandb/

View File

@@ -0,0 +1,473 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Current Time: 2023-08-24 21:25:06\n",
"Current Time: 2023-08-24 21:25:36\n"
]
}
],
"source": [
"import time\n",
"\n",
"while True:\n",
" current_time = time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime())\n",
" print(f\"Current Time: {current_time}\")\n",
" time.sleep(30)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"I'm pretty happy with my model's accuracy relative to GPT-4. How does it compare cost-wise?\n",
"\n",
"I'll really push this to its limits -- let's see how quickly our poor model can classify the [full 2-million-recipe dataset](https://huggingface.co/datasets/corbt/all-recipes) 😈."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: datasets==2.14.4 in /usr/local/lib/python3.10/dist-packages (2.14.4)\n",
"Requirement already satisfied: vllm==0.1.3 in /usr/local/lib/python3.10/dist-packages (0.1.3)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (1.24.4)\n",
"Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (12.0.1)\n",
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.3.7)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.0.3)\n",
"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.28.1)\n",
"Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (4.66.1)\n",
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.3.0)\n",
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.70.15)\n",
"Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2023.6.0)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.8.5)\n",
"Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.16.4)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (23.1)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (6.0)\n",
"Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (5.9.5)\n",
"Requirement already satisfied: ray>=2.5.1 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.6.3)\n",
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.1.99)\n",
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.0.1+cu118)\n",
"Requirement already satisfied: transformers>=4.31.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (4.33.0.dev0)\n",
"Requirement already satisfied: xformers>=0.0.19 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.0.21)\n",
"Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.101.1)\n",
"Requirement already satisfied: uvicorn in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.23.2)\n",
"Requirement already satisfied: pydantic<2 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.10.12)\n",
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (23.1.0)\n",
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (2.1.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (4.0.3)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.9.2)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.4.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.3.1)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (3.9.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (4.7.1)\n",
"Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (8.1.7)\n",
"Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.18.0)\n",
"Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.0.5)\n",
"Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.24.1)\n",
"Requirement already satisfied: grpcio>=1.42.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.57.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (3.4)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (1.26.13)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (2022.12.7)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.0)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.1.2)\n",
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (2.0.0)\n",
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (3.25.0)\n",
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (15.0.7)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (2023.8.8)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.3.2)\n",
"Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi->vllm==0.1.3) (0.27.0)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
"Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn->vllm==0.1.3) (0.14.0)\n",
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->datasets==2.14.4) (1.16.0)\n",
"Requirement already satisfied: anyio<5,>=3.4.0 in /usr/local/lib/python3.10/dist-packages (from starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (3.7.1)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.0.0->vllm==0.1.3) (2.1.2)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (2023.6.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.29.1)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.8.10)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.0.0->vllm==0.1.3) (1.2.1)\n",
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.3.0)\n",
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.1.2)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install datasets==2.14.4 vllm==0.1.3"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of recipes: 2,147,248\n"
]
}
],
"source": [
"from datasets import load_dataset\n",
"\n",
"all_recipes = load_dataset(\"corbt/all-recipes\")[\"train\"][\"input\"]\n",
"\n",
"print(f\"Number of recipes: {len(all_recipes):,}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO 08-24 19:38:29 llm_engine.py:70] Initializing an LLM engine with config: model='./models/run1/merged', tokenizer='./models/run1/merged', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n",
"INFO 08-24 19:39:48 llm_engine.py:196] # GPU blocks: 3419, # CPU blocks: 512\n"
]
}
],
"source": [
"from vllm import LLM, SamplingParams\n",
"\n",
"llm = LLM(model=\"./models/run1/merged\", max_num_batched_tokens=4096)\n",
"\n",
"sampling_params = SamplingParams(\n",
" # 120 should be fine for the work we're doing here.\n",
" max_tokens=120,\n",
" # This is a deterministic task so temperature=0 is best.\n",
" temperature=0,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Start time: 1692906050.3340027\n",
"Processing recipes 0 to 10,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 10000/10000 [04:51<00:00, 34.30it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing recipes 10,000 to 20,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 10000/10000 [04:54<00:00, 33.98it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing recipes 20,000 to 30,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 10000/10000 [04:53<00:00, 34.11it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing recipes 30,000 to 40,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 10000/10000 [04:53<00:00, 34.11it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing recipes 40,000 to 50,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 48%|████▊ | 4796/10000 [02:21<03:18, 26.22it/s]"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[6], line 12\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(\u001b[39m0\u001b[39m, \u001b[39mlen\u001b[39m(all_recipes), BATCH_SIZE):\n\u001b[1;32m 11\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mProcessing recipes \u001b[39m\u001b[39m{\u001b[39;00mi\u001b[39m:\u001b[39;00m\u001b[39m,\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m to \u001b[39m\u001b[39m{\u001b[39;00mi\u001b[39m+\u001b[39mBATCH_SIZE\u001b[39m:\u001b[39;00m\u001b[39m,\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m...\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m---> 12\u001b[0m outputs \u001b[39m=\u001b[39m llm\u001b[39m.\u001b[39;49mgenerate(all_recipes[i:i\u001b[39m+\u001b[39;49mBATCH_SIZE], sampling_params\u001b[39m=\u001b[39;49msampling_params)\n\u001b[1;32m 14\u001b[0m all_outputs\u001b[39m.\u001b[39mextend([o\u001b[39m.\u001b[39moutputs[\u001b[39m0\u001b[39m]\u001b[39m.\u001b[39mtext \u001b[39mfor\u001b[39;00m o \u001b[39min\u001b[39;00m outputs])\n\u001b[1;32m 16\u001b[0m end_time \u001b[39m=\u001b[39m time\u001b[39m.\u001b[39mtime()\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/llm.py:130\u001b[0m, in \u001b[0;36mLLM.generate\u001b[0;34m(self, prompts, sampling_params, prompt_token_ids, use_tqdm)\u001b[0m\n\u001b[1;32m 128\u001b[0m token_ids \u001b[39m=\u001b[39m prompt_token_ids[i]\n\u001b[1;32m 129\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_add_request(prompt, sampling_params, token_ids)\n\u001b[0;32m--> 130\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_run_engine(use_tqdm)\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/llm.py:150\u001b[0m, in \u001b[0;36mLLM._run_engine\u001b[0;34m(self, use_tqdm)\u001b[0m\n\u001b[1;32m 148\u001b[0m outputs: List[RequestOutput] \u001b[39m=\u001b[39m []\n\u001b[1;32m 149\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mllm_engine\u001b[39m.\u001b[39mhas_unfinished_requests():\n\u001b[0;32m--> 150\u001b[0m step_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mllm_engine\u001b[39m.\u001b[39;49mstep()\n\u001b[1;32m 151\u001b[0m \u001b[39mfor\u001b[39;00m output \u001b[39min\u001b[39;00m step_outputs:\n\u001b[1;32m 152\u001b[0m \u001b[39mif\u001b[39;00m output\u001b[39m.\u001b[39mfinished:\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py:313\u001b[0m, in \u001b[0;36mLLMEngine.step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 307\u001b[0m \u001b[39mreturn\u001b[39;00m [\n\u001b[1;32m 308\u001b[0m RequestOutput\u001b[39m.\u001b[39mfrom_seq_group(seq_group)\n\u001b[1;32m 309\u001b[0m \u001b[39mfor\u001b[39;00m seq_group \u001b[39min\u001b[39;00m scheduler_outputs\u001b[39m.\u001b[39mignored_seq_groups\n\u001b[1;32m 310\u001b[0m ]\n\u001b[1;32m 312\u001b[0m \u001b[39m# Execute the model.\u001b[39;00m\n\u001b[0;32m--> 313\u001b[0m output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_run_workers(\n\u001b[1;32m 314\u001b[0m \u001b[39m\"\u001b[39;49m\u001b[39mexecute_model\u001b[39;49m\u001b[39m\"\u001b[39;49m,\n\u001b[1;32m 315\u001b[0m seq_group_metadata_list\u001b[39m=\u001b[39;49mseq_group_metadata_list,\n\u001b[1;32m 316\u001b[0m blocks_to_swap_in\u001b[39m=\u001b[39;49mscheduler_outputs\u001b[39m.\u001b[39;49mblocks_to_swap_in,\n\u001b[1;32m 317\u001b[0m blocks_to_swap_out\u001b[39m=\u001b[39;49mscheduler_outputs\u001b[39m.\u001b[39;49mblocks_to_swap_out,\n\u001b[1;32m 318\u001b[0m blocks_to_copy\u001b[39m=\u001b[39;49mscheduler_outputs\u001b[39m.\u001b[39;49mblocks_to_copy,\n\u001b[1;32m 319\u001b[0m )\n\u001b[1;32m 320\u001b[0m \u001b[39m# Update the scheduler with the model outputs.\u001b[39;00m\n\u001b[1;32m 321\u001b[0m seq_groups \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mscheduler\u001b[39m.\u001b[39mupdate(output)\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py:470\u001b[0m, in \u001b[0;36mLLMEngine._run_workers\u001b[0;34m(self, method, get_all_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 467\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 468\u001b[0m executor \u001b[39m=\u001b[39m \u001b[39mgetattr\u001b[39m(worker, method)\n\u001b[0;32m--> 470\u001b[0m output \u001b[39m=\u001b[39m executor(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 471\u001b[0m all_outputs\u001b[39m.\u001b[39mappend(output)\n\u001b[1;32m 473\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mparallel_config\u001b[39m.\u001b[39mworker_use_ray:\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[39m@functools\u001b[39m\u001b[39m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mdecorate_context\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[39mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[39mreturn\u001b[39;00m func(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/worker/worker.py:293\u001b[0m, in \u001b[0;36mWorker.execute_model\u001b[0;34m(self, seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)\u001b[0m\n\u001b[1;32m 289\u001b[0m input_tokens, input_positions, input_metadata \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_prepare_inputs(\n\u001b[1;32m 290\u001b[0m seq_group_metadata_list)\n\u001b[1;32m 292\u001b[0m \u001b[39m# Execute the model.\u001b[39;00m\n\u001b[0;32m--> 293\u001b[0m output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel(\n\u001b[1;32m 294\u001b[0m input_ids\u001b[39m=\u001b[39;49minput_tokens,\n\u001b[1;32m 295\u001b[0m positions\u001b[39m=\u001b[39;49minput_positions,\n\u001b[1;32m 296\u001b[0m kv_caches\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mgpu_cache,\n\u001b[1;32m 297\u001b[0m input_metadata\u001b[39m=\u001b[39;49minput_metadata,\n\u001b[1;32m 298\u001b[0m cache_events\u001b[39m=\u001b[39;49mcache_events,\n\u001b[1;32m 299\u001b[0m )\n\u001b[1;32m 300\u001b[0m \u001b[39mreturn\u001b[39;00m output\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/llama.py:255\u001b[0m, in \u001b[0;36mLlamaForCausalLM.forward\u001b[0;34m(self, input_ids, positions, kv_caches, input_metadata, cache_events)\u001b[0m\n\u001b[1;32m 245\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\n\u001b[1;32m 246\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 247\u001b[0m input_ids: torch\u001b[39m.\u001b[39mTensor,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 251\u001b[0m cache_events: Optional[List[torch\u001b[39m.\u001b[39mcuda\u001b[39m.\u001b[39mEvent]],\n\u001b[1;32m 252\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Dict[\u001b[39mint\u001b[39m, SequenceOutputs]:\n\u001b[1;32m 253\u001b[0m hidden_states \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel(input_ids, positions, kv_caches,\n\u001b[1;32m 254\u001b[0m input_metadata, cache_events)\n\u001b[0;32m--> 255\u001b[0m next_tokens \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49msampler(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mlm_head\u001b[39m.\u001b[39;49mweight, hidden_states,\n\u001b[1;32m 256\u001b[0m input_metadata)\n\u001b[1;32m 257\u001b[0m \u001b[39mreturn\u001b[39;00m next_tokens\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/sampler.py:44\u001b[0m, in \u001b[0;36mSampler.forward\u001b[0;34m(self, embedding, hidden_states, input_metadata, embedding_bias)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\n\u001b[1;32m 37\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 38\u001b[0m embedding: torch\u001b[39m.\u001b[39mTensor,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 42\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Dict[\u001b[39mint\u001b[39m, SequenceOutputs]:\n\u001b[1;32m 43\u001b[0m \u001b[39m# Get the hidden states that we use for sampling.\u001b[39;00m\n\u001b[0;32m---> 44\u001b[0m hidden_states \u001b[39m=\u001b[39m _prune_hidden_states(hidden_states, input_metadata)\n\u001b[1;32m 46\u001b[0m \u001b[39m# Get the logits for the next tokens.\u001b[39;00m\n\u001b[1;32m 47\u001b[0m logits \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mmatmul(hidden_states, embedding\u001b[39m.\u001b[39mt())\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"# We'll process our recipes in batches of 10,000.\n",
"\n",
"import time\n",
"\n",
"BATCH_SIZE = 10000\n",
"all_outputs = []\n",
"\n",
"start_time = time.time()\n",
"print(f\"Start time: {start_time}\")\n",
"for i in range(0, len(all_recipes), BATCH_SIZE):\n",
" print(f\"Processing recipes {i:,} to {i+BATCH_SIZE:,}...\")\n",
" outputs = llm.generate(\n",
" all_recipes[i : i + BATCH_SIZE], sampling_params=sampling_params\n",
" )\n",
"\n",
" all_outputs.extend([o.outputs[0].text for o in outputs])\n",
"\n",
"end_time = time.time()\n",
"print(f\"End time: {end_time}\")\n",
"print(f\"Total hours: {((end_time - start_time) / 3600):.2f}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Nice! I've processed all 2,147,248 recipes in under 17 hours. Let's do a cost comparison with GPT-3.5 and GPT-4. I'll use the GPT-4 latency/cost numbers based on the 5000 samples used to generate our model's training data."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Model</th>\n",
" <th>Cost to Classify One Recipe</th>\n",
" <th>Cost to Classify Entire Dataset</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Llama 2 7B (finetuned)</td>\n",
" <td>0.000009</td>\n",
" <td>18.86</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>GPT-3.5</td>\n",
" <td>0.000481</td>\n",
" <td>1,033.26</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>GPT-3.5 (finetuned)</td>\n",
" <td>0.004044</td>\n",
" <td>8,683.47</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>GPT-4</td>\n",
" <td>0.010800</td>\n",
" <td>23,190.28</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Model Cost to Classify One Recipe \\\n",
"0 Llama 2 7B (finetuned) 0.000009 \n",
"1 GPT-3.5 0.000481 \n",
"2 GPT-3.5 (finetuned) 0.004044 \n",
"3 GPT-4 0.010800 \n",
"\n",
" Cost to Classify Entire Dataset \n",
"0 18.86 \n",
"1 1,033.26 \n",
"2 8,683.47 \n",
"3 23,190.28 "
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"# I used an on-demand Nvidia L40 on RunPod for this, at an hourly cost of $1.14.\n",
"finetuned_hourly_cost = 1.14\n",
"\n",
"finetuned_total_hours = 16.54\n",
"\n",
"finetuned_avg_cost = finetuned_hourly_cost * finetuned_total_hours / len(all_recipes)\n",
"\n",
"# The average input and output tokens calculated by OpenAI, based on the 5000 recipes I sent them\n",
"avg_input_tokens = 276\n",
"avg_output_tokens = 42\n",
"\n",
"# Token pricing from https://openai.com/pricing\n",
"gpt_4_avg_cost = avg_input_tokens * 0.03 / 1000 + avg_output_tokens * 0.06 / 1000\n",
"\n",
"gpt_35_avg_cost = avg_input_tokens * 0.0015 / 1000 + avg_output_tokens * 0.0016 / 1000\n",
"\n",
"gpt_35_finetuned_avg_cost = (\n",
" avg_input_tokens * 0.012 / 1000 + avg_output_tokens * 0.016 / 1000 + 0.06 / 1000\n",
")\n",
"\n",
"# Multiply the number of recipes\n",
"# gpt_4_cost = len(all_recipes) * gpt_4_avg_cost\n",
"# gpt_35_cost = len(all_recipes) * gpt_35_avg_cost\n",
"# gpt_35_finetuned_cost = len(all_recipes) * gpt_35_finetuned_avg_cost\n",
"\n",
"# Let's put this in a dataframe for easier comparison.\n",
"\n",
"costs = pd.DataFrame(\n",
" {\n",
" \"Model\": [\n",
" \"Llama 2 7B (finetuned)\",\n",
" \"GPT-3.5\",\n",
" \"GPT-3.5 (finetuned)\",\n",
" \"GPT-4\",\n",
" ],\n",
" \"Cost to Classify One Recipe\": [\n",
" finetuned_avg_cost,\n",
" gpt_35_avg_cost,\n",
" gpt_35_finetuned_avg_cost,\n",
" gpt_4_avg_cost,\n",
" ],\n",
" }\n",
")\n",
"\n",
"costs[\"Cost to Classify Entire Dataset\"] = (\n",
" costs[\"Cost to Classify One Recipe\"] * len(all_recipes)\n",
").map(lambda x: f\"{x:,.2f}\")\n",
"\n",
"\n",
"costs\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"...and just for fun, let's figure out how many recipes my pescatarian basement-dwelling brother can make! 😂"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,123 @@
# %% [markdown]
# I'm pretty happy with my model's accuracy relative to GPT-4. How does it compare cost-wise?
#
# I'll really push this to its limits -- let's see how quickly our poor model can classify the [full 2-million-recipe dataset](https://huggingface.co/datasets/corbt/all-recipes) 😈.
# %%
# %%
from datasets import load_dataset
all_recipes = load_dataset("corbt/all-recipes")["train"]["input"]
print(f"Number of recipes: {len(all_recipes):,}")
# %%
from vllm import LLM, SamplingParams
llm = LLM(model="./models/run1/merged", max_num_batched_tokens=4096)
sampling_params = SamplingParams(
# 120 should be fine for the work we're doing here.
max_tokens=120,
# This is a deterministic task so temperature=0 is best.
temperature=0,
)
# %%
import os
import time
import json
BATCH_SIZE = 10000
start_time = time.time()
print(f"Start time: {start_time}")
for i in range(0, len(all_recipes), BATCH_SIZE):
# File name for the current batch
file_name = f"./data/benchmark_batch_{int(i/BATCH_SIZE)}.txt"
# Check if the file already exists; if so, skip to the next batch
if os.path.exists(file_name):
print(f"File {file_name} exists, skipping recipes {i:,} to {i+BATCH_SIZE:,}...")
continue
print(f"Processing recipes {i:,} to {i+BATCH_SIZE:,}...")
outputs = llm.generate(
all_recipes[i : i + BATCH_SIZE], sampling_params=sampling_params
)
outputs = [o.outputs[0].text for o in outputs]
# Write the generated outputs to the file as a JSON list
json.dump(outputs, open(file_name, "w"))
end_time = time.time()
print(f"End time: {end_time}")
print(f"Total hours: {((end_time - start_time) / 3600):.2f}")
# %% [markdown]
# Nice! I've processed all 2,147,248 recipes in under 17 hours. Let's do a cost comparison with GPT-3.5 and GPT-4. I'll use the GPT-4 latency/cost numbers based on the 5000 samples used to generate our model's training data.
# %%
import pandas as pd
# I used an on-demand Nvidia L40 on RunPod for this, at an hourly cost of $1.14.
finetuned_hourly_cost = 1.14
finetuned_total_hours = 17
finetuned_avg_cost = finetuned_hourly_cost * finetuned_total_hours / len(all_recipes)
# The average input and output tokens calculated by OpenAI, based on the 5000 recipes I sent them
avg_input_tokens = 276
avg_output_tokens = 42
# Token pricing from https://openai.com/pricing
gpt_4_avg_cost = avg_input_tokens * 0.03 / 1000 + avg_output_tokens * 0.06 / 1000
gpt_35_avg_cost = avg_input_tokens * 0.0015 / 1000 + avg_output_tokens * 0.0016 / 1000
gpt_35_finetuned_avg_cost = (
avg_input_tokens * 0.012 / 1000 + avg_output_tokens * 0.016 / 1000 + 0.06 / 1000
)
# Multiply the number of recipes
# gpt_4_cost = len(all_recipes) * gpt_4_avg_cost
# gpt_35_cost = len(all_recipes) * gpt_35_avg_cost
# gpt_35_finetuned_cost = len(all_recipes) * gpt_35_finetuned_avg_cost
# Let's put this in a dataframe for easier comparison.
costs = pd.DataFrame(
{
"Model": [
"Llama 2 7B (finetuned)",
"GPT-3.5",
"GPT-3.5 (finetuned)",
"GPT-4",
],
"Cost to Classify One Recipe": [
finetuned_avg_cost,
gpt_35_avg_cost,
gpt_35_finetuned_avg_cost,
gpt_4_avg_cost,
],
}
)
costs["Cost to Classify Entire Dataset"] = (
costs["Cost to Classify One Recipe"] * len(all_recipes)
).map(lambda x: f"{x:,.2f}")
costs
# %% [markdown]
# ...and just for fun, let's figure out how many recipes my pescatarian basement-dwelling brother can make! 😂
# %%

View File

@@ -0,0 +1,913 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Current Time: 2023-08-24 22:02:07\n",
"Current Time: 2023-08-24 22:02:17\n",
"Current Time: 2023-08-24 22:02:27\n",
"Current Time: 2023-08-24 22:02:37\n",
"Current Time: 2023-08-24 22:02:47\n",
"Current Time: 2023-08-24 22:02:57\n",
"Current Time: 2023-08-24 22:03:08\n",
"Current Time: 2023-08-24 22:03:18\n",
"Current Time: 2023-08-24 22:03:28\n",
"Current Time: 2023-08-24 22:03:38\n",
"Current Time: 2023-08-24 22:03:48\n",
"Current Time: 2023-08-24 22:03:58\n",
"Current Time: 2023-08-24 22:04:08\n",
"Current Time: 2023-08-24 22:04:18\n",
"Current Time: 2023-08-24 22:04:28\n",
"Current Time: 2023-08-24 22:04:38\n",
"Current Time: 2023-08-24 22:04:48\n",
"Current Time: 2023-08-24 22:04:58\n",
"Current Time: 2023-08-24 22:05:08\n",
"Current Time: 2023-08-24 22:05:18\n",
"Current Time: 2023-08-24 22:05:28\n",
"Current Time: 2023-08-24 22:05:38\n",
"Current Time: 2023-08-24 22:05:48\n",
"Current Time: 2023-08-24 22:05:58\n",
"Current Time: 2023-08-24 22:06:08\n",
"Current Time: 2023-08-24 22:06:18\n",
"Current Time: 2023-08-24 22:06:28\n",
"Current Time: 2023-08-24 22:06:38\n",
"Current Time: 2023-08-24 22:06:48\n",
"Current Time: 2023-08-24 22:06:58\n",
"Current Time: 2023-08-24 22:07:08\n",
"Current Time: 2023-08-24 22:07:18\n",
"Current Time: 2023-08-24 22:07:28\n",
"Current Time: 2023-08-24 22:07:38\n",
"Current Time: 2023-08-24 22:07:48\n",
"Current Time: 2023-08-24 22:07:58\n",
"Current Time: 2023-08-24 22:08:08\n",
"Current Time: 2023-08-24 22:08:18\n",
"Current Time: 2023-08-24 22:08:28\n",
"Current Time: 2023-08-24 22:08:38\n",
"Current Time: 2023-08-24 22:08:48\n",
"Current Time: 2023-08-24 22:08:58\n",
"Current Time: 2023-08-24 22:09:08\n",
"Current Time: 2023-08-24 22:09:18\n",
"Current Time: 2023-08-24 22:09:28\n",
"Current Time: 2023-08-24 22:09:38\n",
"Current Time: 2023-08-24 22:09:48\n",
"Current Time: 2023-08-24 22:09:58\n",
"Current Time: 2023-08-24 22:10:08\n",
"Current Time: 2023-08-24 22:10:18\n",
"Current Time: 2023-08-24 22:10:28\n",
"Current Time: 2023-08-24 22:10:38\n",
"Current Time: 2023-08-24 22:10:48\n",
"Current Time: 2023-08-24 22:10:58\n",
"Current Time: 2023-08-24 22:11:08\n",
"Current Time: 2023-08-24 22:11:18\n",
"Current Time: 2023-08-24 22:11:28\n",
"Current Time: 2023-08-24 22:11:38\n",
"Current Time: 2023-08-24 22:11:48\n",
"Current Time: 2023-08-24 22:11:58\n",
"Current Time: 2023-08-24 22:12:08\n",
"Current Time: 2023-08-24 22:12:18\n",
"Current Time: 2023-08-24 22:12:28\n",
"Current Time: 2023-08-24 22:12:38\n",
"Current Time: 2023-08-24 22:12:48\n",
"Current Time: 2023-08-24 22:12:58\n",
"Current Time: 2023-08-24 22:13:08\n",
"Current Time: 2023-08-24 22:13:18\n",
"Current Time: 2023-08-24 22:13:28\n",
"Current Time: 2023-08-24 22:13:38\n",
"Current Time: 2023-08-24 22:13:48\n",
"Current Time: 2023-08-24 22:13:58\n",
"Current Time: 2023-08-24 22:14:08\n",
"Current Time: 2023-08-24 22:14:18\n",
"Current Time: 2023-08-24 22:14:28\n",
"Current Time: 2023-08-24 22:14:38\n",
"Current Time: 2023-08-24 22:14:48\n",
"Current Time: 2023-08-24 22:14:58\n",
"Current Time: 2023-08-24 22:15:08\n",
"Current Time: 2023-08-24 22:15:18\n",
"Current Time: 2023-08-24 22:15:28\n",
"Current Time: 2023-08-24 22:15:38\n",
"Current Time: 2023-08-24 22:15:48\n",
"Current Time: 2023-08-24 22:15:58\n",
"Current Time: 2023-08-24 22:16:08\n",
"Current Time: 2023-08-24 22:16:18\n",
"Current Time: 2023-08-24 22:16:28\n",
"Current Time: 2023-08-24 22:16:38\n",
"Current Time: 2023-08-24 22:16:48\n",
"Current Time: 2023-08-24 22:16:58\n",
"Current Time: 2023-08-24 22:17:08\n",
"Current Time: 2023-08-24 22:17:18\n",
"Current Time: 2023-08-24 22:17:28\n",
"Current Time: 2023-08-24 22:17:38\n",
"Current Time: 2023-08-24 22:17:48\n",
"Current Time: 2023-08-24 22:17:58\n",
"Current Time: 2023-08-24 22:18:08\n",
"Current Time: 2023-08-24 22:18:18\n",
"Current Time: 2023-08-24 22:18:28\n",
"Current Time: 2023-08-24 22:18:38\n",
"Current Time: 2023-08-24 22:18:48\n",
"Current Time: 2023-08-24 22:18:58\n",
"Current Time: 2023-08-24 22:19:08\n",
"Current Time: 2023-08-24 22:19:18\n",
"Current Time: 2023-08-24 22:19:28\n",
"Current Time: 2023-08-24 22:19:38\n",
"Current Time: 2023-08-24 22:19:48\n",
"Current Time: 2023-08-24 22:19:58\n",
"Current Time: 2023-08-24 22:20:08\n",
"Current Time: 2023-08-24 22:20:18\n",
"Current Time: 2023-08-24 22:20:28\n",
"Current Time: 2023-08-24 22:20:39\n",
"Current Time: 2023-08-24 22:20:49\n",
"Current Time: 2023-08-24 22:20:59\n",
"Current Time: 2023-08-24 22:21:09\n",
"Current Time: 2023-08-24 22:21:19\n",
"Current Time: 2023-08-24 22:21:29\n",
"Current Time: 2023-08-24 22:21:39\n",
"Current Time: 2023-08-24 22:21:49\n",
"Current Time: 2023-08-24 22:21:59\n",
"Current Time: 2023-08-24 22:22:09\n",
"Current Time: 2023-08-24 22:22:19\n",
"Current Time: 2023-08-24 22:22:29\n",
"Current Time: 2023-08-24 22:22:39\n",
"Current Time: 2023-08-24 22:22:49\n",
"Current Time: 2023-08-24 22:22:59\n",
"Current Time: 2023-08-24 22:23:09\n",
"Current Time: 2023-08-24 22:23:19\n",
"Current Time: 2023-08-24 22:23:29\n",
"Current Time: 2023-08-24 22:23:39\n",
"Current Time: 2023-08-24 22:23:49\n",
"Current Time: 2023-08-24 22:23:59\n",
"Current Time: 2023-08-24 22:24:09\n",
"Current Time: 2023-08-24 22:24:19\n",
"Current Time: 2023-08-24 22:24:29\n",
"Current Time: 2023-08-24 22:24:39\n",
"Current Time: 2023-08-24 22:24:49\n",
"Current Time: 2023-08-24 22:24:59\n",
"Current Time: 2023-08-24 22:25:09\n",
"Current Time: 2023-08-24 22:25:19\n",
"Current Time: 2023-08-24 22:25:29\n",
"Current Time: 2023-08-24 22:25:39\n",
"Current Time: 2023-08-24 22:25:49\n",
"Current Time: 2023-08-24 22:25:59\n",
"Current Time: 2023-08-24 22:26:09\n",
"Current Time: 2023-08-24 22:26:19\n",
"Current Time: 2023-08-24 22:26:29\n",
"Current Time: 2023-08-24 22:26:39\n",
"Current Time: 2023-08-24 22:26:49\n",
"Current Time: 2023-08-24 22:26:59\n",
"Current Time: 2023-08-24 22:27:09\n",
"Current Time: 2023-08-24 22:27:19\n",
"Current Time: 2023-08-24 22:27:29\n",
"Current Time: 2023-08-24 22:27:39\n",
"Current Time: 2023-08-24 22:27:49\n",
"Current Time: 2023-08-24 22:27:59\n",
"Current Time: 2023-08-24 22:28:09\n",
"Current Time: 2023-08-24 22:28:19\n",
"Current Time: 2023-08-24 22:28:29\n",
"Current Time: 2023-08-24 22:28:39\n",
"Current Time: 2023-08-24 22:28:49\n",
"Current Time: 2023-08-24 22:28:59\n",
"Current Time: 2023-08-24 22:29:09\n",
"Current Time: 2023-08-24 22:29:19\n",
"Current Time: 2023-08-24 22:29:29\n",
"Current Time: 2023-08-24 22:29:39\n",
"Current Time: 2023-08-24 22:29:49\n",
"Current Time: 2023-08-24 22:29:59\n",
"Current Time: 2023-08-24 22:30:09\n",
"Current Time: 2023-08-24 22:30:19\n",
"Current Time: 2023-08-24 22:30:29\n",
"Current Time: 2023-08-24 22:30:39\n",
"Current Time: 2023-08-24 22:30:49\n",
"Current Time: 2023-08-24 22:30:59\n",
"Current Time: 2023-08-24 22:31:09\n",
"Current Time: 2023-08-24 22:31:19\n",
"Current Time: 2023-08-24 22:31:29\n",
"Current Time: 2023-08-24 22:31:39\n",
"Current Time: 2023-08-24 22:31:49\n",
"Current Time: 2023-08-24 22:31:59\n",
"Current Time: 2023-08-24 22:32:09\n",
"Current Time: 2023-08-24 22:32:19\n",
"Current Time: 2023-08-24 22:32:29\n",
"Current Time: 2023-08-24 22:32:39\n",
"Current Time: 2023-08-24 22:32:49\n",
"Current Time: 2023-08-24 22:32:59\n",
"Current Time: 2023-08-24 22:33:09\n",
"Current Time: 2023-08-24 22:33:19\n",
"Current Time: 2023-08-24 22:33:29\n",
"Current Time: 2023-08-24 22:33:39\n",
"Current Time: 2023-08-24 22:33:49\n",
"Current Time: 2023-08-24 22:33:59\n",
"Current Time: 2023-08-24 22:34:09\n",
"Current Time: 2023-08-24 22:34:19\n",
"Current Time: 2023-08-24 22:34:29\n",
"Current Time: 2023-08-24 22:34:39\n",
"Current Time: 2023-08-24 22:34:49\n",
"Current Time: 2023-08-24 22:34:59\n",
"Current Time: 2023-08-24 22:35:09\n",
"Current Time: 2023-08-24 22:35:19\n",
"Current Time: 2023-08-24 22:35:29\n",
"Current Time: 2023-08-24 22:35:39\n",
"Current Time: 2023-08-24 22:35:49\n",
"Current Time: 2023-08-24 22:35:59\n",
"Current Time: 2023-08-24 22:36:09\n",
"Current Time: 2023-08-24 22:36:19\n",
"Current Time: 2023-08-24 22:36:29\n",
"Current Time: 2023-08-24 22:36:39\n",
"Current Time: 2023-08-24 22:36:49\n",
"Current Time: 2023-08-24 22:36:59\n",
"Current Time: 2023-08-24 22:37:09\n",
"Current Time: 2023-08-24 22:37:19\n",
"Current Time: 2023-08-24 22:37:30\n",
"Current Time: 2023-08-24 22:37:40\n",
"Current Time: 2023-08-24 22:37:50\n",
"Current Time: 2023-08-24 22:38:00\n",
"Current Time: 2023-08-24 22:38:10\n",
"Current Time: 2023-08-24 22:38:20\n",
"Current Time: 2023-08-24 22:38:30\n",
"Current Time: 2023-08-24 22:38:40\n",
"Current Time: 2023-08-24 22:38:50\n",
"Current Time: 2023-08-24 22:39:00\n",
"Current Time: 2023-08-24 22:39:10\n",
"Current Time: 2023-08-24 22:39:20\n",
"Current Time: 2023-08-24 22:39:30\n",
"Current Time: 2023-08-24 22:39:40\n",
"Current Time: 2023-08-24 22:39:50\n",
"Current Time: 2023-08-24 22:40:00\n",
"Current Time: 2023-08-24 22:40:10\n",
"Current Time: 2023-08-24 22:40:20\n",
"Current Time: 2023-08-24 22:40:30\n",
"Current Time: 2023-08-24 22:40:40\n",
"Current Time: 2023-08-24 22:40:50\n",
"Current Time: 2023-08-24 22:41:00\n",
"Current Time: 2023-08-24 22:41:10\n",
"Current Time: 2023-08-24 22:41:20\n",
"Current Time: 2023-08-24 22:41:30\n",
"Current Time: 2023-08-24 22:41:40\n",
"Current Time: 2023-08-24 22:41:50\n",
"Current Time: 2023-08-24 22:42:00\n",
"Current Time: 2023-08-24 22:42:10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Current Time: 2023-08-24 22:42:20\n",
"Current Time: 2023-08-24 22:42:30\n",
"Current Time: 2023-08-24 22:42:40\n",
"Current Time: 2023-08-24 22:42:50\n",
"Current Time: 2023-08-24 22:43:00\n",
"Current Time: 2023-08-24 22:43:10\n",
"Current Time: 2023-08-24 22:43:20\n",
"Current Time: 2023-08-24 22:43:30\n",
"Current Time: 2023-08-24 22:43:40\n",
"Current Time: 2023-08-24 22:43:50\n",
"Current Time: 2023-08-24 22:44:00\n",
"Current Time: 2023-08-24 22:44:10\n",
"Current Time: 2023-08-24 22:44:20\n",
"Current Time: 2023-08-24 22:44:30\n",
"Current Time: 2023-08-24 22:44:40\n",
"Current Time: 2023-08-24 22:44:50\n",
"Current Time: 2023-08-24 22:45:00\n",
"Current Time: 2023-08-24 22:45:10\n",
"Current Time: 2023-08-24 22:45:20\n",
"Current Time: 2023-08-24 22:45:30\n",
"Current Time: 2023-08-24 22:45:40\n",
"Current Time: 2023-08-24 22:45:50\n",
"Current Time: 2023-08-24 22:46:00\n",
"Current Time: 2023-08-24 22:46:10\n",
"Current Time: 2023-08-24 22:46:20\n",
"Current Time: 2023-08-24 22:46:30\n",
"Current Time: 2023-08-24 22:46:40\n",
"Current Time: 2023-08-24 22:46:50\n",
"Current Time: 2023-08-24 22:47:00\n",
"Current Time: 2023-08-24 22:47:10\n",
"Current Time: 2023-08-24 22:47:20\n",
"Current Time: 2023-08-24 22:47:30\n",
"Current Time: 2023-08-24 22:47:40\n",
"Current Time: 2023-08-24 22:47:50\n",
"Current Time: 2023-08-24 22:48:00\n",
"Current Time: 2023-08-24 22:48:10\n",
"Current Time: 2023-08-24 22:48:20\n",
"Current Time: 2023-08-24 22:48:30\n",
"Current Time: 2023-08-24 22:48:40\n",
"Current Time: 2023-08-24 22:48:50\n",
"Current Time: 2023-08-24 22:49:00\n",
"Current Time: 2023-08-24 22:49:10\n",
"Current Time: 2023-08-24 22:49:20\n",
"Current Time: 2023-08-24 22:49:30\n",
"Current Time: 2023-08-24 22:49:40\n",
"Current Time: 2023-08-24 22:49:50\n",
"Current Time: 2023-08-24 22:50:00\n",
"Current Time: 2023-08-24 22:50:10\n",
"Current Time: 2023-08-24 22:50:20\n",
"Current Time: 2023-08-24 22:50:30\n",
"Current Time: 2023-08-24 22:50:40\n",
"Current Time: 2023-08-24 22:50:50\n",
"Current Time: 2023-08-24 22:51:00\n",
"Current Time: 2023-08-24 22:51:10\n",
"Current Time: 2023-08-24 22:51:20\n",
"Current Time: 2023-08-24 22:51:30\n",
"Current Time: 2023-08-24 22:51:40\n",
"Current Time: 2023-08-24 22:51:50\n",
"Current Time: 2023-08-24 22:52:00\n",
"Current Time: 2023-08-24 22:52:10\n",
"Current Time: 2023-08-24 22:52:20\n",
"Current Time: 2023-08-24 22:52:30\n",
"Current Time: 2023-08-24 22:52:40\n",
"Current Time: 2023-08-24 22:52:50\n",
"Current Time: 2023-08-24 22:53:00\n",
"Current Time: 2023-08-24 22:53:10\n",
"Current Time: 2023-08-24 22:53:20\n",
"Current Time: 2023-08-24 22:53:30\n",
"Current Time: 2023-08-24 22:53:40\n",
"Current Time: 2023-08-24 22:53:50\n",
"Current Time: 2023-08-24 22:54:00\n",
"Current Time: 2023-08-24 22:54:11\n",
"Current Time: 2023-08-24 22:54:21\n",
"Current Time: 2023-08-24 22:54:31\n",
"Current Time: 2023-08-24 22:54:41\n",
"Current Time: 2023-08-24 22:54:51\n",
"Current Time: 2023-08-24 22:55:01\n",
"Current Time: 2023-08-24 22:55:11\n",
"Current Time: 2023-08-24 22:55:21\n",
"Current Time: 2023-08-24 22:55:31\n",
"Current Time: 2023-08-24 22:55:41\n",
"Current Time: 2023-08-24 22:55:51\n",
"Current Time: 2023-08-24 22:56:01\n",
"Current Time: 2023-08-24 22:56:11\n",
"Current Time: 2023-08-24 22:56:21\n",
"Current Time: 2023-08-24 22:56:31\n",
"Current Time: 2023-08-24 22:56:41\n",
"Current Time: 2023-08-24 22:56:51\n",
"Current Time: 2023-08-24 22:57:01\n",
"Current Time: 2023-08-24 22:57:11\n",
"Current Time: 2023-08-24 22:57:21\n",
"Current Time: 2023-08-24 22:57:31\n",
"Current Time: 2023-08-24 22:57:41\n",
"Current Time: 2023-08-24 22:57:51\n",
"Current Time: 2023-08-24 22:58:01\n",
"Current Time: 2023-08-24 22:58:11\n",
"Current Time: 2023-08-24 22:58:21\n",
"Current Time: 2023-08-24 22:58:31\n",
"Current Time: 2023-08-24 22:58:41\n",
"Current Time: 2023-08-24 22:58:51\n",
"Current Time: 2023-08-24 22:59:01\n",
"Current Time: 2023-08-24 22:59:11\n",
"Current Time: 2023-08-24 22:59:21\n",
"Current Time: 2023-08-24 22:59:31\n",
"Current Time: 2023-08-24 22:59:41\n",
"Current Time: 2023-08-24 22:59:51\n",
"Current Time: 2023-08-24 23:00:01\n",
"Current Time: 2023-08-24 23:00:11\n",
"Current Time: 2023-08-24 23:00:21\n",
"Current Time: 2023-08-24 23:00:31\n",
"Current Time: 2023-08-24 23:00:41\n",
"Current Time: 2023-08-24 23:00:51\n",
"Current Time: 2023-08-24 23:01:01\n",
"Current Time: 2023-08-24 23:01:11\n",
"Current Time: 2023-08-24 23:01:21\n",
"Current Time: 2023-08-24 23:01:31\n",
"Current Time: 2023-08-24 23:01:41\n",
"Current Time: 2023-08-24 23:01:51\n",
"Current Time: 2023-08-24 23:02:01\n",
"Current Time: 2023-08-24 23:02:11\n",
"Current Time: 2023-08-24 23:02:21\n",
"Current Time: 2023-08-24 23:02:31\n",
"Current Time: 2023-08-24 23:02:41\n",
"Current Time: 2023-08-24 23:02:51\n",
"Current Time: 2023-08-24 23:03:01\n",
"Current Time: 2023-08-24 23:03:11\n",
"Current Time: 2023-08-24 23:03:21\n",
"Current Time: 2023-08-24 23:03:31\n",
"Current Time: 2023-08-24 23:03:41\n",
"Current Time: 2023-08-24 23:03:51\n",
"Current Time: 2023-08-24 23:04:01\n",
"Current Time: 2023-08-24 23:04:11\n",
"Current Time: 2023-08-24 23:04:21\n",
"Current Time: 2023-08-24 23:04:31\n",
"Current Time: 2023-08-24 23:04:41\n",
"Current Time: 2023-08-24 23:04:51\n",
"Current Time: 2023-08-24 23:05:01\n",
"Current Time: 2023-08-24 23:05:11\n",
"Current Time: 2023-08-24 23:05:21\n",
"Current Time: 2023-08-24 23:05:31\n",
"Current Time: 2023-08-24 23:05:41\n",
"Current Time: 2023-08-24 23:05:51\n",
"Current Time: 2023-08-24 23:06:01\n",
"Current Time: 2023-08-24 23:06:11\n",
"Current Time: 2023-08-24 23:06:21\n",
"Current Time: 2023-08-24 23:06:31\n",
"Current Time: 2023-08-24 23:06:41\n",
"Current Time: 2023-08-24 23:06:51\n",
"Current Time: 2023-08-24 23:07:01\n",
"Current Time: 2023-08-24 23:07:12\n",
"Current Time: 2023-08-24 23:07:22\n",
"Current Time: 2023-08-24 23:07:32\n",
"Current Time: 2023-08-24 23:07:42\n",
"Current Time: 2023-08-24 23:07:52\n",
"Current Time: 2023-08-24 23:08:02\n",
"Current Time: 2023-08-24 23:08:12\n",
"Current Time: 2023-08-24 23:08:22\n",
"Current Time: 2023-08-24 23:08:32\n",
"Current Time: 2023-08-24 23:08:42\n",
"Current Time: 2023-08-24 23:08:52\n",
"Current Time: 2023-08-24 23:09:02\n",
"Current Time: 2023-08-24 23:09:12\n",
"Current Time: 2023-08-24 23:09:22\n",
"Current Time: 2023-08-24 23:09:32\n",
"Current Time: 2023-08-24 23:09:42\n",
"Current Time: 2023-08-24 23:09:52\n",
"Current Time: 2023-08-24 23:10:02\n",
"Current Time: 2023-08-24 23:10:12\n",
"Current Time: 2023-08-24 23:10:22\n",
"Current Time: 2023-08-24 23:10:32\n",
"Current Time: 2023-08-24 23:10:42\n",
"Current Time: 2023-08-24 23:10:52\n",
"Current Time: 2023-08-24 23:11:02\n",
"Current Time: 2023-08-24 23:11:12\n",
"Current Time: 2023-08-24 23:11:22\n",
"Current Time: 2023-08-24 23:11:32\n",
"Current Time: 2023-08-24 23:11:42\n",
"Current Time: 2023-08-24 23:11:52\n",
"Current Time: 2023-08-24 23:12:02\n",
"Current Time: 2023-08-24 23:12:12\n",
"Current Time: 2023-08-24 23:12:22\n",
"Current Time: 2023-08-24 23:12:32\n",
"Current Time: 2023-08-24 23:12:42\n",
"Current Time: 2023-08-24 23:12:52\n",
"Current Time: 2023-08-24 23:13:02\n",
"Current Time: 2023-08-24 23:13:12\n",
"Current Time: 2023-08-24 23:13:22\n",
"Current Time: 2023-08-24 23:13:32\n",
"Current Time: 2023-08-24 23:13:42\n",
"Current Time: 2023-08-24 23:13:52\n",
"Current Time: 2023-08-24 23:14:02\n",
"Current Time: 2023-08-24 23:14:12\n",
"Current Time: 2023-08-24 23:14:22\n",
"Current Time: 2023-08-24 23:14:32\n",
"Current Time: 2023-08-24 23:14:42\n"
]
}
],
"source": [
"import time\n",
"\n",
"while True:\n",
" current_time = time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime())\n",
" print(f\"Current Time: {current_time}\")\n",
" time.sleep(10)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"I'm pretty happy with my model's accuracy relative to GPT-4. How does it compare cost-wise?\n",
"\n",
"I'll really push this to its limits -- let's see how quickly our poor model can classify the [full 2-million-recipe dataset](https://huggingface.co/datasets/corbt/all-recipes) 😈."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: datasets==2.14.4 in /usr/local/lib/python3.10/dist-packages (2.14.4)\n",
"Requirement already satisfied: vllm==0.1.3 in /usr/local/lib/python3.10/dist-packages (0.1.3)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (1.24.4)\n",
"Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (12.0.1)\n",
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.3.7)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.0.3)\n",
"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.28.1)\n",
"Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (4.66.1)\n",
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.3.0)\n",
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.70.15)\n",
"Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2023.6.0)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.8.5)\n",
"Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.16.4)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (23.1)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (6.0)\n",
"Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (5.9.5)\n",
"Requirement already satisfied: ray>=2.5.1 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.6.3)\n",
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.1.99)\n",
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.0.1+cu118)\n",
"Requirement already satisfied: transformers>=4.31.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (4.33.0.dev0)\n",
"Requirement already satisfied: xformers>=0.0.19 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.0.21)\n",
"Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.101.1)\n",
"Requirement already satisfied: uvicorn in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.23.2)\n",
"Requirement already satisfied: pydantic<2 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.10.12)\n",
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (23.1.0)\n",
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (2.1.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (4.0.3)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.9.2)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.4.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.3.1)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (3.9.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (4.7.1)\n",
"Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (8.1.7)\n",
"Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.18.0)\n",
"Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.0.5)\n",
"Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.24.1)\n",
"Requirement already satisfied: grpcio>=1.42.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.57.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (3.4)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (1.26.13)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (2022.12.7)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.0)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.1.2)\n",
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (2.0.0)\n",
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (3.25.0)\n",
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (15.0.7)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (2023.8.8)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.3.2)\n",
"Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi->vllm==0.1.3) (0.27.0)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
"Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn->vllm==0.1.3) (0.14.0)\n",
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->datasets==2.14.4) (1.16.0)\n",
"Requirement already satisfied: anyio<5,>=3.4.0 in /usr/local/lib/python3.10/dist-packages (from starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (3.7.1)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.0.0->vllm==0.1.3) (2.1.2)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (2023.6.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.29.1)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.8.10)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.0.0->vllm==0.1.3) (1.2.1)\n",
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.3.0)\n",
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.1.2)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install datasets==2.14.4 vllm==0.1.3"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of recipes: 2,147,248\n"
]
}
],
"source": [
"from datasets import load_dataset\n",
"\n",
"all_recipes = load_dataset(\"corbt/all-recipes\")[\"train\"][\"input\"]\n",
"\n",
"print(f\"Number of recipes: {len(all_recipes):,}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO 08-24 19:38:29 llm_engine.py:70] Initializing an LLM engine with config: model='./models/run1/merged', tokenizer='./models/run1/merged', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n",
"INFO 08-24 19:39:48 llm_engine.py:196] # GPU blocks: 3419, # CPU blocks: 512\n"
]
}
],
"source": [
"from vllm import LLM, SamplingParams\n",
"\n",
"llm = LLM(model=\"./models/run1/merged\", max_num_batched_tokens=4096)\n",
"\n",
"sampling_params = SamplingParams(\n",
" # 120 should be fine for the work we're doing here.\n",
" max_tokens=120,\n",
" # This is a deterministic task so temperature=0 is best.\n",
" temperature=0,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Start time: 1692906050.3340027\n",
"Processing recipes 0 to 10,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 10000/10000 [04:51<00:00, 34.30it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing recipes 10,000 to 20,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 10000/10000 [04:54<00:00, 33.98it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing recipes 20,000 to 30,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 10000/10000 [04:53<00:00, 34.11it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing recipes 30,000 to 40,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 10000/10000 [04:53<00:00, 34.11it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing recipes 40,000 to 50,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 48%|████▊ | 4796/10000 [02:21<03:18, 26.22it/s]"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[6], line 12\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(\u001b[39m0\u001b[39m, \u001b[39mlen\u001b[39m(all_recipes), BATCH_SIZE):\n\u001b[1;32m 11\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mProcessing recipes \u001b[39m\u001b[39m{\u001b[39;00mi\u001b[39m:\u001b[39;00m\u001b[39m,\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m to \u001b[39m\u001b[39m{\u001b[39;00mi\u001b[39m+\u001b[39mBATCH_SIZE\u001b[39m:\u001b[39;00m\u001b[39m,\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m...\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m---> 12\u001b[0m outputs \u001b[39m=\u001b[39m llm\u001b[39m.\u001b[39;49mgenerate(all_recipes[i:i\u001b[39m+\u001b[39;49mBATCH_SIZE], sampling_params\u001b[39m=\u001b[39;49msampling_params)\n\u001b[1;32m 14\u001b[0m all_outputs\u001b[39m.\u001b[39mextend([o\u001b[39m.\u001b[39moutputs[\u001b[39m0\u001b[39m]\u001b[39m.\u001b[39mtext \u001b[39mfor\u001b[39;00m o \u001b[39min\u001b[39;00m outputs])\n\u001b[1;32m 16\u001b[0m end_time \u001b[39m=\u001b[39m time\u001b[39m.\u001b[39mtime()\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/llm.py:130\u001b[0m, in \u001b[0;36mLLM.generate\u001b[0;34m(self, prompts, sampling_params, prompt_token_ids, use_tqdm)\u001b[0m\n\u001b[1;32m 128\u001b[0m token_ids \u001b[39m=\u001b[39m prompt_token_ids[i]\n\u001b[1;32m 129\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_add_request(prompt, sampling_params, token_ids)\n\u001b[0;32m--> 130\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_run_engine(use_tqdm)\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/llm.py:150\u001b[0m, in \u001b[0;36mLLM._run_engine\u001b[0;34m(self, use_tqdm)\u001b[0m\n\u001b[1;32m 148\u001b[0m outputs: List[RequestOutput] \u001b[39m=\u001b[39m []\n\u001b[1;32m 149\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mllm_engine\u001b[39m.\u001b[39mhas_unfinished_requests():\n\u001b[0;32m--> 150\u001b[0m step_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mllm_engine\u001b[39m.\u001b[39;49mstep()\n\u001b[1;32m 151\u001b[0m \u001b[39mfor\u001b[39;00m output \u001b[39min\u001b[39;00m step_outputs:\n\u001b[1;32m 152\u001b[0m \u001b[39mif\u001b[39;00m output\u001b[39m.\u001b[39mfinished:\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py:313\u001b[0m, in \u001b[0;36mLLMEngine.step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 307\u001b[0m \u001b[39mreturn\u001b[39;00m [\n\u001b[1;32m 308\u001b[0m RequestOutput\u001b[39m.\u001b[39mfrom_seq_group(seq_group)\n\u001b[1;32m 309\u001b[0m \u001b[39mfor\u001b[39;00m seq_group \u001b[39min\u001b[39;00m scheduler_outputs\u001b[39m.\u001b[39mignored_seq_groups\n\u001b[1;32m 310\u001b[0m ]\n\u001b[1;32m 312\u001b[0m \u001b[39m# Execute the model.\u001b[39;00m\n\u001b[0;32m--> 313\u001b[0m output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_run_workers(\n\u001b[1;32m 314\u001b[0m \u001b[39m\"\u001b[39;49m\u001b[39mexecute_model\u001b[39;49m\u001b[39m\"\u001b[39;49m,\n\u001b[1;32m 315\u001b[0m seq_group_metadata_list\u001b[39m=\u001b[39;49mseq_group_metadata_list,\n\u001b[1;32m 316\u001b[0m blocks_to_swap_in\u001b[39m=\u001b[39;49mscheduler_outputs\u001b[39m.\u001b[39;49mblocks_to_swap_in,\n\u001b[1;32m 317\u001b[0m blocks_to_swap_out\u001b[39m=\u001b[39;49mscheduler_outputs\u001b[39m.\u001b[39;49mblocks_to_swap_out,\n\u001b[1;32m 318\u001b[0m blocks_to_copy\u001b[39m=\u001b[39;49mscheduler_outputs\u001b[39m.\u001b[39;49mblocks_to_copy,\n\u001b[1;32m 319\u001b[0m )\n\u001b[1;32m 320\u001b[0m \u001b[39m# Update the scheduler with the model outputs.\u001b[39;00m\n\u001b[1;32m 321\u001b[0m seq_groups \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mscheduler\u001b[39m.\u001b[39mupdate(output)\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py:470\u001b[0m, in \u001b[0;36mLLMEngine._run_workers\u001b[0;34m(self, method, get_all_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 467\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 468\u001b[0m executor \u001b[39m=\u001b[39m \u001b[39mgetattr\u001b[39m(worker, method)\n\u001b[0;32m--> 470\u001b[0m output \u001b[39m=\u001b[39m executor(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 471\u001b[0m all_outputs\u001b[39m.\u001b[39mappend(output)\n\u001b[1;32m 473\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mparallel_config\u001b[39m.\u001b[39mworker_use_ray:\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[39m@functools\u001b[39m\u001b[39m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mdecorate_context\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[39mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[39mreturn\u001b[39;00m func(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/worker/worker.py:293\u001b[0m, in \u001b[0;36mWorker.execute_model\u001b[0;34m(self, seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)\u001b[0m\n\u001b[1;32m 289\u001b[0m input_tokens, input_positions, input_metadata \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_prepare_inputs(\n\u001b[1;32m 290\u001b[0m seq_group_metadata_list)\n\u001b[1;32m 292\u001b[0m \u001b[39m# Execute the model.\u001b[39;00m\n\u001b[0;32m--> 293\u001b[0m output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel(\n\u001b[1;32m 294\u001b[0m input_ids\u001b[39m=\u001b[39;49minput_tokens,\n\u001b[1;32m 295\u001b[0m positions\u001b[39m=\u001b[39;49minput_positions,\n\u001b[1;32m 296\u001b[0m kv_caches\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mgpu_cache,\n\u001b[1;32m 297\u001b[0m input_metadata\u001b[39m=\u001b[39;49minput_metadata,\n\u001b[1;32m 298\u001b[0m cache_events\u001b[39m=\u001b[39;49mcache_events,\n\u001b[1;32m 299\u001b[0m )\n\u001b[1;32m 300\u001b[0m \u001b[39mreturn\u001b[39;00m output\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/llama.py:255\u001b[0m, in \u001b[0;36mLlamaForCausalLM.forward\u001b[0;34m(self, input_ids, positions, kv_caches, input_metadata, cache_events)\u001b[0m\n\u001b[1;32m 245\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\n\u001b[1;32m 246\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 247\u001b[0m input_ids: torch\u001b[39m.\u001b[39mTensor,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 251\u001b[0m cache_events: Optional[List[torch\u001b[39m.\u001b[39mcuda\u001b[39m.\u001b[39mEvent]],\n\u001b[1;32m 252\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Dict[\u001b[39mint\u001b[39m, SequenceOutputs]:\n\u001b[1;32m 253\u001b[0m hidden_states \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel(input_ids, positions, kv_caches,\n\u001b[1;32m 254\u001b[0m input_metadata, cache_events)\n\u001b[0;32m--> 255\u001b[0m next_tokens \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49msampler(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mlm_head\u001b[39m.\u001b[39;49mweight, hidden_states,\n\u001b[1;32m 256\u001b[0m input_metadata)\n\u001b[1;32m 257\u001b[0m \u001b[39mreturn\u001b[39;00m next_tokens\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/sampler.py:44\u001b[0m, in \u001b[0;36mSampler.forward\u001b[0;34m(self, embedding, hidden_states, input_metadata, embedding_bias)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\n\u001b[1;32m 37\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 38\u001b[0m embedding: torch\u001b[39m.\u001b[39mTensor,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 42\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Dict[\u001b[39mint\u001b[39m, SequenceOutputs]:\n\u001b[1;32m 43\u001b[0m \u001b[39m# Get the hidden states that we use for sampling.\u001b[39;00m\n\u001b[0;32m---> 44\u001b[0m hidden_states \u001b[39m=\u001b[39m _prune_hidden_states(hidden_states, input_metadata)\n\u001b[1;32m 46\u001b[0m \u001b[39m# Get the logits for the next tokens.\u001b[39;00m\n\u001b[1;32m 47\u001b[0m logits \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mmatmul(hidden_states, embedding\u001b[39m.\u001b[39mt())\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"# We'll process our recipes in batches of 10,000.\n",
"\n",
"import time\n",
"\n",
"BATCH_SIZE = 10000\n",
"all_outputs = []\n",
"\n",
"start_time = time.time()\n",
"print(f\"Start time: {start_time}\")\n",
"for i in range(0, len(all_recipes), BATCH_SIZE):\n",
" print(f\"Processing recipes {i:,} to {i+BATCH_SIZE:,}...\")\n",
" outputs = llm.generate(\n",
" all_recipes[i : i + BATCH_SIZE], sampling_params=sampling_params\n",
" )\n",
"\n",
" all_outputs.extend([o.outputs[0].text for o in outputs])\n",
"\n",
"end_time = time.time()\n",
"print(f\"End time: {end_time}\")\n",
"print(f\"Total hours: {((end_time - start_time) / 3600):.2f}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Nice! I've processed all 2,147,248 recipes in under 17 hours. Let's do a cost comparison with GPT-3.5 and GPT-4. I'll use the GPT-4 latency/cost numbers based on the 5000 samples used to generate our model's training data."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Model</th>\n",
" <th>Cost to Classify One Recipe</th>\n",
" <th>Cost to Classify Entire Dataset</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Llama 2 7B (finetuned)</td>\n",
" <td>0.000009</td>\n",
" <td>18.86</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>GPT-3.5</td>\n",
" <td>0.000481</td>\n",
" <td>1,033.26</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>GPT-3.5 (finetuned)</td>\n",
" <td>0.004044</td>\n",
" <td>8,683.47</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>GPT-4</td>\n",
" <td>0.010800</td>\n",
" <td>23,190.28</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Model Cost to Classify One Recipe \\\n",
"0 Llama 2 7B (finetuned) 0.000009 \n",
"1 GPT-3.5 0.000481 \n",
"2 GPT-3.5 (finetuned) 0.004044 \n",
"3 GPT-4 0.010800 \n",
"\n",
" Cost to Classify Entire Dataset \n",
"0 18.86 \n",
"1 1,033.26 \n",
"2 8,683.47 \n",
"3 23,190.28 "
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"# I used an on-demand Nvidia L40 on RunPod for this, at an hourly cost of $1.14.\n",
"finetuned_hourly_cost = 1.14\n",
"\n",
"finetuned_total_hours = 16.54\n",
"\n",
"finetuned_avg_cost = finetuned_hourly_cost * finetuned_total_hours / len(all_recipes)\n",
"\n",
"# The average input and output tokens calculated by OpenAI, based on the 5000 recipes I sent them\n",
"avg_input_tokens = 276\n",
"avg_output_tokens = 42\n",
"\n",
"# Token pricing from https://openai.com/pricing\n",
"gpt_4_avg_cost = avg_input_tokens * 0.03 / 1000 + avg_output_tokens * 0.06 / 1000\n",
"\n",
"gpt_35_avg_cost = avg_input_tokens * 0.0015 / 1000 + avg_output_tokens * 0.0016 / 1000\n",
"\n",
"gpt_35_finetuned_avg_cost = (\n",
" avg_input_tokens * 0.012 / 1000 + avg_output_tokens * 0.016 / 1000 + 0.06 / 1000\n",
")\n",
"\n",
"# Multiply the number of recipes\n",
"# gpt_4_cost = len(all_recipes) * gpt_4_avg_cost\n",
"# gpt_35_cost = len(all_recipes) * gpt_35_avg_cost\n",
"# gpt_35_finetuned_cost = len(all_recipes) * gpt_35_finetuned_avg_cost\n",
"\n",
"# Let's put this in a dataframe for easier comparison.\n",
"\n",
"costs = pd.DataFrame(\n",
" {\n",
" \"Model\": [\n",
" \"Llama 2 7B (finetuned)\",\n",
" \"GPT-3.5\",\n",
" \"GPT-3.5 (finetuned)\",\n",
" \"GPT-4\",\n",
" ],\n",
" \"Cost to Classify One Recipe\": [\n",
" finetuned_avg_cost,\n",
" gpt_35_avg_cost,\n",
" gpt_35_finetuned_avg_cost,\n",
" gpt_4_avg_cost,\n",
" ],\n",
" }\n",
")\n",
"\n",
"costs[\"Cost to Classify Entire Dataset\"] = (\n",
" costs[\"Cost to Classify One Recipe\"] * len(all_recipes)\n",
").map(lambda x: f\"{x:,.2f}\")\n",
"\n",
"\n",
"costs\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"...and just for fun, let's figure out how many recipes my pescatarian basement-dwelling brother can make! 😂"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -1,248 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"I'm pretty happy with my model's accuracy relative to GPT-4. How does it compare cost-wise?\n",
"\n",
"I'll really push this to its limits -- let's see how quickly our poor model can classify the [full 2-million-recipe dataset](https://huggingface.co/datasets/corbt/all-recipes) 😈."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: datasets==2.14.4 in /usr/local/lib/python3.10/dist-packages (2.14.4)\n",
"Requirement already satisfied: vllm==0.1.3 in /usr/local/lib/python3.10/dist-packages (0.1.3)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (1.24.4)\n",
"Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (12.0.1)\n",
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.3.7)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.0.3)\n",
"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.28.1)\n",
"Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (4.66.1)\n",
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.3.0)\n",
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.70.15)\n",
"Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2023.6.0)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.8.5)\n",
"Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.16.4)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (23.1)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (6.0)\n",
"Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (5.9.5)\n",
"Requirement already satisfied: ray>=2.5.1 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.6.3)\n",
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.1.99)\n",
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.0.1+cu118)\n",
"Requirement already satisfied: transformers>=4.31.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (4.33.0.dev0)\n",
"Requirement already satisfied: xformers>=0.0.19 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.0.21)\n",
"Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.101.1)\n",
"Requirement already satisfied: uvicorn in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.23.2)\n",
"Requirement already satisfied: pydantic<2 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.10.12)\n",
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (23.1.0)\n",
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (2.1.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (4.0.3)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.9.2)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.4.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.3.1)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (3.9.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (4.7.1)\n",
"Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (8.1.7)\n",
"Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.18.0)\n",
"Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.0.5)\n",
"Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.24.1)\n",
"Requirement already satisfied: grpcio>=1.42.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.57.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (3.4)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (1.26.13)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (2022.12.7)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.0)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.1.2)\n",
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (2.0.0)\n",
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (3.25.0)\n",
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (15.0.7)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (2023.8.8)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.3.2)\n",
"Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi->vllm==0.1.3) (0.27.0)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
"Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn->vllm==0.1.3) (0.14.0)\n",
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->datasets==2.14.4) (1.16.0)\n",
"Requirement already satisfied: anyio<5,>=3.4.0 in /usr/local/lib/python3.10/dist-packages (from starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (3.7.1)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.0.0->vllm==0.1.3) (2.1.2)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (2023.6.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.29.1)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.8.10)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.0.0->vllm==0.1.3) (1.2.1)\n",
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.3.0)\n",
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.1.2)\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.10 -m pip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install datasets==2.14.4 vllm==0.1.3"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of recipes: 2,147,248\n"
]
}
],
"source": [
"from datasets import load_dataset\n",
"\n",
"all_recipes = load_dataset(\"corbt/all-recipes\")[\"train\"][\"input\"]\n",
"\n",
"print(f\"Number of recipes: {len(all_recipes):,}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO 08-24 19:38:29 llm_engine.py:70] Initializing an LLM engine with config: model='./models/run1/merged', tokenizer='./models/run1/merged', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n",
"INFO 08-24 19:39:48 llm_engine.py:196] # GPU blocks: 3419, # CPU blocks: 512\n"
]
}
],
"source": [
"from vllm import LLM, SamplingParams\n",
"\n",
"llm = LLM(model=\"./models/run1/merged\", max_num_batched_tokens=4096)\n",
"\n",
"sampling_params = SamplingParams(\n",
" # 120 should be fine for the work we're doing here.\n",
" max_tokens=120,\n",
" # This is a deterministic task so temperature=0 is best.\n",
" temperature=0,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Start time: 1692906050.3340027\n",
"Processing recipes 0 to 10,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 10000/10000 [04:51<00:00, 34.30it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing recipes 10,000 to 20,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 10000/10000 [04:54<00:00, 33.98it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing recipes 20,000 to 30,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 30%|███ | 3008/10000 [01:29<03:27, 33.73it/s]"
]
}
],
"source": [
"# We'll process our recipes in batches of 10,000.\n",
"\n",
"import time\n",
"\n",
"BATCH_SIZE = 10000\n",
"all_outputs = []\n",
"\n",
"start_time = time.time()\n",
"print(f\"Start time: {start_time}\")\n",
"for i in range(0, len(all_recipes), BATCH_SIZE):\n",
" print(f\"Processing recipes {i:,} to {i+BATCH_SIZE:,}...\")\n",
" outputs = llm.generate(\n",
" all_recipes[i : i + BATCH_SIZE], sampling_params=sampling_params\n",
" )\n",
"\n",
" all_outputs.extend([o.outputs[0].text for o in outputs])\n",
"\n",
"end_time = time.time()\n",
"print(f\"End time: {end_time}\")\n",
"print(f\"Total hours: {((end_time - start_time) / 3600):.2f}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Nice! I've processed all 2,147,248 recipes in under 17 hours. Let's do a cost comparison with GPT-3.5 and GPT-4. I'll use the GPT-4 latency/cost numbers based on the 5000 samples used to generate our model's training data."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -9,7 +9,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 1,
"metadata": {},
"outputs": [
{
@@ -88,12 +88,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Remember I got a \"test.jsonl\" file from OpenPipe back in [./prepare.ipynb](./prepare.ipynb)? Since that is data formatted the same way as our training data but that we didn't use for training, we can use it to check our model's performance."
"Remember I got a \"test.jsonl\" file from OpenPipe back in [./prepare.ipynb](./prepare.ipynb)? That's data from our dataset that we didn't use in training, so we can use it to check our model's performance."
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -111,7 +111,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -121,7 +121,7 @@
"Sample prompt:\n",
"--------------\n",
"### Instruction:\n",
"[{\"role\":\"system\",\"content\":\"Your goal is to classify a recipe along several dimensions. You should \"},{\"role\":\"user\",\"content\":\"Strawberry Sorbet\\n\\nIngredients:\\n- 2 cups chopped strawberries Safeway 1 lb For $3.99 thru 02/09\\n- 1 cup cold water\\n- 2 cups boiling water\\n- 1 pkg. (4-serving size) JELL-O Strawberry Flavor Gelatin\\n- 1/2 cup sugar\\n\\nDirections:\\n- Place strawberries and cold water in blender container; cover.\\n- Blend on high speed until smooth.\\n- Stir boiling water into combined dry gelatin mix and sugar in medium bowl at least 2 minutes until completely dissolved.\\n- Add strawberry mixture; mix well.\\n- Pour into 9-inch square pan.\\n- Freeze 1 to 1-1/2 hours or until ice crystals form 1 inch around edges of pan.\\n- Spoon half of the gelatin mixture into blender container; cover.\\n- Blend on high speed about 30 seconds or until smooth; pour into bowl.\\n- Repeat with remaining gelatin mixture.\\n- Add to blended gelatin mixture in bowl; mix well.\\n- Return to pan.\\n- Freeze 6 hours or overnight until firm.\\n- Scoop into dessert dishes to serve.\\n- Store leftover sorbet in freezer.\"}]\n",
"[{\"role\":\"system\",\"content\":\"Your goal is to classify a recipe along several dimensions.Pay attention to the instructions.\"},{\"role\":\"user\",\"content\":\"Pan Gravy\\n\\nIngredients:\\n- 1/3 cup all purpose flour\\n- 1/3 cup turkey drippings\\n- 3 cup water or broth\\n- 1/8 to 1/4 teaspoon salt\\n- 1/8 tsp pepper\\n\\nDirections:\\n- In a skillet or roasting pan, add flour to drippings; blend well.\\n- Cook over medium heat 2 to 3 minutes until smooth and light brown, stirring constantly.\\n- Add water; cook until mixture boils and thickens, stirring constantly.\\n- Stir in salt and pepper.\\n- *Flour and drippings can be decreased to 1/4 cup each for thinner gravy.\\n- *\"}]\n",
"\n",
"### Response:\n",
"\n"
@@ -152,22 +152,22 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO 08-24 18:58:03 llm_engine.py:70] Initializing an LLM engine with config: model='./models/run1/merged', tokenizer='./models/run1/merged', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n",
"INFO 08-24 18:59:18 llm_engine.py:196] # GPU blocks: 3419, # CPU blocks: 512\n"
"INFO 08-24 22:01:58 llm_engine.py:70] Initializing an LLM engine with config: model='./models/run1/merged', tokenizer='./models/run1/merged', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n",
"INFO 08-24 22:02:46 llm_engine.py:196] # GPU blocks: 3419, # CPU blocks: 512\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 201/201 [00:16<00:00, 12.18it/s]"
"Processed prompts: 100%|██████████| 500/500 [00:37<00:00, 13.44it/s]"
]
},
{
@@ -176,7 +176,7 @@
"text": [
"Sample output:\n",
"--------------\n",
"{\"role\":\"assistant\",\"content\":null,\"function_call\":{\"name\":\"classify\",\"arguments\":\"{\\n\\\"has_non_fish_meat\\\": false,\\n\\\"requires_oven\\\": false,\\n\\\"requires_stove\\\": true,\\n\\\"cook_time_over_30_mins\\\": false,\\n\\\"main_course\\\": false\\n}\"}}\n"
"{\"role\":\"assistant\",\"content\":null,\"function_call\":{\"name\":\"classify\",\"arguments\":\"{\\n\\\"has_non_fish_meat\\\": true,\\n\\\"requires_oven\\\": false,\\n\\\"requires_stove\\\": true,\\n\\\"cook_time_over_30_mins\\\": false,\\n\\\"main_dish\\\": false\\n}\"}}\n"
]
},
{
@@ -211,19 +211,19 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok, we have our outputs! Since there are 5 categories we classify each recipe on, a natural metric would be for each recipe and each category, what percentage of the time our model's output matches GPT-4's. I'll write a quick eval function to check that."
"Ok, we have our outputs! There are 5 categories we classify each recipe on, so let's check what percentage of the time our model's output matches GPT-4's. I'll write a quick eval function for that:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Overall accuracy: 0.91\n"
"Overall accuracy: 0.95\n"
]
}
],
@@ -231,7 +231,7 @@
"import json\n",
"\n",
"\n",
"def parse_fn_call_args(str):\n",
"def parse_fn_call(str):\n",
" \"\"\"Parse the function call arguments from the response\"\"\"\n",
" response_dict = json.loads(str)\n",
" args_dict = json.loads(response_dict[\"function_call\"][\"arguments\"])\n",
@@ -241,12 +241,12 @@
"\n",
"def calculate_accuracy(row):\n",
" \"\"\"Calculate the fraction of my model's outputs that match the reference outputs\"\"\"\n",
" true_outputs = parse_fn_call_args(row[\"output\"])\n",
" my_outputs = parse_fn_call_args(row[\"my_outputs\"])\n",
" true_outputs = parse_fn_call(row[\"output\"])\n",
" my_outputs = parse_fn_call(row[\"my_outputs\"])\n",
"\n",
" num_matching_outputs = 0\n",
" for key in true_outputs.keys():\n",
" if true_outputs[key] == my_outputs[key]:\n",
" if key in my_outputs and true_outputs[key] == my_outputs[key]:\n",
" num_matching_outputs += 1\n",
"\n",
" return num_matching_outputs / len(true_outputs)\n",
@@ -261,9 +261,371 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Not bad! Of course, the next obvious step is to look at where Llama 2 is \"wrong\" and evaluate the types of errors it makes. I've exported a Google Sheet where I did exactly that with an earlier version of this model trained on the same dataset. You can see that [here](https://docs.google.com/spreadsheets/d/1vn-nA0CRQwz-BvEYvxUcO1-EP80ZbPhcxDoCTttvsmI/edit?usp=sharing).\n",
"Not bad! However, there are still a few rows where the model outputs don't match. Let's take a closer look."
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Alligator Sauce Piquant\n",
"\n",
"Ingredients:\n",
"- 2 lb. alligator, boneless and cubed *\n",
"- 4 onions, diced\n",
"- 1 c. parsley, chopped\n",
"- 4 stalks celery, chopped\n",
"- 1 bell pepper, diced\n",
"- 1 c. catsup\n",
"- 2 Tbsp. Heinz steak sauce\n",
"- 2 Tbsp. soy sauce\n",
"- 2 Tbsp. Louisiana hot sauce\n",
"- 2 Tbsp. cornstarch\n",
"- 1 tsp. salt\n",
"- 2 tsp. red pepper (ground)\n",
"- 1/4 c. cooking oil\n",
"\n",
"Directions:\n",
"- *Alligator must be free of all fat; also dark meat is the best (leg and body meat), boneless.\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>GPT-4</th>\n",
" <th>My model</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>cook_time_over_30_mins</th>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>main_dish</th>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" GPT-4 My model\n",
"cook_time_over_30_mins True False\n",
"main_dish True False"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Veggie Casserole\n",
"\n",
"Ingredients:\n",
"- 1 (8 oz.) bag mixed veggies (corn, peas, carrots, green beans), steamed\n",
"- 1 c. celery\n",
"- 1 c. onions\n",
"- 1 c. Cheddar cheese\n",
"- 1 c. mayonnaise\n",
"\n",
"Directions:\n",
"- Mix above ingredients.\n",
"- Bake at 350° for 30 minutes, until bubbly.\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>GPT-4</th>\n",
" <th>My model</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>main_dish</th>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" GPT-4 My model\n",
"main_dish False True"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Rhonda'S Butter Chess Pie\n",
"\n",
"Ingredients:\n",
"- 5 eggs\n",
"- 1 stick melted butter\n",
"- 2 c. sugar\n",
"- 1 tsp. vanilla\n",
"- 1 Tbsp. cornstarch\n",
"- 1/2 c. buttermilk\n",
"- unbaked 9-inch deep dish pie shell\n",
"\n",
"Directions:\n",
"- Mix eggs with sugar and cornstarch until smooth.\n",
"- Add melted butter, vanilla and buttermilk.\n",
"- Bake at 350° for 30 minutes or until done.\n",
"- Let cool and chill.\n",
"- Similar to Furr's Butter Chess Pie.\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>GPT-4</th>\n",
" <th>My model</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>cook_time_over_30_mins</th>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" GPT-4 My model\n",
"cook_time_over_30_mins False True"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Broccoli Gorgonzola Cream Soup\n",
"\n",
"Ingredients:\n",
"- 2 heads Broccoli\n",
"- 700 milliliters Water\n",
"- 1 Onion, Peeled And Cut Into Chunks\n",
"- 1 pinch Salt\n",
"- 1 teaspoon Oregano\n",
"- 1 Potato, Peeled And Cut Into Chunks\n",
"- 200 grams Crumbled Gorgonzola\n",
"- 1 Tablespoon Finely Grated Parmesan\n",
"\n",
"Directions:\n",
"- Cut off the hard trunks of the broccoli and cut it into small pieces. Prepare a pot with water, add broccoli, onion, salt and oregano and boil for about 30 minutes.\n",
"- Add the peeled potato and boil for another 20 minutes. When vegetables are cooked, strain and save the stock.\n",
"- Using a hand blender, puree vegetables, adding as much stock as desired. Bring soup back to heat over low heat, and sir in gorgonzola. Remove from heat and add Parmesan.\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>GPT-4</th>\n",
" <th>My model</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>main_dish</th>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" GPT-4 My model\n",
"main_dish False True"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wild Rice With Cucumber And Feta\n",
"\n",
"Ingredients:\n",
"- 1 (8.5-ounce) package precooked wild rice (such as Archer Farms)\n",
"- 1 cup diced English cucumber\n",
"- 1 1/2 tablespoons olive oil\n",
"- 1 tablespoon fresh lemon juice\n",
"- 2 ounces crumbled feta cheese\n",
"- 1/2 teaspoon pepper\n",
"- 1/4 teaspoon salt\n",
"\n",
"Directions:\n",
"- Prepare rice according to the package directions.\n",
"- Combine cooked rice, cucumber, olive oil, lemon juice, and crumbled feta cheese in a medium bowl; toss to coat. Stir in pepper and salt.\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>GPT-4</th>\n",
" <th>My model</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>main_dish</th>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" GPT-4 My model\n",
"main_dish True False"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import numpy as np\n",
"\n",
"The main takeaway: generally places where GPT-4 and Llama 2 disagreed were genuinely ambiguous cases, where either answer was acceptable (eg. a dish that takes about 30 mins to cook might be classified as over 30 minutes by one, and under 30 minutes by the other).\n",
"np.random.seed(42)\n",
"\n",
"for row in test_data[test_data.accuracy < 1].sample(5).itertuples():\n",
" print(json.loads(row.instruction)[1][\"content\"])\n",
"\n",
" gpt4_output = parse_fn_call(row.output)\n",
" my_output = parse_fn_call(row.my_outputs)\n",
"\n",
" table = pd.DataFrame(\n",
" {\n",
" \"GPT-4\": gpt4_output,\n",
" \"My model\": my_output,\n",
" }\n",
" )\n",
"\n",
" table = table[table[\"GPT-4\"] != table[\"My model\"]]\n",
" display(table)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Looking at the outputs, it's clear that our model still makes some mistakes. But at the same time, there are plenty of examples like \"Rhonda's Butter Chess Pie\" where our model gets it right, even though GPT-4 got it wrong! And there are also cases like the \"Veggie Casserole\", where the \"right\" answer is truly ambiguous and really both answers are defensible.\n",
"\n",
"Interested in cost/latency benchmarking? You can check out [./benchmarking.ipynb](./benchmarking.ipynb) for an overview of my findings!"
]

View File

@@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 1,
"metadata": {},
"outputs": [
{
@@ -53,7 +53,7 @@
}
],
"source": [
"%pip install openpipe==3.0.3 python-dotenv==1.0.0 joblib==1.3.2"
"%pip install openpipe==3.0.3 python-dotenv==1.0.0 joblib==1.3.2 datasets==2.14.4"
]
},
{
@@ -65,7 +65,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 2,
"metadata": {},
"outputs": [
{
@@ -138,14 +138,14 @@
" - cook_time_over_30_mins\n",
" - main_dish\n",
"\n",
"That looks like a pretty random list, but there's actually an important unifying thread: we're looking for meals that my pescatarian brother can eat in his kitchen-less, near-window-less basement apartment in San Francisco! (If you haven't tried to get an apartment in SF you probably think I'm joking 😂.)\n",
"That looks like a pretty random list, but there's actually an important unifying thread: I'm looking for meals that my pescatarian brother/co-founder can make in his kitchen-less, near-window-less basement apartment in San Francisco! (If you haven't tried to get an apartment in SF you probably think I'm joking 😂.)\n",
"\n",
"We'll use [OpenPipe](https://github.com/openpipe/openpipe) to track our calls and form a training dataset. Create an account and a project, then copy your API key from https://app.openpipe.ai/project/settings into a file called `.env`. You can see an example in [./.env.example](./.env.example)."
"I'll use [OpenPipe](https://github.com/openpipe/openpipe) to track the API calls and form a training dataset. To follow along you'll need to create a free OpenPipe account, then copy your API key from https://app.openpipe.ai/project/settings into a file called `.env`. You can see an example in [./.env.example](./.env.example)."
]
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -153,13 +153,7 @@
"output_type": "stream",
"text": [
"Classifying first recipe:\n",
"------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"------------------\n",
"{'has_non_fish_meat': False, 'requires_oven': False, 'requires_stove': True, 'cook_time_over_30_mins': True, 'main_dish': True}\n"
]
}
@@ -225,7 +219,7 @@
" \"requires_oven\",\n",
" \"requires_stove\",\n",
" \"cook_time_over_30_mins\",\n",
" \"main_course\",\n",
" \"main_dish\",\n",
" ],\n",
" },\n",
" }\n",
@@ -246,12 +240,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"That's working, so I'll go ahead and classify all 5000 recipes with GPT-4. Using GPT-4 for this is slowwww and costs about $40. The model I'm fine-tuning will be much faster -- let's see if we can make it as good!"
"That's working, so I'll go ahead and classify all 5000 recipes with GPT-4. Using GPT-4 for this is slowwww and costs about $40. The model I'm fine-tuning will be much faster -- we'll see if we can make it as good!"
]
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 4,
"metadata": {},
"outputs": [
{
@@ -302,7 +296,12 @@
"Classifying recipe 4100/5000: Prune Cake\n",
"Classifying recipe 4200/5000: Strawberry Sorbet\n",
"Classifying recipe 4300/5000: Lemonade Chicken\n",
"Classifying recipe 4400/5000: Crock-Pot Vegetarian Chili\n"
"Classifying recipe 4400/5000: Crock-Pot Vegetarian Chili\n",
"Classifying recipe 4500/5000: Grandma Dickrell'S Molasses Cake - 1936\n",
"Classifying recipe 4600/5000: Creamed Corn Casserole\n",
"Classifying recipe 4700/5000: Homemade Croutons\n",
"Classifying recipe 4800/5000: Potatoes With Leeks And Gruyere\n",
"Classifying recipe 4900/5000: Chocolate Oatmeal Cookie\n"
]
}
],

File diff suppressed because it is too large Load Diff

View File

@@ -5,7 +5,7 @@ from peft import PeftModel
import os
def merge(config_file: str):
def merge_lora_model(config_file: str):
config = yaml.load(open(config_file, "r"), Loader=yaml.FullLoader)
base_model = config["base_model"]