Files
OpenPipe-llm/examples/classify-recipes/evaluate.ipynb
Kyle Corbitt 40638a7848 more work
2023-08-24 23:49:44 +00:00

657 lines
26 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"I have a model in `./models/run1/merged` that was trained on GPT-4's outputs to classify recipes. I need to figure out whether it does a good job at classifying recipes. I'll install dependencies first."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: vllm==0.1.3 in /usr/local/lib/python3.10/dist-packages (0.1.3)\n",
"Requirement already satisfied: pandas==2.0.3 in /usr/local/lib/python3.10/dist-packages (2.0.3)\n",
"Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (5.9.5)\n",
"Requirement already satisfied: ray>=2.5.1 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.6.3)\n",
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.1.99)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.24.4)\n",
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.0.1+cu118)\n",
"Requirement already satisfied: transformers>=4.31.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (4.33.0.dev0)\n",
"Requirement already satisfied: xformers>=0.0.19 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.0.21)\n",
"Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.101.1)\n",
"Requirement already satisfied: uvicorn in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.23.2)\n",
"Requirement already satisfied: pydantic<2 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.10.12)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas==2.0.3) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas==2.0.3) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas==2.0.3) (2023.3)\n",
"Requirement already satisfied: typing-extensions>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<2->vllm==0.1.3) (4.7.1)\n",
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas==2.0.3) (1.16.0)\n",
"Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (8.1.7)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (3.9.0)\n",
"Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.18.0)\n",
"Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.0.5)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (23.1)\n",
"Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.24.1)\n",
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (6.0)\n",
"Requirement already satisfied: aiosignal in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.3.1)\n",
"Requirement already satisfied: frozenlist in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.4.0)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (2.28.1)\n",
"Requirement already satisfied: grpcio>=1.42.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.57.0)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.0)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.1.2)\n",
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (2.0.0)\n",
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (3.25.0)\n",
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (15.0.7)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.15.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.16.4)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (2023.8.8)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.3.2)\n",
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (4.66.1)\n",
"Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi->vllm==0.1.3) (0.27.0)\n",
"Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn->vllm==0.1.3) (0.14.0)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.15.1->transformers>=4.31.0->vllm==0.1.3) (2023.6.0)\n",
"Requirement already satisfied: anyio<5,>=3.4.0 in /usr/local/lib/python3.10/dist-packages (from starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (3.7.1)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.0.0->vllm==0.1.3) (2.1.2)\n",
"Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (23.1.0)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (2023.6.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.29.1)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.8.10)\n",
"Requirement already satisfied: charset-normalizer<3,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm==0.1.3) (2.1.1)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm==0.1.3) (3.4)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm==0.1.3) (1.26.13)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm==0.1.3) (2022.12.7)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.0.0->vllm==0.1.3) (1.2.1)\n",
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.3.0)\n",
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.1.2)\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.10 -m pip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install vllm==0.1.3 pandas==2.0.3"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Remember I got a \"test.jsonl\" file from OpenPipe back in [./prepare.ipynb](./prepare.ipynb)? That's data from our dataset that we didn't use in training, so we can use it to check our model's performance."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"test_data = pd.read_json(\"./data/test.jsonl\", lines=True)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"During the training process Axolotl transformed our data into an instruction/response format known as the \"Alpaca format\" based on [the project that introduced it](https://github.com/tatsu-lab/stanford_alpaca). I need to transform my test data into the same format for best results."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sample prompt:\n",
"--------------\n",
"### Instruction:\n",
"[{\"role\":\"system\",\"content\":\"Your goal is to classify a recipe along several dimensions.Pay attention to the instructions.\"},{\"role\":\"user\",\"content\":\"Pan Gravy\\n\\nIngredients:\\n- 1/3 cup all purpose flour\\n- 1/3 cup turkey drippings\\n- 3 cup water or broth\\n- 1/8 to 1/4 teaspoon salt\\n- 1/8 tsp pepper\\n\\nDirections:\\n- In a skillet or roasting pan, add flour to drippings; blend well.\\n- Cook over medium heat 2 to 3 minutes until smooth and light brown, stirring constantly.\\n- Add water; cook until mixture boils and thickens, stirring constantly.\\n- Stir in salt and pepper.\\n- *Flour and drippings can be decreased to 1/4 cup each for thinner gravy.\\n- *\"}]\n",
"\n",
"### Response:\n",
"\n"
]
}
],
"source": [
"from axolotl.prompters import UnpromptedPrompter\n",
"\n",
"prompter = UnpromptedPrompter()\n",
"\n",
"\n",
"def format_prompt(input: str) -> str:\n",
" return next(prompter.build_prompt(input))\n",
"\n",
"\n",
"prompts = test_data[\"instruction\"].apply(format_prompt)\n",
"\n",
"print(f\"Sample prompt:\\n--------------\\n{prompts[0]}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next up, I'll use [vLLM](https://vllm.readthedocs.io/en/latest/) to efficiently process all the prompts in our test data with our own model."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO 08-24 22:01:58 llm_engine.py:70] Initializing an LLM engine with config: model='./models/run1/merged', tokenizer='./models/run1/merged', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n",
"INFO 08-24 22:02:46 llm_engine.py:196] # GPU blocks: 3419, # CPU blocks: 512\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 500/500 [00:37<00:00, 13.44it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sample output:\n",
"--------------\n",
"{\"role\":\"assistant\",\"content\":null,\"function_call\":{\"name\":\"classify\",\"arguments\":\"{\\n\\\"has_non_fish_meat\\\": true,\\n\\\"requires_oven\\\": false,\\n\\\"requires_stove\\\": true,\\n\\\"cook_time_over_30_mins\\\": false,\\n\\\"main_dish\\\": false\\n}\"}}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"from vllm import LLM, SamplingParams\n",
"\n",
"llm = LLM(model=\"./models/run1/merged\", max_num_batched_tokens=4096)\n",
"\n",
"sampling_params = SamplingParams(\n",
" # 120 should be fine for the work we're doing here.\n",
" max_tokens=120,\n",
" # This is a deterministic task so temperature=0 is best.\n",
" temperature=0,\n",
")\n",
"\n",
"my_outputs = llm.generate(prompts, sampling_params=sampling_params)\n",
"my_outputs = [o.outputs[0].text for o in my_outputs]\n",
"\n",
"test_data[\"my_outputs\"] = my_outputs\n",
"\n",
"print(f\"Sample output:\\n--------------\\n{my_outputs[0]}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok, we have our outputs! There are 5 categories we classify each recipe on, so let's check what percentage of the time our model's output matches GPT-4's. I'll write a quick eval function for that:"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Overall accuracy: 0.95\n"
]
}
],
"source": [
"import json\n",
"\n",
"\n",
"def parse_fn_call(str):\n",
" \"\"\"Parse the function call arguments from the response\"\"\"\n",
" response_dict = json.loads(str)\n",
" args_dict = json.loads(response_dict[\"function_call\"][\"arguments\"])\n",
"\n",
" return args_dict\n",
"\n",
"\n",
"def calculate_accuracy(row):\n",
" \"\"\"Calculate the fraction of my model's outputs that match the reference outputs\"\"\"\n",
" true_outputs = parse_fn_call(row[\"output\"])\n",
" my_outputs = parse_fn_call(row[\"my_outputs\"])\n",
"\n",
" num_matching_outputs = 0\n",
" for key in true_outputs.keys():\n",
" if key in my_outputs and true_outputs[key] == my_outputs[key]:\n",
" num_matching_outputs += 1\n",
"\n",
" return num_matching_outputs / len(true_outputs)\n",
"\n",
"\n",
"test_data[\"accuracy\"] = test_data.apply(calculate_accuracy, axis=1)\n",
"\n",
"print(f\"Overall accuracy: {test_data['accuracy'].mean():.2f}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Not bad! However, there are still a few rows where the model outputs don't match. Let's take a closer look."
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Alligator Sauce Piquant\n",
"\n",
"Ingredients:\n",
"- 2 lb. alligator, boneless and cubed *\n",
"- 4 onions, diced\n",
"- 1 c. parsley, chopped\n",
"- 4 stalks celery, chopped\n",
"- 1 bell pepper, diced\n",
"- 1 c. catsup\n",
"- 2 Tbsp. Heinz steak sauce\n",
"- 2 Tbsp. soy sauce\n",
"- 2 Tbsp. Louisiana hot sauce\n",
"- 2 Tbsp. cornstarch\n",
"- 1 tsp. salt\n",
"- 2 tsp. red pepper (ground)\n",
"- 1/4 c. cooking oil\n",
"\n",
"Directions:\n",
"- *Alligator must be free of all fat; also dark meat is the best (leg and body meat), boneless.\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>GPT-4</th>\n",
" <th>My model</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>cook_time_over_30_mins</th>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>main_dish</th>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" GPT-4 My model\n",
"cook_time_over_30_mins True False\n",
"main_dish True False"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Veggie Casserole\n",
"\n",
"Ingredients:\n",
"- 1 (8 oz.) bag mixed veggies (corn, peas, carrots, green beans), steamed\n",
"- 1 c. celery\n",
"- 1 c. onions\n",
"- 1 c. Cheddar cheese\n",
"- 1 c. mayonnaise\n",
"\n",
"Directions:\n",
"- Mix above ingredients.\n",
"- Bake at 350° for 30 minutes, until bubbly.\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>GPT-4</th>\n",
" <th>My model</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>main_dish</th>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" GPT-4 My model\n",
"main_dish False True"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Rhonda'S Butter Chess Pie\n",
"\n",
"Ingredients:\n",
"- 5 eggs\n",
"- 1 stick melted butter\n",
"- 2 c. sugar\n",
"- 1 tsp. vanilla\n",
"- 1 Tbsp. cornstarch\n",
"- 1/2 c. buttermilk\n",
"- unbaked 9-inch deep dish pie shell\n",
"\n",
"Directions:\n",
"- Mix eggs with sugar and cornstarch until smooth.\n",
"- Add melted butter, vanilla and buttermilk.\n",
"- Bake at 350° for 30 minutes or until done.\n",
"- Let cool and chill.\n",
"- Similar to Furr's Butter Chess Pie.\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>GPT-4</th>\n",
" <th>My model</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>cook_time_over_30_mins</th>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" GPT-4 My model\n",
"cook_time_over_30_mins False True"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Broccoli Gorgonzola Cream Soup\n",
"\n",
"Ingredients:\n",
"- 2 heads Broccoli\n",
"- 700 milliliters Water\n",
"- 1 Onion, Peeled And Cut Into Chunks\n",
"- 1 pinch Salt\n",
"- 1 teaspoon Oregano\n",
"- 1 Potato, Peeled And Cut Into Chunks\n",
"- 200 grams Crumbled Gorgonzola\n",
"- 1 Tablespoon Finely Grated Parmesan\n",
"\n",
"Directions:\n",
"- Cut off the hard trunks of the broccoli and cut it into small pieces. Prepare a pot with water, add broccoli, onion, salt and oregano and boil for about 30 minutes.\n",
"- Add the peeled potato and boil for another 20 minutes. When vegetables are cooked, strain and save the stock.\n",
"- Using a hand blender, puree vegetables, adding as much stock as desired. Bring soup back to heat over low heat, and sir in gorgonzola. Remove from heat and add Parmesan.\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>GPT-4</th>\n",
" <th>My model</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>main_dish</th>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" GPT-4 My model\n",
"main_dish False True"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wild Rice With Cucumber And Feta\n",
"\n",
"Ingredients:\n",
"- 1 (8.5-ounce) package precooked wild rice (such as Archer Farms)\n",
"- 1 cup diced English cucumber\n",
"- 1 1/2 tablespoons olive oil\n",
"- 1 tablespoon fresh lemon juice\n",
"- 2 ounces crumbled feta cheese\n",
"- 1/2 teaspoon pepper\n",
"- 1/4 teaspoon salt\n",
"\n",
"Directions:\n",
"- Prepare rice according to the package directions.\n",
"- Combine cooked rice, cucumber, olive oil, lemon juice, and crumbled feta cheese in a medium bowl; toss to coat. Stir in pepper and salt.\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>GPT-4</th>\n",
" <th>My model</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>main_dish</th>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" GPT-4 My model\n",
"main_dish True False"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import numpy as np\n",
"\n",
"np.random.seed(42)\n",
"\n",
"for row in test_data[test_data.accuracy < 1].sample(5).itertuples():\n",
" print(json.loads(row.instruction)[1][\"content\"])\n",
"\n",
" gpt4_output = parse_fn_call(row.output)\n",
" my_output = parse_fn_call(row.my_outputs)\n",
"\n",
" table = pd.DataFrame(\n",
" {\n",
" \"GPT-4\": gpt4_output,\n",
" \"My model\": my_output,\n",
" }\n",
" )\n",
"\n",
" table = table[table[\"GPT-4\"] != table[\"My model\"]]\n",
" display(table)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Looking at the outputs, it's clear that our model still makes some mistakes. But at the same time, there are plenty of examples like \"Rhonda's Butter Chess Pie\" where our model gets it right, even though GPT-4 got it wrong! And there are also cases like the \"Veggie Casserole\", where the \"right\" answer is truly ambiguous and really both answers are defensible.\n",
"\n",
"Interested in cost/latency benchmarking? You can check out [./benchmarking.ipynb](./benchmarking.ipynb) for an overview of my findings!"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}