more benchmarking

This commit is contained in:
Kyle Corbitt
2023-08-24 19:52:31 +00:00
parent 13bac46e0b
commit 14eae45d18
4 changed files with 455 additions and 52 deletions

View File

@@ -0,0 +1,248 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"I'm pretty happy with my model's accuracy relative to GPT-4. How does it compare cost-wise?\n",
"\n",
"I'll really push this to its limits -- let's see how quickly our poor model can classify the [full 2-million-recipe dataset](https://huggingface.co/datasets/corbt/all-recipes) 😈."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: datasets==2.14.4 in /usr/local/lib/python3.10/dist-packages (2.14.4)\n",
"Requirement already satisfied: vllm==0.1.3 in /usr/local/lib/python3.10/dist-packages (0.1.3)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (1.24.4)\n",
"Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (12.0.1)\n",
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.3.7)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.0.3)\n",
"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.28.1)\n",
"Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (4.66.1)\n",
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.3.0)\n",
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.70.15)\n",
"Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2023.6.0)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.8.5)\n",
"Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.16.4)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (23.1)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (6.0)\n",
"Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (5.9.5)\n",
"Requirement already satisfied: ray>=2.5.1 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.6.3)\n",
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.1.99)\n",
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.0.1+cu118)\n",
"Requirement already satisfied: transformers>=4.31.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (4.33.0.dev0)\n",
"Requirement already satisfied: xformers>=0.0.19 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.0.21)\n",
"Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.101.1)\n",
"Requirement already satisfied: uvicorn in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.23.2)\n",
"Requirement already satisfied: pydantic<2 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.10.12)\n",
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (23.1.0)\n",
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (2.1.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (4.0.3)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.9.2)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.4.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.3.1)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (3.9.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (4.7.1)\n",
"Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (8.1.7)\n",
"Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.18.0)\n",
"Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.0.5)\n",
"Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.24.1)\n",
"Requirement already satisfied: grpcio>=1.42.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.57.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (3.4)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (1.26.13)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (2022.12.7)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.0)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.1.2)\n",
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (2.0.0)\n",
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (3.25.0)\n",
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (15.0.7)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (2023.8.8)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.3.2)\n",
"Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi->vllm==0.1.3) (0.27.0)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
"Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn->vllm==0.1.3) (0.14.0)\n",
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->datasets==2.14.4) (1.16.0)\n",
"Requirement already satisfied: anyio<5,>=3.4.0 in /usr/local/lib/python3.10/dist-packages (from starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (3.7.1)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.0.0->vllm==0.1.3) (2.1.2)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (2023.6.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.29.1)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.8.10)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.0.0->vllm==0.1.3) (1.2.1)\n",
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.3.0)\n",
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.1.2)\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.10 -m pip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install datasets==2.14.4 vllm==0.1.3"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of recipes: 2,147,248\n"
]
}
],
"source": [
"from datasets import load_dataset\n",
"\n",
"all_recipes = load_dataset(\"corbt/all-recipes\")[\"train\"][\"input\"]\n",
"\n",
"print(f\"Number of recipes: {len(all_recipes):,}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO 08-24 19:38:29 llm_engine.py:70] Initializing an LLM engine with config: model='./models/run1/merged', tokenizer='./models/run1/merged', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n",
"INFO 08-24 19:39:48 llm_engine.py:196] # GPU blocks: 3419, # CPU blocks: 512\n"
]
}
],
"source": [
"from vllm import LLM, SamplingParams\n",
"\n",
"llm = LLM(model=\"./models/run1/merged\", max_num_batched_tokens=4096)\n",
"\n",
"sampling_params = SamplingParams(\n",
" # 120 should be fine for the work we're doing here.\n",
" max_tokens=120,\n",
" # This is a deterministic task so temperature=0 is best.\n",
" temperature=0,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Start time: 1692906050.3340027\n",
"Processing recipes 0 to 10,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 10000/10000 [04:51<00:00, 34.30it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing recipes 10,000 to 20,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 10000/10000 [04:54<00:00, 33.98it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing recipes 20,000 to 30,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 30%|███ | 3008/10000 [01:29<03:27, 33.73it/s]"
]
}
],
"source": [
"# We'll process our recipes in batches of 10,000.\n",
"\n",
"import time\n",
"\n",
"BATCH_SIZE = 10000\n",
"all_outputs = []\n",
"\n",
"start_time = time.time()\n",
"print(f\"Start time: {start_time}\")\n",
"for i in range(0, len(all_recipes), BATCH_SIZE):\n",
" print(f\"Processing recipes {i:,} to {i+BATCH_SIZE:,}...\")\n",
" outputs = llm.generate(\n",
" all_recipes[i : i + BATCH_SIZE], sampling_params=sampling_params\n",
" )\n",
"\n",
" all_outputs.extend([o.outputs[0].text for o in outputs])\n",
"\n",
"end_time = time.time()\n",
"print(f\"End time: {end_time}\")\n",
"print(f\"Total hours: {((end_time - start_time) / 3600):.2f}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Nice! I've processed all 2,147,248 recipes in under 17 hours. Let's do a cost comparison with GPT-3.5 and GPT-4. I'll use the GPT-4 latency/cost numbers based on the 5000 samples used to generate our model's training data."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -9,7 +9,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -111,18 +111,20 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 7,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"ename": "ModuleNotFoundError", "name": "stdout",
"evalue": "No module named 'axolotl.prompters'", "output_type": "stream",
"output_type": "error", "text": [
"traceback": [ "Sample prompt:\n",
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "--------------\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", "### Instruction:\n",
"Cell \u001b[0;32mIn[5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39maxolotl\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mprompters\u001b[39;00m \u001b[39mimport\u001b[39;00m UnpromptedPrompter\n\u001b[1;32m 2\u001b[0m prompter \u001b[39m=\u001b[39m UnpromptedPrompter()\n\u001b[1;32m 4\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mformat_prompt\u001b[39m(\u001b[39minput\u001b[39m: \u001b[39mstr\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mstr\u001b[39m:\n", "[{\"role\":\"system\",\"content\":\"Your goal is to classify a recipe along several dimensions. You should \"},{\"role\":\"user\",\"content\":\"Strawberry Sorbet\\n\\nIngredients:\\n- 2 cups chopped strawberries Safeway 1 lb For $3.99 thru 02/09\\n- 1 cup cold water\\n- 2 cups boiling water\\n- 1 pkg. (4-serving size) JELL-O Strawberry Flavor Gelatin\\n- 1/2 cup sugar\\n\\nDirections:\\n- Place strawberries and cold water in blender container; cover.\\n- Blend on high speed until smooth.\\n- Stir boiling water into combined dry gelatin mix and sugar in medium bowl at least 2 minutes until completely dissolved.\\n- Add strawberry mixture; mix well.\\n- Pour into 9-inch square pan.\\n- Freeze 1 to 1-1/2 hours or until ice crystals form 1 inch around edges of pan.\\n- Spoon half of the gelatin mixture into blender container; cover.\\n- Blend on high speed about 30 seconds or until smooth; pour into bowl.\\n- Repeat with remaining gelatin mixture.\\n- Add to blended gelatin mixture in bowl; mix well.\\n- Return to pan.\\n- Freeze 6 hours or overnight until firm.\\n- Scoop into dessert dishes to serve.\\n- Store leftover sorbet in freezer.\"}]\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'axolotl.prompters'" "\n",
"### Response:\n",
"\n"
] ]
} }
], ],
@@ -136,9 +138,134 @@
" return next(prompter.build_prompt(input))\n", " return next(prompter.build_prompt(input))\n",
"\n", "\n",
"\n", "\n",
"prompts = test_data[\"input\"].apply(format_prompt)\n", "prompts = test_data[\"instruction\"].apply(format_prompt)\n",
"\n", "\n",
"print(f\"Sample prompt:\\n-----------\\n{prompts[0]}\")\n" "print(f\"Sample prompt:\\n--------------\\n{prompts[0]}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next up, I'll use [vLLM](https://vllm.readthedocs.io/en/latest/) to efficiently process all the prompts in our test data with our own model."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO 08-24 18:58:03 llm_engine.py:70] Initializing an LLM engine with config: model='./models/run1/merged', tokenizer='./models/run1/merged', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n",
"INFO 08-24 18:59:18 llm_engine.py:196] # GPU blocks: 3419, # CPU blocks: 512\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 201/201 [00:16<00:00, 12.18it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sample output:\n",
"--------------\n",
"{\"role\":\"assistant\",\"content\":null,\"function_call\":{\"name\":\"classify\",\"arguments\":\"{\\n\\\"has_non_fish_meat\\\": false,\\n\\\"requires_oven\\\": false,\\n\\\"requires_stove\\\": true,\\n\\\"cook_time_over_30_mins\\\": false,\\n\\\"main_course\\\": false\\n}\"}}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"from vllm import LLM, SamplingParams\n",
"\n",
"llm = LLM(model=\"./models/run1/merged\", max_num_batched_tokens=4096)\n",
"\n",
"sampling_params = SamplingParams(\n",
" # 120 should be fine for the work we're doing here.\n",
" max_tokens=120,\n",
" # This is a deterministic task so temperature=0 is best.\n",
" temperature=0,\n",
")\n",
"\n",
"my_outputs = llm.generate(prompts, sampling_params=sampling_params)\n",
"my_outputs = [o.outputs[0].text for o in my_outputs]\n",
"\n",
"test_data[\"my_outputs\"] = my_outputs\n",
"\n",
"print(f\"Sample output:\\n--------------\\n{my_outputs[0]}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok, we have our outputs! Since there are 5 categories we classify each recipe on, a natural metric would be for each recipe and each category, what percentage of the time our model's output matches GPT-4's. I'll write a quick eval function to check that."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Overall accuracy: 0.91\n"
]
}
],
"source": [
"import json\n",
"\n",
"\n",
"def parse_fn_call_args(str):\n",
" \"\"\"Parse the function call arguments from the response\"\"\"\n",
" response_dict = json.loads(str)\n",
" args_dict = json.loads(response_dict[\"function_call\"][\"arguments\"])\n",
"\n",
" return args_dict\n",
"\n",
"\n",
"def calculate_accuracy(row):\n",
" \"\"\"Calculate the fraction of my model's outputs that match the reference outputs\"\"\"\n",
" true_outputs = parse_fn_call_args(row[\"output\"])\n",
" my_outputs = parse_fn_call_args(row[\"my_outputs\"])\n",
"\n",
" num_matching_outputs = 0\n",
" for key in true_outputs.keys():\n",
" if true_outputs[key] == my_outputs[key]:\n",
" num_matching_outputs += 1\n",
"\n",
" return num_matching_outputs / len(true_outputs)\n",
"\n",
"\n",
"test_data[\"accuracy\"] = test_data.apply(calculate_accuracy, axis=1)\n",
"\n",
"print(f\"Overall accuracy: {test_data['accuracy'].mean():.2f}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Not bad! Of course, the next obvious step is to look at where Llama 2 is \"wrong\" and evaluate the types of errors it makes. I've exported a Google Sheet where I did exactly that with an earlier version of this model trained on the same dataset. You can see that [here](https://docs.google.com/spreadsheets/d/1vn-nA0CRQwz-BvEYvxUcO1-EP80ZbPhcxDoCTttvsmI/edit?usp=sharing).\n",
"\n",
"The main takeaway: generally places where GPT-4 and Llama 2 disagreed were genuinely ambiguous cases, where either answer was acceptable (eg. a dish that takes about 30 mins to cook might be classified as over 30 minutes by one, and under 30 minutes by the other).\n",
"\n",
"Interested in cost/latency benchmarking? You can check out [./benchmarking.ipynb](./benchmarking.ipynb) for an overview of my findings!"
] ]
} }
], ],

View File

@@ -11,7 +11,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 11,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -19,6 +19,8 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Requirement already satisfied: openpipe==3.0.3 in /usr/local/lib/python3.10/dist-packages (3.0.3)\n", "Requirement already satisfied: openpipe==3.0.3 in /usr/local/lib/python3.10/dist-packages (3.0.3)\n",
"Requirement already satisfied: python-dotenv==1.0.0 in /usr/local/lib/python3.10/dist-packages (1.0.0)\n",
"Requirement already satisfied: joblib==1.3.2 in /usr/local/lib/python3.10/dist-packages (1.3.2)\n",
"Requirement already satisfied: attrs<24.0.0,>=23.1.0 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (23.1.0)\n", "Requirement already satisfied: attrs<24.0.0,>=23.1.0 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (23.1.0)\n",
"Requirement already satisfied: httpx<0.25.0,>=0.24.1 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (0.24.1)\n", "Requirement already satisfied: httpx<0.25.0,>=0.24.1 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (0.24.1)\n",
"Requirement already satisfied: openai<0.28.0,>=0.27.8 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (0.27.9)\n", "Requirement already satisfied: openai<0.28.0,>=0.27.8 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (0.27.9)\n",
@@ -63,14 +65,15 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 12,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Recipe dataset:\n" "Recipe dataset shape:\n",
"------------------\n"
] ]
}, },
{ {
@@ -90,7 +93,7 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"First recipe:\n", "First recipe:\n",
" Shrimp Creole\n", "------------------ Shrimp Creole\n",
"\n", "\n",
"Ingredients:\n", "Ingredients:\n",
"- 20 shrimp (8 oz.)\n", "- 20 shrimp (8 oz.)\n",
@@ -128,27 +131,37 @@
"source": [ "source": [
"Mm, delicious. Anyway, we need to generate a training dataset. We'll call GPT-4 on each of our examples.\n", "Mm, delicious. Anyway, we need to generate a training dataset. We'll call GPT-4 on each of our examples.\n",
"\n", "\n",
"We'll use [OpenPipe](https://github.com/openpipe/openpipe) to track our calls and form a training dataset. Create an account and a project, then copy your API key from https://app.openpipe.ai/project/settings into a file called `.env`." "In this case, I'll ask GPT-4 to classify each recipe along 5 dimensions:\n",
" - has_non_fish_meat\n",
" - requires_oven\n",
" - requires_stove\n",
" - cook_time_over_30_mins\n",
" - main_dish\n",
"\n",
"That looks like a pretty random list, but there's actually an important unifying thread: we're looking for meals that my pescatarian brother can eat in his kitchen-less, near-window-less basement apartment in San Francisco! (If you haven't tried to get an apartment in SF you probably think I'm joking 😂.)\n",
"\n",
"We'll use [OpenPipe](https://github.com/openpipe/openpipe) to track our calls and form a training dataset. Create an account and a project, then copy your API key from https://app.openpipe.ai/project/settings into a file called `.env`. You can see an example in [./.env.example](./.env.example)."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 13,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "name": "stdout",
"text/plain": [ "output_type": "stream",
"{'has_non_fish_meat': False,\n", "text": [
" 'requires_oven': True,\n", "Classifying first recipe:\n",
" 'requires_stove': True,\n", "------------------\n"
" 'cook_time_over_30_mins': False,\n", ]
" 'main_dish': False}" },
] {
}, "name": "stdout",
"execution_count": 9, "output_type": "stream",
"metadata": {}, "text": [
"output_type": "execute_result" "{'has_non_fish_meat': False, 'requires_oven': False, 'requires_stove': True, 'cook_time_over_30_mins': True, 'main_dish': True}\n"
]
} }
], ],
"source": [ "source": [
@@ -157,10 +170,13 @@
"import os\n", "import os\n",
"import dotenv\n", "import dotenv\n",
"\n", "\n",
"# Use `dotenv` to load the contents of the `.env` file into the environment\n",
"dotenv.load_dotenv()\n", "dotenv.load_dotenv()\n",
"\n", "\n",
"# Configure OpenPipe using the API key from the environment\n",
"configure_openpipe(api_key=os.environ[\"OPENPIPE_API_KEY\"])\n", "configure_openpipe(api_key=os.environ[\"OPENPIPE_API_KEY\"])\n",
"\n", "\n",
"# Configure OpenAI using the API key from the environment\n",
"openai.api_key = os.environ[\"OPENAI_API_KEY\"]\n", "openai.api_key = os.environ[\"OPENAI_API_KEY\"]\n",
"\n", "\n",
"\n", "\n",
@@ -222,12 +238,20 @@
" return json.loads(completion.choices[0].message.function_call.arguments)\n", " return json.loads(completion.choices[0].message.function_call.arguments)\n",
"\n", "\n",
"\n", "\n",
"classify_recipe(recipes[\"recipe\"][-1])\n" "print(\"Classifying first recipe:\\n------------------\")\n",
"print(classify_recipe(recipes[\"recipe\"][0]))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That's working, so I'll go ahead and classify all 5000 recipes with GPT-4. Using GPT-4 for this is slowwww and costs about $40. The model I'm fine-tuning will be much faster -- let's see if we can make it as good!"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 14,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -238,8 +262,6 @@
"Classifying recipe 100/5000: Spoon Bread\n", "Classifying recipe 100/5000: Spoon Bread\n",
"Classifying recipe 200/5000: Quadrangle Grille'S Pumpkin-Walnut Cheesecake\n", "Classifying recipe 200/5000: Quadrangle Grille'S Pumpkin-Walnut Cheesecake\n",
"Classifying recipe 300/5000: Broccoli Casserole\n", "Classifying recipe 300/5000: Broccoli Casserole\n",
"Error reporting to OpenPipe: 520 is not a valid HTTPStatus\n",
"520 is not a valid HTTPStatus\n",
"Classifying recipe 400/5000: Paal Payasam (3-Ingredient Rice Pudding)\n", "Classifying recipe 400/5000: Paal Payasam (3-Ingredient Rice Pudding)\n",
"Classifying recipe 500/5000: Dirt Dessert\n", "Classifying recipe 500/5000: Dirt Dessert\n",
"Classifying recipe 600/5000: Dolma, Stuffed Dried Peppers And Eggplants\n", "Classifying recipe 600/5000: Dolma, Stuffed Dried Peppers And Eggplants\n",
@@ -265,21 +287,22 @@
"Classifying recipe 2600/5000: Pepperoni Bread\n", "Classifying recipe 2600/5000: Pepperoni Bread\n",
"Classifying recipe 2700/5000: Sabzi Polow\n", "Classifying recipe 2700/5000: Sabzi Polow\n",
"Classifying recipe 2800/5000: Italian Vegetable Pizzas\n", "Classifying recipe 2800/5000: Italian Vegetable Pizzas\n",
"Error classifying recipe 2801: Bad gateway. {\"error\":{\"code\":502,\"message\":\"Bad gateway.\",\"param\":null,\"type\":\"cf_bad_gateway\"}} 502 {'error': {'code': 502, 'message': 'Bad gateway.', 'param': None, 'type': 'cf_bad_gateway'}} {'Date': 'Thu, 24 Aug 2023 15:44:45 GMT', 'Content-Type': 'application/json', 'Content-Length': '84', 'Connection': 'keep-alive', 'X-Frame-Options': 'SAMEORIGIN', 'Referrer-Policy': 'same-origin', 'Cache-Control': 'private, max-age=0, no-store, no-cache, must-revalidate, post-check=0, pre-check=0', 'Expires': 'Thu, 01 Jan 1970 00:00:01 GMT', 'Server': 'cloudflare', 'CF-RAY': '7fbca943df684de1-MCI', 'alt-svc': 'h3=\":443\"; ma=86400'}\n",
"Classifying recipe 2900/5000: Hot Fudge Sauce, Soda Shop Style\n", "Classifying recipe 2900/5000: Hot Fudge Sauce, Soda Shop Style\n",
"Classifying recipe 3000/5000: Meatball Soup With Vegetables And Brown Rice\n", "Classifying recipe 3000/5000: Meatball Soup With Vegetables And Brown Rice\n",
"Classifying recipe 3100/5000: Herbed Potatoes And Onions\n", "Classifying recipe 3100/5000: Herbed Potatoes And Onions\n",
"Classifying recipe 3200/5000: Apple Crunch Pie (2 Extra Servings)\n", "Classifying recipe 3200/5000: Apple Crunch Pie (2 Extra Servings)\n",
"Classifying recipe 3300/5000: Pineapple-Orange Punch\n", "Classifying recipe 3300/5000: Pineapple-Orange Punch\n",
"Classifying recipe 3400/5000: Turkey Veggie Burgers With Avocado Mayo\n", "Classifying recipe 3400/5000: Turkey Veggie Burgers With Avocado Mayo\n",
"Error reporting to OpenPipe: 520 is not a valid HTTPStatus\n",
"520 is not a valid HTTPStatus\n",
"Classifying recipe 3500/5000: Pear & Goat Cheese Salad\n", "Classifying recipe 3500/5000: Pear & Goat Cheese Salad\n",
"Classifying recipe 3600/5000: Triple Chocolate Cookies\n", "Classifying recipe 3600/5000: Triple Chocolate Cookies\n",
"Classifying recipe 3700/5000: Strawberry Banana Yogurt Pops\n", "Classifying recipe 3700/5000: Strawberry Banana Yogurt Pops\n",
"Error classifying recipe 3779: Request timed out: HTTPSConnectionPool(host='api.openai.com', port=443): Read timed out. (read timeout=600)\n",
"Classifying recipe 3800/5000: Chicken Croquettes\n", "Classifying recipe 3800/5000: Chicken Croquettes\n",
"Classifying recipe 3900/5000: Mushroom Casserole\n" "Classifying recipe 3900/5000: Mushroom Casserole\n",
"Classifying recipe 4000/5000: Vegetarian Summer Roll\n",
"Classifying recipe 4100/5000: Prune Cake\n",
"Classifying recipe 4200/5000: Strawberry Sorbet\n",
"Classifying recipe 4300/5000: Lemonade Chicken\n",
"Classifying recipe 4400/5000: Crock-Pot Vegetarian Chili\n"
] ]
} }
], ],
@@ -295,12 +318,14 @@
] ]
}, },
{ {
"cell_type": "code", "cell_type": "markdown",
"execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [],
"source": [ "source": [
"Ok, we have our " "Ok, now that my recipes are classified I'll download the training data. \n",
"\n",
"Next up I'll train the model -- check out [./train.ipynb](./train.ipynb) for details! Just go to https://app.openpipe.ai/request-logs, select all the logs you created, and click \"Export\". The default 10% testing split is fine for this dataset size.\n",
"\n",
"I got two files from that: `train.jsonl` and `test.jsonl`. I moved both of them into this repository under `./data/`."
] ]
} }
], ],

View File

@@ -9,7 +9,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -51,12 +51,12 @@
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.10 -m pip install --upgrade pip\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.10 -m pip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n", "Note: you may need to restart the kernel to use updated packages.\n",
"fatal: destination path 'axolotl' already exists and is not an empty directory.\n", "fatal: destination path 'axolotl' already exists and is not an empty directory.\n",
"Obtaining file:///workspace/gpt4-fine-tuning/axolotl\n", "Obtaining file:///workspace/OpenPipe/examples/classify-recipes/axolotl\n",
" Preparing metadata (setup.py) ... \u001b[?25ldone\n", " Preparing metadata (setup.py) ... \u001b[?25ldone\n",
"\u001b[?25hCollecting transformers@ git+https://github.com/huggingface/transformers.git (from axolotl==0.1)\n", "\u001b[?25hCollecting transformers@ git+https://github.com/huggingface/transformers.git (from axolotl==0.1)\n",
" Cloning https://github.com/huggingface/transformers.git to /tmp/pip-install-8yfermge/transformers_22e4388baf16446d8445557008e38efe\n", " Cloning https://github.com/huggingface/transformers.git to /tmp/pip-install-o3o9dk76/transformers_99ed72a1465e41bba173c85b3be82a1b\n",
" Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-install-8yfermge/transformers_22e4388baf16446d8445557008e38efe\n", " Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-install-o3o9dk76/transformers_99ed72a1465e41bba173c85b3be82a1b\n",
" Resolved https://github.com/huggingface/transformers.git to commit 4d40109c3a93c9b8bbca204cb046ed510f1c72e8\n", " Resolved https://github.com/huggingface/transformers.git to commit f26099e7b5cf579f99a42bab6ddd371bf2c8d548\n",
" Installing build dependencies ... \u001b[?25ldone\n", " Installing build dependencies ... \u001b[?25ldone\n",
"\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n", "\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n",
"\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n", "\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n",
@@ -152,10 +152,6 @@
"Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb->axolotl==0.1) (5.0.0)\n", "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb->axolotl==0.1) (5.0.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.0.0->bert-score==0.3.13->axolotl==0.1) (2.1.2)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.0.0->bert-score==0.3.13->axolotl==0.1) (2.1.2)\n",
"Installing collected packages: axolotl\n", "Installing collected packages: axolotl\n",
" Attempting uninstall: axolotl\n",
" Found existing installation: axolotl 0.1\n",
" Uninstalling axolotl-0.1:\n",
" Successfully uninstalled axolotl-0.1\n",
" Running setup.py develop for axolotl\n", " Running setup.py develop for axolotl\n",
"Successfully installed axolotl\n", "Successfully installed axolotl\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
@@ -421,6 +417,13 @@
"final_model_dir = merge_lora_model(\"training-config.yaml\")\n", "final_model_dir = merge_lora_model(\"training-config.yaml\")\n",
"print(f\"Final model saved to '{final_model_dir}'\")\n" "print(f\"Final model saved to '{final_model_dir}'\")\n"
] ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok, I have a model, but is it actually any good? I'll run some evaluations in [./evaluate.ipynb](./evaluate.ipynb) to check."
]
} }
], ],
"metadata": { "metadata": {