more benchmarking
This commit is contained in:
248
examples/classify-recipes/benchmarking.ipynb
Normal file
248
examples/classify-recipes/benchmarking.ipynb
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"I'm pretty happy with my model's accuracy relative to GPT-4. How does it compare cost-wise?\n",
|
||||||
|
"\n",
|
||||||
|
"I'll really push this to its limits -- let's see how quickly our poor model can classify the [full 2-million-recipe dataset](https://huggingface.co/datasets/corbt/all-recipes) 😈."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Requirement already satisfied: datasets==2.14.4 in /usr/local/lib/python3.10/dist-packages (2.14.4)\n",
|
||||||
|
"Requirement already satisfied: vllm==0.1.3 in /usr/local/lib/python3.10/dist-packages (0.1.3)\n",
|
||||||
|
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (1.24.4)\n",
|
||||||
|
"Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (12.0.1)\n",
|
||||||
|
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.3.7)\n",
|
||||||
|
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.0.3)\n",
|
||||||
|
"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.28.1)\n",
|
||||||
|
"Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (4.66.1)\n",
|
||||||
|
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.3.0)\n",
|
||||||
|
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.70.15)\n",
|
||||||
|
"Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2023.6.0)\n",
|
||||||
|
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.8.5)\n",
|
||||||
|
"Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.16.4)\n",
|
||||||
|
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (23.1)\n",
|
||||||
|
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (6.0)\n",
|
||||||
|
"Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.11.1)\n",
|
||||||
|
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (5.9.5)\n",
|
||||||
|
"Requirement already satisfied: ray>=2.5.1 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.6.3)\n",
|
||||||
|
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.1.99)\n",
|
||||||
|
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.0.1+cu118)\n",
|
||||||
|
"Requirement already satisfied: transformers>=4.31.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (4.33.0.dev0)\n",
|
||||||
|
"Requirement already satisfied: xformers>=0.0.19 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.0.21)\n",
|
||||||
|
"Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.101.1)\n",
|
||||||
|
"Requirement already satisfied: uvicorn in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.23.2)\n",
|
||||||
|
"Requirement already satisfied: pydantic<2 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.10.12)\n",
|
||||||
|
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (23.1.0)\n",
|
||||||
|
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (2.1.1)\n",
|
||||||
|
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (6.0.4)\n",
|
||||||
|
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (4.0.3)\n",
|
||||||
|
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.9.2)\n",
|
||||||
|
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.4.0)\n",
|
||||||
|
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.3.1)\n",
|
||||||
|
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (3.9.0)\n",
|
||||||
|
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (4.7.1)\n",
|
||||||
|
"Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (8.1.7)\n",
|
||||||
|
"Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.18.0)\n",
|
||||||
|
"Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.0.5)\n",
|
||||||
|
"Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.24.1)\n",
|
||||||
|
"Requirement already satisfied: grpcio>=1.42.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.57.0)\n",
|
||||||
|
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (3.4)\n",
|
||||||
|
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (1.26.13)\n",
|
||||||
|
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (2022.12.7)\n",
|
||||||
|
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (1.11.1)\n",
|
||||||
|
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.0)\n",
|
||||||
|
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.1.2)\n",
|
||||||
|
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (2.0.0)\n",
|
||||||
|
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (3.25.0)\n",
|
||||||
|
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (15.0.7)\n",
|
||||||
|
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (2023.8.8)\n",
|
||||||
|
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.13.3)\n",
|
||||||
|
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.3.2)\n",
|
||||||
|
"Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi->vllm==0.1.3) (0.27.0)\n",
|
||||||
|
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2.8.2)\n",
|
||||||
|
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
|
||||||
|
"Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
|
||||||
|
"Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn->vllm==0.1.3) (0.14.0)\n",
|
||||||
|
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->datasets==2.14.4) (1.16.0)\n",
|
||||||
|
"Requirement already satisfied: anyio<5,>=3.4.0 in /usr/local/lib/python3.10/dist-packages (from starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (3.7.1)\n",
|
||||||
|
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.0.0->vllm==0.1.3) (2.1.2)\n",
|
||||||
|
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (2023.6.1)\n",
|
||||||
|
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.29.1)\n",
|
||||||
|
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.8.10)\n",
|
||||||
|
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.0.0->vllm==0.1.3) (1.2.1)\n",
|
||||||
|
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.3.0)\n",
|
||||||
|
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.1.2)\n",
|
||||||
|
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
|
||||||
|
"\u001b[0m\n",
|
||||||
|
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
|
||||||
|
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.10 -m pip install --upgrade pip\u001b[0m\n",
|
||||||
|
"Note: you may need to restart the kernel to use updated packages.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"%pip install datasets==2.14.4 vllm==0.1.3"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Number of recipes: 2,147,248\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from datasets import load_dataset\n",
|
||||||
|
"\n",
|
||||||
|
"all_recipes = load_dataset(\"corbt/all-recipes\")[\"train\"][\"input\"]\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"Number of recipes: {len(all_recipes):,}\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"INFO 08-24 19:38:29 llm_engine.py:70] Initializing an LLM engine with config: model='./models/run1/merged', tokenizer='./models/run1/merged', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n",
|
||||||
|
"INFO 08-24 19:39:48 llm_engine.py:196] # GPU blocks: 3419, # CPU blocks: 512\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from vllm import LLM, SamplingParams\n",
|
||||||
|
"\n",
|
||||||
|
"llm = LLM(model=\"./models/run1/merged\", max_num_batched_tokens=4096)\n",
|
||||||
|
"\n",
|
||||||
|
"sampling_params = SamplingParams(\n",
|
||||||
|
" # 120 should be fine for the work we're doing here.\n",
|
||||||
|
" max_tokens=120,\n",
|
||||||
|
" # This is a deterministic task so temperature=0 is best.\n",
|
||||||
|
" temperature=0,\n",
|
||||||
|
")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Start time: 1692906050.3340027\n",
|
||||||
|
"Processing recipes 0 to 10,000...\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Processed prompts: 100%|██████████| 10000/10000 [04:51<00:00, 34.30it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Processing recipes 10,000 to 20,000...\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Processed prompts: 100%|██████████| 10000/10000 [04:54<00:00, 33.98it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Processing recipes 20,000 to 30,000...\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Processed prompts: 30%|███ | 3008/10000 [01:29<03:27, 33.73it/s]"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# We'll process our recipes in batches of 10,000.\n",
|
||||||
|
"\n",
|
||||||
|
"import time\n",
|
||||||
|
"\n",
|
||||||
|
"BATCH_SIZE = 10000\n",
|
||||||
|
"all_outputs = []\n",
|
||||||
|
"\n",
|
||||||
|
"start_time = time.time()\n",
|
||||||
|
"print(f\"Start time: {start_time}\")\n",
|
||||||
|
"for i in range(0, len(all_recipes), BATCH_SIZE):\n",
|
||||||
|
" print(f\"Processing recipes {i:,} to {i+BATCH_SIZE:,}...\")\n",
|
||||||
|
" outputs = llm.generate(\n",
|
||||||
|
" all_recipes[i : i + BATCH_SIZE], sampling_params=sampling_params\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" all_outputs.extend([o.outputs[0].text for o in outputs])\n",
|
||||||
|
"\n",
|
||||||
|
"end_time = time.time()\n",
|
||||||
|
"print(f\"End time: {end_time}\")\n",
|
||||||
|
"print(f\"Total hours: {((end_time - start_time) / 3600):.2f}\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Nice! I've processed all 2,147,248 recipes in under 17 hours. Let's do a cost comparison with GPT-3.5 and GPT-4. I'll use the GPT-4 latency/cost numbers based on the 5000 samples used to generate our model's training data."
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.6"
|
||||||
|
},
|
||||||
|
"orig_nbformat": 4
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
||||||
@@ -9,7 +9,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 3,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@@ -111,18 +111,20 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 7,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"ename": "ModuleNotFoundError",
|
"name": "stdout",
|
||||||
"evalue": "No module named 'axolotl.prompters'",
|
"output_type": "stream",
|
||||||
"output_type": "error",
|
"text": [
|
||||||
"traceback": [
|
"Sample prompt:\n",
|
||||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
"--------------\n",
|
||||||
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
"### Instruction:\n",
|
||||||
"Cell \u001b[0;32mIn[5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39maxolotl\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mprompters\u001b[39;00m \u001b[39mimport\u001b[39;00m UnpromptedPrompter\n\u001b[1;32m 2\u001b[0m prompter \u001b[39m=\u001b[39m UnpromptedPrompter()\n\u001b[1;32m 4\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mformat_prompt\u001b[39m(\u001b[39minput\u001b[39m: \u001b[39mstr\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mstr\u001b[39m:\n",
|
"[{\"role\":\"system\",\"content\":\"Your goal is to classify a recipe along several dimensions. You should \"},{\"role\":\"user\",\"content\":\"Strawberry Sorbet\\n\\nIngredients:\\n- 2 cups chopped strawberries Safeway 1 lb For $3.99 thru 02/09\\n- 1 cup cold water\\n- 2 cups boiling water\\n- 1 pkg. (4-serving size) JELL-O Strawberry Flavor Gelatin\\n- 1/2 cup sugar\\n\\nDirections:\\n- Place strawberries and cold water in blender container; cover.\\n- Blend on high speed until smooth.\\n- Stir boiling water into combined dry gelatin mix and sugar in medium bowl at least 2 minutes until completely dissolved.\\n- Add strawberry mixture; mix well.\\n- Pour into 9-inch square pan.\\n- Freeze 1 to 1-1/2 hours or until ice crystals form 1 inch around edges of pan.\\n- Spoon half of the gelatin mixture into blender container; cover.\\n- Blend on high speed about 30 seconds or until smooth; pour into bowl.\\n- Repeat with remaining gelatin mixture.\\n- Add to blended gelatin mixture in bowl; mix well.\\n- Return to pan.\\n- Freeze 6 hours or overnight until firm.\\n- Scoop into dessert dishes to serve.\\n- Store leftover sorbet in freezer.\"}]\n",
|
||||||
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'axolotl.prompters'"
|
"\n",
|
||||||
|
"### Response:\n",
|
||||||
|
"\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -136,9 +138,134 @@
|
|||||||
" return next(prompter.build_prompt(input))\n",
|
" return next(prompter.build_prompt(input))\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"prompts = test_data[\"input\"].apply(format_prompt)\n",
|
"prompts = test_data[\"instruction\"].apply(format_prompt)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(f\"Sample prompt:\\n-----------\\n{prompts[0]}\")\n"
|
"print(f\"Sample prompt:\\n--------------\\n{prompts[0]}\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Next up, I'll use [vLLM](https://vllm.readthedocs.io/en/latest/) to efficiently process all the prompts in our test data with our own model."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"INFO 08-24 18:58:03 llm_engine.py:70] Initializing an LLM engine with config: model='./models/run1/merged', tokenizer='./models/run1/merged', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n",
|
||||||
|
"INFO 08-24 18:59:18 llm_engine.py:196] # GPU blocks: 3419, # CPU blocks: 512\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Processed prompts: 100%|██████████| 201/201 [00:16<00:00, 12.18it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Sample output:\n",
|
||||||
|
"--------------\n",
|
||||||
|
"{\"role\":\"assistant\",\"content\":null,\"function_call\":{\"name\":\"classify\",\"arguments\":\"{\\n\\\"has_non_fish_meat\\\": false,\\n\\\"requires_oven\\\": false,\\n\\\"requires_stove\\\": true,\\n\\\"cook_time_over_30_mins\\\": false,\\n\\\"main_course\\\": false\\n}\"}}\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from vllm import LLM, SamplingParams\n",
|
||||||
|
"\n",
|
||||||
|
"llm = LLM(model=\"./models/run1/merged\", max_num_batched_tokens=4096)\n",
|
||||||
|
"\n",
|
||||||
|
"sampling_params = SamplingParams(\n",
|
||||||
|
" # 120 should be fine for the work we're doing here.\n",
|
||||||
|
" max_tokens=120,\n",
|
||||||
|
" # This is a deterministic task so temperature=0 is best.\n",
|
||||||
|
" temperature=0,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"my_outputs = llm.generate(prompts, sampling_params=sampling_params)\n",
|
||||||
|
"my_outputs = [o.outputs[0].text for o in my_outputs]\n",
|
||||||
|
"\n",
|
||||||
|
"test_data[\"my_outputs\"] = my_outputs\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"Sample output:\\n--------------\\n{my_outputs[0]}\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Ok, we have our outputs! Since there are 5 categories we classify each recipe on, a natural metric would be for each recipe and each category, what percentage of the time our model's output matches GPT-4's. I'll write a quick eval function to check that."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 16,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Overall accuracy: 0.91\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import json\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def parse_fn_call_args(str):\n",
|
||||||
|
" \"\"\"Parse the function call arguments from the response\"\"\"\n",
|
||||||
|
" response_dict = json.loads(str)\n",
|
||||||
|
" args_dict = json.loads(response_dict[\"function_call\"][\"arguments\"])\n",
|
||||||
|
"\n",
|
||||||
|
" return args_dict\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def calculate_accuracy(row):\n",
|
||||||
|
" \"\"\"Calculate the fraction of my model's outputs that match the reference outputs\"\"\"\n",
|
||||||
|
" true_outputs = parse_fn_call_args(row[\"output\"])\n",
|
||||||
|
" my_outputs = parse_fn_call_args(row[\"my_outputs\"])\n",
|
||||||
|
"\n",
|
||||||
|
" num_matching_outputs = 0\n",
|
||||||
|
" for key in true_outputs.keys():\n",
|
||||||
|
" if true_outputs[key] == my_outputs[key]:\n",
|
||||||
|
" num_matching_outputs += 1\n",
|
||||||
|
"\n",
|
||||||
|
" return num_matching_outputs / len(true_outputs)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"test_data[\"accuracy\"] = test_data.apply(calculate_accuracy, axis=1)\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"Overall accuracy: {test_data['accuracy'].mean():.2f}\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Not bad! Of course, the next obvious step is to look at where Llama 2 is \"wrong\" and evaluate the types of errors it makes. I've exported a Google Sheet where I did exactly that with an earlier version of this model trained on the same dataset. You can see that [here](https://docs.google.com/spreadsheets/d/1vn-nA0CRQwz-BvEYvxUcO1-EP80ZbPhcxDoCTttvsmI/edit?usp=sharing).\n",
|
||||||
|
"\n",
|
||||||
|
"The main takeaway: generally places where GPT-4 and Llama 2 disagreed were genuinely ambiguous cases, where either answer was acceptable (eg. a dish that takes about 30 mins to cook might be classified as over 30 minutes by one, and under 30 minutes by the other).\n",
|
||||||
|
"\n",
|
||||||
|
"Interested in cost/latency benchmarking? You can check out [./benchmarking.ipynb](./benchmarking.ipynb) for an overview of my findings!"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 11,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@@ -19,6 +19,8 @@
|
|||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Requirement already satisfied: openpipe==3.0.3 in /usr/local/lib/python3.10/dist-packages (3.0.3)\n",
|
"Requirement already satisfied: openpipe==3.0.3 in /usr/local/lib/python3.10/dist-packages (3.0.3)\n",
|
||||||
|
"Requirement already satisfied: python-dotenv==1.0.0 in /usr/local/lib/python3.10/dist-packages (1.0.0)\n",
|
||||||
|
"Requirement already satisfied: joblib==1.3.2 in /usr/local/lib/python3.10/dist-packages (1.3.2)\n",
|
||||||
"Requirement already satisfied: attrs<24.0.0,>=23.1.0 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (23.1.0)\n",
|
"Requirement already satisfied: attrs<24.0.0,>=23.1.0 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (23.1.0)\n",
|
||||||
"Requirement already satisfied: httpx<0.25.0,>=0.24.1 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (0.24.1)\n",
|
"Requirement already satisfied: httpx<0.25.0,>=0.24.1 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (0.24.1)\n",
|
||||||
"Requirement already satisfied: openai<0.28.0,>=0.27.8 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (0.27.9)\n",
|
"Requirement already satisfied: openai<0.28.0,>=0.27.8 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (0.27.9)\n",
|
||||||
@@ -63,14 +65,15 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 12,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Recipe dataset:\n"
|
"Recipe dataset shape:\n",
|
||||||
|
"------------------\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -90,7 +93,7 @@
|
|||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"First recipe:\n",
|
"First recipe:\n",
|
||||||
" Shrimp Creole\n",
|
"------------------ Shrimp Creole\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Ingredients:\n",
|
"Ingredients:\n",
|
||||||
"- 20 shrimp (8 oz.)\n",
|
"- 20 shrimp (8 oz.)\n",
|
||||||
@@ -128,27 +131,37 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"Mm, delicious. Anyway, we need to generate a training dataset. We'll call GPT-4 on each of our examples.\n",
|
"Mm, delicious. Anyway, we need to generate a training dataset. We'll call GPT-4 on each of our examples.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"We'll use [OpenPipe](https://github.com/openpipe/openpipe) to track our calls and form a training dataset. Create an account and a project, then copy your API key from https://app.openpipe.ai/project/settings into a file called `.env`."
|
"In this case, I'll ask GPT-4 to classify each recipe along 5 dimensions:\n",
|
||||||
|
" - has_non_fish_meat\n",
|
||||||
|
" - requires_oven\n",
|
||||||
|
" - requires_stove\n",
|
||||||
|
" - cook_time_over_30_mins\n",
|
||||||
|
" - main_dish\n",
|
||||||
|
"\n",
|
||||||
|
"That looks like a pretty random list, but there's actually an important unifying thread: we're looking for meals that my pescatarian brother can eat in his kitchen-less, near-window-less basement apartment in San Francisco! (If you haven't tried to get an apartment in SF you probably think I'm joking 😂.)\n",
|
||||||
|
"\n",
|
||||||
|
"We'll use [OpenPipe](https://github.com/openpipe/openpipe) to track our calls and form a training dataset. Create an account and a project, then copy your API key from https://app.openpipe.ai/project/settings into a file called `.env`. You can see an example in [./.env.example](./.env.example)."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": 13,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"name": "stdout",
|
||||||
"text/plain": [
|
"output_type": "stream",
|
||||||
"{'has_non_fish_meat': False,\n",
|
"text": [
|
||||||
" 'requires_oven': True,\n",
|
"Classifying first recipe:\n",
|
||||||
" 'requires_stove': True,\n",
|
"------------------\n"
|
||||||
" 'cook_time_over_30_mins': False,\n",
|
]
|
||||||
" 'main_dish': False}"
|
},
|
||||||
]
|
{
|
||||||
},
|
"name": "stdout",
|
||||||
"execution_count": 9,
|
"output_type": "stream",
|
||||||
"metadata": {},
|
"text": [
|
||||||
"output_type": "execute_result"
|
"{'has_non_fish_meat': False, 'requires_oven': False, 'requires_stove': True, 'cook_time_over_30_mins': True, 'main_dish': True}\n"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -157,10 +170,13 @@
|
|||||||
"import os\n",
|
"import os\n",
|
||||||
"import dotenv\n",
|
"import dotenv\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"# Use `dotenv` to load the contents of the `.env` file into the environment\n",
|
||||||
"dotenv.load_dotenv()\n",
|
"dotenv.load_dotenv()\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"# Configure OpenPipe using the API key from the environment\n",
|
||||||
"configure_openpipe(api_key=os.environ[\"OPENPIPE_API_KEY\"])\n",
|
"configure_openpipe(api_key=os.environ[\"OPENPIPE_API_KEY\"])\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"# Configure OpenAI using the API key from the environment\n",
|
||||||
"openai.api_key = os.environ[\"OPENAI_API_KEY\"]\n",
|
"openai.api_key = os.environ[\"OPENAI_API_KEY\"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -222,12 +238,20 @@
|
|||||||
" return json.loads(completion.choices[0].message.function_call.arguments)\n",
|
" return json.loads(completion.choices[0].message.function_call.arguments)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"classify_recipe(recipes[\"recipe\"][-1])\n"
|
"print(\"Classifying first recipe:\\n------------------\")\n",
|
||||||
|
"print(classify_recipe(recipes[\"recipe\"][0]))\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"That's working, so I'll go ahead and classify all 5000 recipes with GPT-4. Using GPT-4 for this is slowwww and costs about $40. The model I'm fine-tuning will be much faster -- let's see if we can make it as good!"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 10,
|
"execution_count": 14,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@@ -238,8 +262,6 @@
|
|||||||
"Classifying recipe 100/5000: Spoon Bread\n",
|
"Classifying recipe 100/5000: Spoon Bread\n",
|
||||||
"Classifying recipe 200/5000: Quadrangle Grille'S Pumpkin-Walnut Cheesecake\n",
|
"Classifying recipe 200/5000: Quadrangle Grille'S Pumpkin-Walnut Cheesecake\n",
|
||||||
"Classifying recipe 300/5000: Broccoli Casserole\n",
|
"Classifying recipe 300/5000: Broccoli Casserole\n",
|
||||||
"Error reporting to OpenPipe: 520 is not a valid HTTPStatus\n",
|
|
||||||
"520 is not a valid HTTPStatus\n",
|
|
||||||
"Classifying recipe 400/5000: Paal Payasam (3-Ingredient Rice Pudding)\n",
|
"Classifying recipe 400/5000: Paal Payasam (3-Ingredient Rice Pudding)\n",
|
||||||
"Classifying recipe 500/5000: Dirt Dessert\n",
|
"Classifying recipe 500/5000: Dirt Dessert\n",
|
||||||
"Classifying recipe 600/5000: Dolma, Stuffed Dried Peppers And Eggplants\n",
|
"Classifying recipe 600/5000: Dolma, Stuffed Dried Peppers And Eggplants\n",
|
||||||
@@ -265,21 +287,22 @@
|
|||||||
"Classifying recipe 2600/5000: Pepperoni Bread\n",
|
"Classifying recipe 2600/5000: Pepperoni Bread\n",
|
||||||
"Classifying recipe 2700/5000: Sabzi Polow\n",
|
"Classifying recipe 2700/5000: Sabzi Polow\n",
|
||||||
"Classifying recipe 2800/5000: Italian Vegetable Pizzas\n",
|
"Classifying recipe 2800/5000: Italian Vegetable Pizzas\n",
|
||||||
"Error classifying recipe 2801: Bad gateway. {\"error\":{\"code\":502,\"message\":\"Bad gateway.\",\"param\":null,\"type\":\"cf_bad_gateway\"}} 502 {'error': {'code': 502, 'message': 'Bad gateway.', 'param': None, 'type': 'cf_bad_gateway'}} {'Date': 'Thu, 24 Aug 2023 15:44:45 GMT', 'Content-Type': 'application/json', 'Content-Length': '84', 'Connection': 'keep-alive', 'X-Frame-Options': 'SAMEORIGIN', 'Referrer-Policy': 'same-origin', 'Cache-Control': 'private, max-age=0, no-store, no-cache, must-revalidate, post-check=0, pre-check=0', 'Expires': 'Thu, 01 Jan 1970 00:00:01 GMT', 'Server': 'cloudflare', 'CF-RAY': '7fbca943df684de1-MCI', 'alt-svc': 'h3=\":443\"; ma=86400'}\n",
|
|
||||||
"Classifying recipe 2900/5000: Hot Fudge Sauce, Soda Shop Style\n",
|
"Classifying recipe 2900/5000: Hot Fudge Sauce, Soda Shop Style\n",
|
||||||
"Classifying recipe 3000/5000: Meatball Soup With Vegetables And Brown Rice\n",
|
"Classifying recipe 3000/5000: Meatball Soup With Vegetables And Brown Rice\n",
|
||||||
"Classifying recipe 3100/5000: Herbed Potatoes And Onions\n",
|
"Classifying recipe 3100/5000: Herbed Potatoes And Onions\n",
|
||||||
"Classifying recipe 3200/5000: Apple Crunch Pie (2 Extra Servings)\n",
|
"Classifying recipe 3200/5000: Apple Crunch Pie (2 Extra Servings)\n",
|
||||||
"Classifying recipe 3300/5000: Pineapple-Orange Punch\n",
|
"Classifying recipe 3300/5000: Pineapple-Orange Punch\n",
|
||||||
"Classifying recipe 3400/5000: Turkey Veggie Burgers With Avocado Mayo\n",
|
"Classifying recipe 3400/5000: Turkey Veggie Burgers With Avocado Mayo\n",
|
||||||
"Error reporting to OpenPipe: 520 is not a valid HTTPStatus\n",
|
|
||||||
"520 is not a valid HTTPStatus\n",
|
|
||||||
"Classifying recipe 3500/5000: Pear & Goat Cheese Salad\n",
|
"Classifying recipe 3500/5000: Pear & Goat Cheese Salad\n",
|
||||||
"Classifying recipe 3600/5000: Triple Chocolate Cookies\n",
|
"Classifying recipe 3600/5000: Triple Chocolate Cookies\n",
|
||||||
"Classifying recipe 3700/5000: Strawberry Banana Yogurt Pops\n",
|
"Classifying recipe 3700/5000: Strawberry Banana Yogurt Pops\n",
|
||||||
"Error classifying recipe 3779: Request timed out: HTTPSConnectionPool(host='api.openai.com', port=443): Read timed out. (read timeout=600)\n",
|
|
||||||
"Classifying recipe 3800/5000: Chicken Croquettes\n",
|
"Classifying recipe 3800/5000: Chicken Croquettes\n",
|
||||||
"Classifying recipe 3900/5000: Mushroom Casserole\n"
|
"Classifying recipe 3900/5000: Mushroom Casserole\n",
|
||||||
|
"Classifying recipe 4000/5000: Vegetarian Summer Roll\n",
|
||||||
|
"Classifying recipe 4100/5000: Prune Cake\n",
|
||||||
|
"Classifying recipe 4200/5000: Strawberry Sorbet\n",
|
||||||
|
"Classifying recipe 4300/5000: Lemonade Chicken\n",
|
||||||
|
"Classifying recipe 4400/5000: Crock-Pot Vegetarian Chili\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -295,12 +318,14 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "markdown",
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"Ok, we have our "
|
"Ok, now that my recipes are classified I'll download the training data. \n",
|
||||||
|
"\n",
|
||||||
|
"Next up I'll train the model -- check out [./train.ipynb](./train.ipynb) for details! Just go to https://app.openpipe.ai/request-logs, select all the logs you created, and click \"Export\". The default 10% testing split is fine for this dataset size.\n",
|
||||||
|
"\n",
|
||||||
|
"I got two files from that: `train.jsonl` and `test.jsonl`. I moved both of them into this repository under `./data/`."
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 1,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@@ -51,12 +51,12 @@
|
|||||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.10 -m pip install --upgrade pip\u001b[0m\n",
|
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.10 -m pip install --upgrade pip\u001b[0m\n",
|
||||||
"Note: you may need to restart the kernel to use updated packages.\n",
|
"Note: you may need to restart the kernel to use updated packages.\n",
|
||||||
"fatal: destination path 'axolotl' already exists and is not an empty directory.\n",
|
"fatal: destination path 'axolotl' already exists and is not an empty directory.\n",
|
||||||
"Obtaining file:///workspace/gpt4-fine-tuning/axolotl\n",
|
"Obtaining file:///workspace/OpenPipe/examples/classify-recipes/axolotl\n",
|
||||||
" Preparing metadata (setup.py) ... \u001b[?25ldone\n",
|
" Preparing metadata (setup.py) ... \u001b[?25ldone\n",
|
||||||
"\u001b[?25hCollecting transformers@ git+https://github.com/huggingface/transformers.git (from axolotl==0.1)\n",
|
"\u001b[?25hCollecting transformers@ git+https://github.com/huggingface/transformers.git (from axolotl==0.1)\n",
|
||||||
" Cloning https://github.com/huggingface/transformers.git to /tmp/pip-install-8yfermge/transformers_22e4388baf16446d8445557008e38efe\n",
|
" Cloning https://github.com/huggingface/transformers.git to /tmp/pip-install-o3o9dk76/transformers_99ed72a1465e41bba173c85b3be82a1b\n",
|
||||||
" Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-install-8yfermge/transformers_22e4388baf16446d8445557008e38efe\n",
|
" Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-install-o3o9dk76/transformers_99ed72a1465e41bba173c85b3be82a1b\n",
|
||||||
" Resolved https://github.com/huggingface/transformers.git to commit 4d40109c3a93c9b8bbca204cb046ed510f1c72e8\n",
|
" Resolved https://github.com/huggingface/transformers.git to commit f26099e7b5cf579f99a42bab6ddd371bf2c8d548\n",
|
||||||
" Installing build dependencies ... \u001b[?25ldone\n",
|
" Installing build dependencies ... \u001b[?25ldone\n",
|
||||||
"\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n",
|
"\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n",
|
||||||
"\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n",
|
"\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n",
|
||||||
@@ -152,10 +152,6 @@
|
|||||||
"Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb->axolotl==0.1) (5.0.0)\n",
|
"Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb->axolotl==0.1) (5.0.0)\n",
|
||||||
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.0.0->bert-score==0.3.13->axolotl==0.1) (2.1.2)\n",
|
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.0.0->bert-score==0.3.13->axolotl==0.1) (2.1.2)\n",
|
||||||
"Installing collected packages: axolotl\n",
|
"Installing collected packages: axolotl\n",
|
||||||
" Attempting uninstall: axolotl\n",
|
|
||||||
" Found existing installation: axolotl 0.1\n",
|
|
||||||
" Uninstalling axolotl-0.1:\n",
|
|
||||||
" Successfully uninstalled axolotl-0.1\n",
|
|
||||||
" Running setup.py develop for axolotl\n",
|
" Running setup.py develop for axolotl\n",
|
||||||
"Successfully installed axolotl\n",
|
"Successfully installed axolotl\n",
|
||||||
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
|
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
|
||||||
@@ -421,6 +417,13 @@
|
|||||||
"final_model_dir = merge_lora_model(\"training-config.yaml\")\n",
|
"final_model_dir = merge_lora_model(\"training-config.yaml\")\n",
|
||||||
"print(f\"Final model saved to '{final_model_dir}'\")\n"
|
"print(f\"Final model saved to '{final_model_dir}'\")\n"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Ok, I have a model, but is it actually any good? I'll run some evaluations in [./evaluate.ipynb](./evaluate.ipynb) to check."
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|||||||
Reference in New Issue
Block a user