more benchmarking
This commit is contained in:
@@ -9,7 +9,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -111,18 +111,20 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "ModuleNotFoundError",
|
||||
"evalue": "No module named 'axolotl.prompters'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39maxolotl\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mprompters\u001b[39;00m \u001b[39mimport\u001b[39;00m UnpromptedPrompter\n\u001b[1;32m 2\u001b[0m prompter \u001b[39m=\u001b[39m UnpromptedPrompter()\n\u001b[1;32m 4\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mformat_prompt\u001b[39m(\u001b[39minput\u001b[39m: \u001b[39mstr\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mstr\u001b[39m:\n",
|
||||
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'axolotl.prompters'"
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample prompt:\n",
|
||||
"--------------\n",
|
||||
"### Instruction:\n",
|
||||
"[{\"role\":\"system\",\"content\":\"Your goal is to classify a recipe along several dimensions. You should \"},{\"role\":\"user\",\"content\":\"Strawberry Sorbet\\n\\nIngredients:\\n- 2 cups chopped strawberries Safeway 1 lb For $3.99 thru 02/09\\n- 1 cup cold water\\n- 2 cups boiling water\\n- 1 pkg. (4-serving size) JELL-O Strawberry Flavor Gelatin\\n- 1/2 cup sugar\\n\\nDirections:\\n- Place strawberries and cold water in blender container; cover.\\n- Blend on high speed until smooth.\\n- Stir boiling water into combined dry gelatin mix and sugar in medium bowl at least 2 minutes until completely dissolved.\\n- Add strawberry mixture; mix well.\\n- Pour into 9-inch square pan.\\n- Freeze 1 to 1-1/2 hours or until ice crystals form 1 inch around edges of pan.\\n- Spoon half of the gelatin mixture into blender container; cover.\\n- Blend on high speed about 30 seconds or until smooth; pour into bowl.\\n- Repeat with remaining gelatin mixture.\\n- Add to blended gelatin mixture in bowl; mix well.\\n- Return to pan.\\n- Freeze 6 hours or overnight until firm.\\n- Scoop into dessert dishes to serve.\\n- Store leftover sorbet in freezer.\"}]\n",
|
||||
"\n",
|
||||
"### Response:\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -136,9 +138,134 @@
|
||||
" return next(prompter.build_prompt(input))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"prompts = test_data[\"input\"].apply(format_prompt)\n",
|
||||
"prompts = test_data[\"instruction\"].apply(format_prompt)\n",
|
||||
"\n",
|
||||
"print(f\"Sample prompt:\\n-----------\\n{prompts[0]}\")\n"
|
||||
"print(f\"Sample prompt:\\n--------------\\n{prompts[0]}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next up, I'll use [vLLM](https://vllm.readthedocs.io/en/latest/) to efficiently process all the prompts in our test data with our own model."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO 08-24 18:58:03 llm_engine.py:70] Initializing an LLM engine with config: model='./models/run1/merged', tokenizer='./models/run1/merged', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n",
|
||||
"INFO 08-24 18:59:18 llm_engine.py:196] # GPU blocks: 3419, # CPU blocks: 512\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed prompts: 100%|██████████| 201/201 [00:16<00:00, 12.18it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample output:\n",
|
||||
"--------------\n",
|
||||
"{\"role\":\"assistant\",\"content\":null,\"function_call\":{\"name\":\"classify\",\"arguments\":\"{\\n\\\"has_non_fish_meat\\\": false,\\n\\\"requires_oven\\\": false,\\n\\\"requires_stove\\\": true,\\n\\\"cook_time_over_30_mins\\\": false,\\n\\\"main_course\\\": false\\n}\"}}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from vllm import LLM, SamplingParams\n",
|
||||
"\n",
|
||||
"llm = LLM(model=\"./models/run1/merged\", max_num_batched_tokens=4096)\n",
|
||||
"\n",
|
||||
"sampling_params = SamplingParams(\n",
|
||||
" # 120 should be fine for the work we're doing here.\n",
|
||||
" max_tokens=120,\n",
|
||||
" # This is a deterministic task so temperature=0 is best.\n",
|
||||
" temperature=0,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"my_outputs = llm.generate(prompts, sampling_params=sampling_params)\n",
|
||||
"my_outputs = [o.outputs[0].text for o in my_outputs]\n",
|
||||
"\n",
|
||||
"test_data[\"my_outputs\"] = my_outputs\n",
|
||||
"\n",
|
||||
"print(f\"Sample output:\\n--------------\\n{my_outputs[0]}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Ok, we have our outputs! Since there are 5 categories we classify each recipe on, a natural metric would be for each recipe and each category, what percentage of the time our model's output matches GPT-4's. I'll write a quick eval function to check that."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Overall accuracy: 0.91\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def parse_fn_call_args(str):\n",
|
||||
" \"\"\"Parse the function call arguments from the response\"\"\"\n",
|
||||
" response_dict = json.loads(str)\n",
|
||||
" args_dict = json.loads(response_dict[\"function_call\"][\"arguments\"])\n",
|
||||
"\n",
|
||||
" return args_dict\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def calculate_accuracy(row):\n",
|
||||
" \"\"\"Calculate the fraction of my model's outputs that match the reference outputs\"\"\"\n",
|
||||
" true_outputs = parse_fn_call_args(row[\"output\"])\n",
|
||||
" my_outputs = parse_fn_call_args(row[\"my_outputs\"])\n",
|
||||
"\n",
|
||||
" num_matching_outputs = 0\n",
|
||||
" for key in true_outputs.keys():\n",
|
||||
" if true_outputs[key] == my_outputs[key]:\n",
|
||||
" num_matching_outputs += 1\n",
|
||||
"\n",
|
||||
" return num_matching_outputs / len(true_outputs)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"test_data[\"accuracy\"] = test_data.apply(calculate_accuracy, axis=1)\n",
|
||||
"\n",
|
||||
"print(f\"Overall accuracy: {test_data['accuracy'].mean():.2f}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Not bad! Of course, the next obvious step is to look at where Llama 2 is \"wrong\" and evaluate the types of errors it makes. I've exported a Google Sheet where I did exactly that with an earlier version of this model trained on the same dataset. You can see that [here](https://docs.google.com/spreadsheets/d/1vn-nA0CRQwz-BvEYvxUcO1-EP80ZbPhcxDoCTttvsmI/edit?usp=sharing).\n",
|
||||
"\n",
|
||||
"The main takeaway: generally places where GPT-4 and Llama 2 disagreed were genuinely ambiguous cases, where either answer was acceptable (eg. a dish that takes about 30 mins to cook might be classified as over 30 minutes by one, and under 30 minutes by the other).\n",
|
||||
"\n",
|
||||
"Interested in cost/latency benchmarking? You can check out [./benchmarking.ipynb](./benchmarking.ipynb) for an overview of my findings!"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
Reference in New Issue
Block a user