Compare commits

..

10 Commits

Author SHA1 Message Date
Kyle Corbitt
ead981b900 Disable custom models for the moment
We're running into GPU constraints and need to turn off custom models until we find a better provider or can hot-swap them.
2023-08-25 16:01:51 -07:00
Kyle Corbitt
e0d0cc0df1 Merge pull request #193 from OpenPipe/examples
classify-recipes example
2023-08-25 12:46:01 -07:00
arcticfly
7df1c59bd3 Update README.md 2023-08-24 22:23:40 -07:00
arcticfly
c83863f468 Update README.md 2023-08-24 20:59:46 -07:00
arcticfly
33ca98b267 Enlarge fine-tune gif 2023-08-24 14:14:38 -07:00
arcticfly
39c943f2ec Change layout of README.md 2023-08-24 14:13:01 -07:00
arcticfly
2aa4ac1594 Update opening gif in README.md 2023-08-24 12:46:25 -07:00
arcticfly
42ade01f22 Update README.md 2023-08-24 11:14:25 -07:00
David Corbitt
59b79049c1 Move license to top level 2023-08-24 10:41:23 -07:00
arcticfly
0d7433cb7e Update README.md
Include more models
2023-08-24 00:13:24 -07:00
13 changed files with 1066 additions and 636 deletions

View File

@@ -1,14 +1,52 @@
# OpenPipe
<p align="center">
<a href="https://openpipe.ai">
<img height="70" src="https://github.com/openpipe/openpipe/assets/41524992/70af25fb-1f90-42d9-8a20-3606e3b5aaba" alt="logo">
</a>
</p>
<h1 align="center">
OpenPipe
</h1>
OpenPipe is a flexible playground for comparing and optimizing LLM prompts. It lets you quickly generate, test and compare candidate prompts, and can automatically [translate](#-translate-between-model-apis) those prompts between models.
<p align="center">
<i>Turn expensive prompts into cheap fine-tuned models.</i>
</p>
<img src="https://github.com/openpipe/openpipe/assets/41524992/66bb1843-cb72-4130-a369-eec2df3b8201" alt="demo">
<p align="center">
<a href="/LICENSE"><img alt="License Apache-2.0" src="https://img.shields.io/github/license/openpipe/openpipe?style=flat-square"></a>
<a href='http://makeapullrequest.com'><img alt='PRs Welcome' src='https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square'/></a>
<a href="https://github.com/openpipe/openpipe/graphs/commit-activity"><img alt="GitHub commit activity" src="https://img.shields.io/github/commit-activity/m/openpipe/openpipe?style=flat-square"/></a>
<a href="https://github.com/openpipe/openpipe/issues"><img alt="GitHub closed issues" src="https://img.shields.io/github/issues-closed/openpipe/openpipe?style=flat-square"/></a>
</p>
<p align="center">
<a href="https://app.openpipe.ai/">Hosted App</a> - <a href="#running-locally">Running Locally</a> - <a href="#sample-experiments">Experiments</a>
</p>
<br>
Use powerful but expensive LLMs to fine-tune smaller and cheaper models suited to your exact needs. Evaluate model and prompt combinations in the playground. Query your past requests and export optimized training data. Try it out at https://app.openpipe.ai or <a href="#running-locally">run it locally</a>.
<br>
## 🪛 Features
* <b>Experiment</b>
* Bulk-test wide-reaching scenarios using code templating.
* Seamlessly translate prompts across different model APIs.
* Tap into autogenerated scenarios for fresh test perspectives.
* <b>Fine-Tune (Beta)</b>
* Easy integration with OpenPipe's SDK in both Python and JS.
* Swiftly query logs using intuitive built-in filters.
* Export data in multiple training formats, including Alpaca and ChatGPT, with deduplication.
<img src="https://github.com/openpipe/openpipe/assets/41524992/eaa8b92d-4536-4f63-bbef-4b0b1a60f6b5" alt="fine-tune demo">
<!-- <img height="400px" src="https://github.com/openpipe/openpipe/assets/41524992/66bb1843-cb72-4130-a369-eec2df3b8201" alt="playground demo"> -->
You can use our hosted version of OpenPipe at https://openpipe.ai. You can also clone this repository and [run it locally](#running-locally).
## Sample Experiments
These are simple experiments users have created that show how OpenPipe works. Feel free to fork them and start experimenting yourself.
These are sample experiments users have created that show how OpenPipe works. Feel free to fork them and start experimenting yourself.
- [Twitter Sentiment Analysis](https://app.openpipe.ai/experiments/62c20a73-2012-4a64-973c-4b665ad46a57)
- [Reddit User Needs](https://app.openpipe.ai/experiments/22222222-2222-2222-2222-222222222222)
@@ -17,37 +55,25 @@ These are simple experiments users have created that show how OpenPipe works. Fe
## Supported Models
- All models available through the OpenAI [chat completion API](https://platform.openai.com/docs/guides/gpt/chat-completions-api)
- Llama2 [7b chat](https://replicate.com/a16z-infra/llama7b-v2-chat), [13b chat](https://replicate.com/a16z-infra/llama13b-v2-chat), [70b chat](https://replicate.com/replicate/llama70b-v2-chat).
- Anthropic's [Claude 1 Instant](https://www.anthropic.com/index/introducing-claude) and [Claude 2](https://www.anthropic.com/index/claude-2)
## Features
### 🔍 Visualize Responses
Inspect prompt completions side-by-side.
### 🧪 Bulk-Test
OpenPipe lets you _template_ a prompt. Use the templating feature to run the prompts you're testing against many potential inputs for broad coverage of your problem space.
### 📟 Translate between Model APIs
Write your prompt in one format and automatically convert it to work with any other model.
<!-- <img width="480" alt="Screenshot 2023-08-01 at 11 55 38 PM" src="https://github.com/OpenPipe/OpenPipe/assets/41524992/1e19ccf2-96b6-4e93-a3a5-1449710d1b5b" alt="translate between models"> -->
### 🛠️ Refine Your Prompts Automatically
Use a growing database of best-practice refinements to improve your prompts automatically.
<!-- <img width="480" alt="Screenshot 2023-08-01 at 11 55 38 PM" src="https://github.com/OpenPipe/OpenPipe/assets/41524992/87a27fe7-daef-445c-a5e2-1c82b23f9f99" alt="add function call"> -->
### 🪄 Auto-generate Test Scenarios
OpenPipe includes a tool to generate new test scenarios based on your existing prompts and scenarios. Just click "Autogenerate Scenario" to try it out!
<!-- <img width="600" src="https://github.com/openpipe/openpipe/assets/41524992/219a844e-3f4e-4f6b-8066-41348b42977b" alt="auto-generate"> -->
#### OpenAI
- [GPT 3.5 Turbo](https://platform.openai.com/docs/guides/gpt/chat-completions-api)
- [GPT 3.5 Turbo 16k](https://platform.openai.com/docs/guides/gpt/chat-completions-api)
- [GPT 4](https://openai.com/gpt-4)
#### Llama2
- [7b chat](https://replicate.com/a16z-infra/llama7b-v2-chat)
- [13b chat](https://replicate.com/a16z-infra/llama13b-v2-chat)
- [70b chat](https://replicate.com/replicate/llama70b-v2-chat)
#### Llama2 Fine-Tunes
- [Open-Orca/OpenOrcaxOpenChat-Preview2-13B](https://huggingface.co/Open-Orca/OpenOrcaxOpenChat-Preview2-13B)
- [Open-Orca/OpenOrca-Platypus2-13B](https://huggingface.co/Open-Orca/OpenOrca-Platypus2-13B)
- [NousResearch/Nous-Hermes-Llama2-13b](https://huggingface.co/NousResearch/Nous-Hermes-Llama2-13b)
- [jondurbin/airoboros-l2-13b-gpt4-2.0](https://huggingface.co/jondurbin/airoboros-l2-13b-gpt4-2.0)
- [lmsys/vicuna-13b-v1.5](https://huggingface.co/lmsys/vicuna-13b-v1.5)
- [Gryphe/MythoMax-L2-13b](https://huggingface.co/Gryphe/MythoMax-L2-13b)
- [NousResearch/Nous-Hermes-llama-2-7b](https://huggingface.co/NousResearch/Nous-Hermes-llama-2-7b)
#### Anthropic
- [Claude 1 Instant](https://www.anthropic.com/index/introducing-claude)
- [Claude 2](https://www.anthropic.com/index/claude-2)
## Running Locally

View File

@@ -147,9 +147,10 @@ export default function OutputCell({
<ResponseLog
time={response.receivedAt}
title="Response received from API"
message={`statusCode: ${response.statusCode ?? ""}\n ${
response.errorMessage ?? ""
}`}
message={[
response.statusCode ? `Status: ${response.statusCode}\n` : "",
response.errorMessage ?? "",
].join("")}
/>
)}
</Fragment>

View File

@@ -17,10 +17,23 @@ const modelEndpoints: Record<OpenpipeChatInput["model"], string> = {
"NousResearch/Nous-Hermes-llama-2-7b": "https://ua1bpc6kv3dgge-8000.proxy.runpod.net/v1",
};
const CUSTOM_MODELS_ENABLED = false;
export async function getCompletion(
input: OpenpipeChatInput,
onStream: ((partialOutput: OpenpipeChatOutput) => void) | null,
): Promise<CompletionResponse<OpenpipeChatOutput>> {
// Temporarily disable these models because of GPU constraints
if (!CUSTOM_MODELS_ENABLED) {
return {
type: "error",
message:
"We've disabled this model temporarily because of GPU capacity constraints. Check back later.",
autoRetry: false,
};
}
const { model, messages, ...rest } = input;
const templatedPrompt = frontendModelProvider.models[model].templatePrompt?.(messages);

4
examples/.gitignore vendored
View File

@@ -1,6 +1,4 @@
axolotl/
models/
data/
wandb/
cache/
.ipynb_checkpoints/
wandb/

View File

@@ -0,0 +1,473 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Current Time: 2023-08-24 21:25:06\n",
"Current Time: 2023-08-24 21:25:36\n"
]
}
],
"source": [
"import time\n",
"\n",
"while True:\n",
" current_time = time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime())\n",
" print(f\"Current Time: {current_time}\")\n",
" time.sleep(30)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"I'm pretty happy with my model's accuracy relative to GPT-4. How does it compare cost-wise?\n",
"\n",
"I'll really push this to its limits -- let's see how quickly our poor model can classify the [full 2-million-recipe dataset](https://huggingface.co/datasets/corbt/all-recipes) 😈."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: datasets==2.14.4 in /usr/local/lib/python3.10/dist-packages (2.14.4)\n",
"Requirement already satisfied: vllm==0.1.3 in /usr/local/lib/python3.10/dist-packages (0.1.3)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (1.24.4)\n",
"Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (12.0.1)\n",
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.3.7)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.0.3)\n",
"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.28.1)\n",
"Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (4.66.1)\n",
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.3.0)\n",
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.70.15)\n",
"Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2023.6.0)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.8.5)\n",
"Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.16.4)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (23.1)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (6.0)\n",
"Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (5.9.5)\n",
"Requirement already satisfied: ray>=2.5.1 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.6.3)\n",
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.1.99)\n",
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.0.1+cu118)\n",
"Requirement already satisfied: transformers>=4.31.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (4.33.0.dev0)\n",
"Requirement already satisfied: xformers>=0.0.19 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.0.21)\n",
"Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.101.1)\n",
"Requirement already satisfied: uvicorn in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.23.2)\n",
"Requirement already satisfied: pydantic<2 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.10.12)\n",
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (23.1.0)\n",
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (2.1.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (4.0.3)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.9.2)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.4.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.3.1)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (3.9.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (4.7.1)\n",
"Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (8.1.7)\n",
"Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.18.0)\n",
"Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.0.5)\n",
"Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.24.1)\n",
"Requirement already satisfied: grpcio>=1.42.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.57.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (3.4)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (1.26.13)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (2022.12.7)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.0)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.1.2)\n",
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (2.0.0)\n",
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (3.25.0)\n",
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (15.0.7)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (2023.8.8)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.3.2)\n",
"Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi->vllm==0.1.3) (0.27.0)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
"Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn->vllm==0.1.3) (0.14.0)\n",
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->datasets==2.14.4) (1.16.0)\n",
"Requirement already satisfied: anyio<5,>=3.4.0 in /usr/local/lib/python3.10/dist-packages (from starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (3.7.1)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.0.0->vllm==0.1.3) (2.1.2)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (2023.6.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.29.1)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.8.10)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.0.0->vllm==0.1.3) (1.2.1)\n",
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.3.0)\n",
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.1.2)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install datasets==2.14.4 vllm==0.1.3"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of recipes: 2,147,248\n"
]
}
],
"source": [
"from datasets import load_dataset\n",
"\n",
"all_recipes = load_dataset(\"corbt/all-recipes\")[\"train\"][\"input\"]\n",
"\n",
"print(f\"Number of recipes: {len(all_recipes):,}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO 08-24 19:38:29 llm_engine.py:70] Initializing an LLM engine with config: model='./models/run1/merged', tokenizer='./models/run1/merged', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n",
"INFO 08-24 19:39:48 llm_engine.py:196] # GPU blocks: 3419, # CPU blocks: 512\n"
]
}
],
"source": [
"from vllm import LLM, SamplingParams\n",
"\n",
"llm = LLM(model=\"./models/run1/merged\", max_num_batched_tokens=4096)\n",
"\n",
"sampling_params = SamplingParams(\n",
" # 120 should be fine for the work we're doing here.\n",
" max_tokens=120,\n",
" # This is a deterministic task so temperature=0 is best.\n",
" temperature=0,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Start time: 1692906050.3340027\n",
"Processing recipes 0 to 10,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 10000/10000 [04:51<00:00, 34.30it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing recipes 10,000 to 20,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 10000/10000 [04:54<00:00, 33.98it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing recipes 20,000 to 30,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 10000/10000 [04:53<00:00, 34.11it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing recipes 30,000 to 40,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 10000/10000 [04:53<00:00, 34.11it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing recipes 40,000 to 50,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 48%|████▊ | 4796/10000 [02:21<03:18, 26.22it/s]"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[6], line 12\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(\u001b[39m0\u001b[39m, \u001b[39mlen\u001b[39m(all_recipes), BATCH_SIZE):\n\u001b[1;32m 11\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mProcessing recipes \u001b[39m\u001b[39m{\u001b[39;00mi\u001b[39m:\u001b[39;00m\u001b[39m,\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m to \u001b[39m\u001b[39m{\u001b[39;00mi\u001b[39m+\u001b[39mBATCH_SIZE\u001b[39m:\u001b[39;00m\u001b[39m,\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m...\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m---> 12\u001b[0m outputs \u001b[39m=\u001b[39m llm\u001b[39m.\u001b[39;49mgenerate(all_recipes[i:i\u001b[39m+\u001b[39;49mBATCH_SIZE], sampling_params\u001b[39m=\u001b[39;49msampling_params)\n\u001b[1;32m 14\u001b[0m all_outputs\u001b[39m.\u001b[39mextend([o\u001b[39m.\u001b[39moutputs[\u001b[39m0\u001b[39m]\u001b[39m.\u001b[39mtext \u001b[39mfor\u001b[39;00m o \u001b[39min\u001b[39;00m outputs])\n\u001b[1;32m 16\u001b[0m end_time \u001b[39m=\u001b[39m time\u001b[39m.\u001b[39mtime()\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/llm.py:130\u001b[0m, in \u001b[0;36mLLM.generate\u001b[0;34m(self, prompts, sampling_params, prompt_token_ids, use_tqdm)\u001b[0m\n\u001b[1;32m 128\u001b[0m token_ids \u001b[39m=\u001b[39m prompt_token_ids[i]\n\u001b[1;32m 129\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_add_request(prompt, sampling_params, token_ids)\n\u001b[0;32m--> 130\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_run_engine(use_tqdm)\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/llm.py:150\u001b[0m, in \u001b[0;36mLLM._run_engine\u001b[0;34m(self, use_tqdm)\u001b[0m\n\u001b[1;32m 148\u001b[0m outputs: List[RequestOutput] \u001b[39m=\u001b[39m []\n\u001b[1;32m 149\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mllm_engine\u001b[39m.\u001b[39mhas_unfinished_requests():\n\u001b[0;32m--> 150\u001b[0m step_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mllm_engine\u001b[39m.\u001b[39;49mstep()\n\u001b[1;32m 151\u001b[0m \u001b[39mfor\u001b[39;00m output \u001b[39min\u001b[39;00m step_outputs:\n\u001b[1;32m 152\u001b[0m \u001b[39mif\u001b[39;00m output\u001b[39m.\u001b[39mfinished:\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py:313\u001b[0m, in \u001b[0;36mLLMEngine.step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 307\u001b[0m \u001b[39mreturn\u001b[39;00m [\n\u001b[1;32m 308\u001b[0m RequestOutput\u001b[39m.\u001b[39mfrom_seq_group(seq_group)\n\u001b[1;32m 309\u001b[0m \u001b[39mfor\u001b[39;00m seq_group \u001b[39min\u001b[39;00m scheduler_outputs\u001b[39m.\u001b[39mignored_seq_groups\n\u001b[1;32m 310\u001b[0m ]\n\u001b[1;32m 312\u001b[0m \u001b[39m# Execute the model.\u001b[39;00m\n\u001b[0;32m--> 313\u001b[0m output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_run_workers(\n\u001b[1;32m 314\u001b[0m \u001b[39m\"\u001b[39;49m\u001b[39mexecute_model\u001b[39;49m\u001b[39m\"\u001b[39;49m,\n\u001b[1;32m 315\u001b[0m seq_group_metadata_list\u001b[39m=\u001b[39;49mseq_group_metadata_list,\n\u001b[1;32m 316\u001b[0m blocks_to_swap_in\u001b[39m=\u001b[39;49mscheduler_outputs\u001b[39m.\u001b[39;49mblocks_to_swap_in,\n\u001b[1;32m 317\u001b[0m blocks_to_swap_out\u001b[39m=\u001b[39;49mscheduler_outputs\u001b[39m.\u001b[39;49mblocks_to_swap_out,\n\u001b[1;32m 318\u001b[0m blocks_to_copy\u001b[39m=\u001b[39;49mscheduler_outputs\u001b[39m.\u001b[39;49mblocks_to_copy,\n\u001b[1;32m 319\u001b[0m )\n\u001b[1;32m 320\u001b[0m \u001b[39m# Update the scheduler with the model outputs.\u001b[39;00m\n\u001b[1;32m 321\u001b[0m seq_groups \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mscheduler\u001b[39m.\u001b[39mupdate(output)\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py:470\u001b[0m, in \u001b[0;36mLLMEngine._run_workers\u001b[0;34m(self, method, get_all_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 467\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 468\u001b[0m executor \u001b[39m=\u001b[39m \u001b[39mgetattr\u001b[39m(worker, method)\n\u001b[0;32m--> 470\u001b[0m output \u001b[39m=\u001b[39m executor(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 471\u001b[0m all_outputs\u001b[39m.\u001b[39mappend(output)\n\u001b[1;32m 473\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mparallel_config\u001b[39m.\u001b[39mworker_use_ray:\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[39m@functools\u001b[39m\u001b[39m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mdecorate_context\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[39mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[39mreturn\u001b[39;00m func(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/worker/worker.py:293\u001b[0m, in \u001b[0;36mWorker.execute_model\u001b[0;34m(self, seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)\u001b[0m\n\u001b[1;32m 289\u001b[0m input_tokens, input_positions, input_metadata \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_prepare_inputs(\n\u001b[1;32m 290\u001b[0m seq_group_metadata_list)\n\u001b[1;32m 292\u001b[0m \u001b[39m# Execute the model.\u001b[39;00m\n\u001b[0;32m--> 293\u001b[0m output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel(\n\u001b[1;32m 294\u001b[0m input_ids\u001b[39m=\u001b[39;49minput_tokens,\n\u001b[1;32m 295\u001b[0m positions\u001b[39m=\u001b[39;49minput_positions,\n\u001b[1;32m 296\u001b[0m kv_caches\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mgpu_cache,\n\u001b[1;32m 297\u001b[0m input_metadata\u001b[39m=\u001b[39;49minput_metadata,\n\u001b[1;32m 298\u001b[0m cache_events\u001b[39m=\u001b[39;49mcache_events,\n\u001b[1;32m 299\u001b[0m )\n\u001b[1;32m 300\u001b[0m \u001b[39mreturn\u001b[39;00m output\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/llama.py:255\u001b[0m, in \u001b[0;36mLlamaForCausalLM.forward\u001b[0;34m(self, input_ids, positions, kv_caches, input_metadata, cache_events)\u001b[0m\n\u001b[1;32m 245\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\n\u001b[1;32m 246\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 247\u001b[0m input_ids: torch\u001b[39m.\u001b[39mTensor,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 251\u001b[0m cache_events: Optional[List[torch\u001b[39m.\u001b[39mcuda\u001b[39m.\u001b[39mEvent]],\n\u001b[1;32m 252\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Dict[\u001b[39mint\u001b[39m, SequenceOutputs]:\n\u001b[1;32m 253\u001b[0m hidden_states \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel(input_ids, positions, kv_caches,\n\u001b[1;32m 254\u001b[0m input_metadata, cache_events)\n\u001b[0;32m--> 255\u001b[0m next_tokens \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49msampler(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mlm_head\u001b[39m.\u001b[39;49mweight, hidden_states,\n\u001b[1;32m 256\u001b[0m input_metadata)\n\u001b[1;32m 257\u001b[0m \u001b[39mreturn\u001b[39;00m next_tokens\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/sampler.py:44\u001b[0m, in \u001b[0;36mSampler.forward\u001b[0;34m(self, embedding, hidden_states, input_metadata, embedding_bias)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\n\u001b[1;32m 37\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 38\u001b[0m embedding: torch\u001b[39m.\u001b[39mTensor,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 42\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Dict[\u001b[39mint\u001b[39m, SequenceOutputs]:\n\u001b[1;32m 43\u001b[0m \u001b[39m# Get the hidden states that we use for sampling.\u001b[39;00m\n\u001b[0;32m---> 44\u001b[0m hidden_states \u001b[39m=\u001b[39m _prune_hidden_states(hidden_states, input_metadata)\n\u001b[1;32m 46\u001b[0m \u001b[39m# Get the logits for the next tokens.\u001b[39;00m\n\u001b[1;32m 47\u001b[0m logits \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mmatmul(hidden_states, embedding\u001b[39m.\u001b[39mt())\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"# We'll process our recipes in batches of 10,000.\n",
"\n",
"import time\n",
"\n",
"BATCH_SIZE = 10000\n",
"all_outputs = []\n",
"\n",
"start_time = time.time()\n",
"print(f\"Start time: {start_time}\")\n",
"for i in range(0, len(all_recipes), BATCH_SIZE):\n",
" print(f\"Processing recipes {i:,} to {i+BATCH_SIZE:,}...\")\n",
" outputs = llm.generate(\n",
" all_recipes[i : i + BATCH_SIZE], sampling_params=sampling_params\n",
" )\n",
"\n",
" all_outputs.extend([o.outputs[0].text for o in outputs])\n",
"\n",
"end_time = time.time()\n",
"print(f\"End time: {end_time}\")\n",
"print(f\"Total hours: {((end_time - start_time) / 3600):.2f}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Nice! I've processed all 2,147,248 recipes in under 17 hours. Let's do a cost comparison with GPT-3.5 and GPT-4. I'll use the GPT-4 latency/cost numbers based on the 5000 samples used to generate our model's training data."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Model</th>\n",
" <th>Cost to Classify One Recipe</th>\n",
" <th>Cost to Classify Entire Dataset</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Llama 2 7B (finetuned)</td>\n",
" <td>0.000009</td>\n",
" <td>18.86</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>GPT-3.5</td>\n",
" <td>0.000481</td>\n",
" <td>1,033.26</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>GPT-3.5 (finetuned)</td>\n",
" <td>0.004044</td>\n",
" <td>8,683.47</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>GPT-4</td>\n",
" <td>0.010800</td>\n",
" <td>23,190.28</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Model Cost to Classify One Recipe \\\n",
"0 Llama 2 7B (finetuned) 0.000009 \n",
"1 GPT-3.5 0.000481 \n",
"2 GPT-3.5 (finetuned) 0.004044 \n",
"3 GPT-4 0.010800 \n",
"\n",
" Cost to Classify Entire Dataset \n",
"0 18.86 \n",
"1 1,033.26 \n",
"2 8,683.47 \n",
"3 23,190.28 "
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"# I used an on-demand Nvidia L40 on RunPod for this, at an hourly cost of $1.14.\n",
"finetuned_hourly_cost = 1.14\n",
"\n",
"finetuned_total_hours = 16.54\n",
"\n",
"finetuned_avg_cost = finetuned_hourly_cost * finetuned_total_hours / len(all_recipes)\n",
"\n",
"# The average input and output tokens calculated by OpenAI, based on the 5000 recipes I sent them\n",
"avg_input_tokens = 276\n",
"avg_output_tokens = 42\n",
"\n",
"# Token pricing from https://openai.com/pricing\n",
"gpt_4_avg_cost = avg_input_tokens * 0.03 / 1000 + avg_output_tokens * 0.06 / 1000\n",
"\n",
"gpt_35_avg_cost = avg_input_tokens * 0.0015 / 1000 + avg_output_tokens * 0.0016 / 1000\n",
"\n",
"gpt_35_finetuned_avg_cost = (\n",
" avg_input_tokens * 0.012 / 1000 + avg_output_tokens * 0.016 / 1000 + 0.06 / 1000\n",
")\n",
"\n",
"# Multiply the number of recipes\n",
"# gpt_4_cost = len(all_recipes) * gpt_4_avg_cost\n",
"# gpt_35_cost = len(all_recipes) * gpt_35_avg_cost\n",
"# gpt_35_finetuned_cost = len(all_recipes) * gpt_35_finetuned_avg_cost\n",
"\n",
"# Let's put this in a dataframe for easier comparison.\n",
"\n",
"costs = pd.DataFrame(\n",
" {\n",
" \"Model\": [\n",
" \"Llama 2 7B (finetuned)\",\n",
" \"GPT-3.5\",\n",
" \"GPT-3.5 (finetuned)\",\n",
" \"GPT-4\",\n",
" ],\n",
" \"Cost to Classify One Recipe\": [\n",
" finetuned_avg_cost,\n",
" gpt_35_avg_cost,\n",
" gpt_35_finetuned_avg_cost,\n",
" gpt_4_avg_cost,\n",
" ],\n",
" }\n",
")\n",
"\n",
"costs[\"Cost to Classify Entire Dataset\"] = (\n",
" costs[\"Cost to Classify One Recipe\"] * len(all_recipes)\n",
").map(lambda x: f\"{x:,.2f}\")\n",
"\n",
"\n",
"costs\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"...and just for fun, let's figure out how many recipes my pescatarian basement-dwelling brother can make! 😂"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,123 @@
# %% [markdown]
# I'm pretty happy with my model's accuracy relative to GPT-4. How does it compare cost-wise?
#
# I'll really push this to its limits -- let's see how quickly our poor model can classify the [full 2-million-recipe dataset](https://huggingface.co/datasets/corbt/all-recipes) 😈.
# %%
# %%
from datasets import load_dataset
all_recipes = load_dataset("corbt/all-recipes")["train"]["input"]
print(f"Number of recipes: {len(all_recipes):,}")
# %%
from vllm import LLM, SamplingParams
llm = LLM(model="./models/run1/merged", max_num_batched_tokens=4096)
sampling_params = SamplingParams(
# 120 should be fine for the work we're doing here.
max_tokens=120,
# This is a deterministic task so temperature=0 is best.
temperature=0,
)
# %%
import os
import time
import json
BATCH_SIZE = 10000
start_time = time.time()
print(f"Start time: {start_time}")
for i in range(0, len(all_recipes), BATCH_SIZE):
# File name for the current batch
file_name = f"./data/benchmark_batch_{int(i/BATCH_SIZE)}.txt"
# Check if the file already exists; if so, skip to the next batch
if os.path.exists(file_name):
print(f"File {file_name} exists, skipping recipes {i:,} to {i+BATCH_SIZE:,}...")
continue
print(f"Processing recipes {i:,} to {i+BATCH_SIZE:,}...")
outputs = llm.generate(
all_recipes[i : i + BATCH_SIZE], sampling_params=sampling_params
)
outputs = [o.outputs[0].text for o in outputs]
# Write the generated outputs to the file as a JSON list
json.dump(outputs, open(file_name, "w"))
end_time = time.time()
print(f"End time: {end_time}")
print(f"Total hours: {((end_time - start_time) / 3600):.2f}")
# %% [markdown]
# Nice! I've processed all 2,147,248 recipes in under 17 hours. Let's do a cost comparison with GPT-3.5 and GPT-4. I'll use the GPT-4 latency/cost numbers based on the 5000 samples used to generate our model's training data.
# %%
import pandas as pd
# I used an on-demand Nvidia L40 on RunPod for this, at an hourly cost of $1.14.
finetuned_hourly_cost = 1.14
finetuned_total_hours = 17
finetuned_avg_cost = finetuned_hourly_cost * finetuned_total_hours / len(all_recipes)
# The average input and output tokens calculated by OpenAI, based on the 5000 recipes I sent them
avg_input_tokens = 276
avg_output_tokens = 42
# Token pricing from https://openai.com/pricing
gpt_4_avg_cost = avg_input_tokens * 0.03 / 1000 + avg_output_tokens * 0.06 / 1000
gpt_35_avg_cost = avg_input_tokens * 0.0015 / 1000 + avg_output_tokens * 0.0016 / 1000
gpt_35_finetuned_avg_cost = (
avg_input_tokens * 0.012 / 1000 + avg_output_tokens * 0.016 / 1000 + 0.06 / 1000
)
# Multiply the number of recipes
# gpt_4_cost = len(all_recipes) * gpt_4_avg_cost
# gpt_35_cost = len(all_recipes) * gpt_35_avg_cost
# gpt_35_finetuned_cost = len(all_recipes) * gpt_35_finetuned_avg_cost
# Let's put this in a dataframe for easier comparison.
costs = pd.DataFrame(
{
"Model": [
"Llama 2 7B (finetuned)",
"GPT-3.5",
"GPT-3.5 (finetuned)",
"GPT-4",
],
"Cost to Classify One Recipe": [
finetuned_avg_cost,
gpt_35_avg_cost,
gpt_35_finetuned_avg_cost,
gpt_4_avg_cost,
],
}
)
costs["Cost to Classify Entire Dataset"] = (
costs["Cost to Classify One Recipe"] * len(all_recipes)
).map(lambda x: f"{x:,.2f}")
costs
# %% [markdown]
# ...and just for fun, let's figure out how many recipes my pescatarian basement-dwelling brother can make! 😂
# %%

View File

@@ -6,22 +6,104 @@
"source": [
"I'm pretty happy with my model's accuracy relative to GPT-4. How does it compare cost-wise?\n",
"\n",
"I'll really push this to its limits -- let's see how quickly our poor model can classify the [full 2-million-recipe dataset](https://huggingface.co/datasets/corbt/all-recipes) 😈.\n"
"I'll really push this to its limits -- let's see how quickly our poor model can classify the [full 2-million-recipe dataset](https://huggingface.co/datasets/corbt/all-recipes) 😈."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: datasets==2.14.4 in /usr/local/lib/python3.10/dist-packages (2.14.4)\n",
"Requirement already satisfied: vllm==0.1.3 in /usr/local/lib/python3.10/dist-packages (0.1.3)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (1.24.4)\n",
"Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (12.0.1)\n",
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.3.7)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.0.3)\n",
"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.28.1)\n",
"Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (4.66.1)\n",
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.3.0)\n",
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.70.15)\n",
"Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2023.6.0)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.8.5)\n",
"Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.16.4)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (23.1)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (6.0)\n",
"Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (5.9.5)\n",
"Requirement already satisfied: ray>=2.5.1 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.6.3)\n",
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.1.99)\n",
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.0.1+cu118)\n",
"Requirement already satisfied: transformers>=4.31.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (4.33.0.dev0)\n",
"Requirement already satisfied: xformers>=0.0.19 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.0.21)\n",
"Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.101.1)\n",
"Requirement already satisfied: uvicorn in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.23.2)\n",
"Requirement already satisfied: pydantic<2 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.10.12)\n",
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (23.1.0)\n",
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (2.1.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (4.0.3)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.9.2)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.4.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.3.1)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (3.9.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (4.7.1)\n",
"Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (8.1.7)\n",
"Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.18.0)\n",
"Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.0.5)\n",
"Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.24.1)\n",
"Requirement already satisfied: grpcio>=1.42.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.57.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (3.4)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (1.26.13)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (2022.12.7)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.0)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.1.2)\n",
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (2.0.0)\n",
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (3.25.0)\n",
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (15.0.7)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (2023.8.8)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.3.2)\n",
"Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi->vllm==0.1.3) (0.27.0)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
"Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn->vllm==0.1.3) (0.14.0)\n",
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->datasets==2.14.4) (1.16.0)\n",
"Requirement already satisfied: anyio<5,>=3.4.0 in /usr/local/lib/python3.10/dist-packages (from starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (3.7.1)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.0.0->vllm==0.1.3) (2.1.2)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (2023.6.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.29.1)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.8.10)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.0.0->vllm==0.1.3) (1.2.1)\n",
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.3.0)\n",
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.1.2)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%%capture\n",
"%pip install datasets==2.14.4 vllm==0.1.3"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"outputs": [
{
@@ -194,12 +276,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Nice! I've processed all 2,147,248 recipes in under 17 hours. Let's do a cost comparison with GPT-3.5 and GPT-4. I'll use the GPT-4 latency/cost numbers based on the 5000 samples used to generate our model's training data.\n"
"Nice! I've processed all 2,147,248 recipes in under 17 hours. Let's do a cost comparison with GPT-3.5 and GPT-4. I'll use the GPT-4 latency/cost numbers based on the 5000 samples used to generate our model's training data."
]
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 19,
"metadata": {},
"outputs": [
{
@@ -231,47 +313,47 @@
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Llama2 7B (FT)</td>\n",
" <td>Llama 2 7B (finetuned)</td>\n",
" <td>0.000009</td>\n",
" <td>18.81</td>\n",
" <td>18.86</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>GPT-3.5</td>\n",
" <td>0.000481</td>\n",
" <td>1033.26</td>\n",
" <td>1,033.26</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>GPT-3.5 (FT)</td>\n",
" <td>GPT-3.5 (finetuned)</td>\n",
" <td>0.004044</td>\n",
" <td>8683.47</td>\n",
" <td>8,683.47</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>GPT-4</td>\n",
" <td>0.010800</td>\n",
" <td>23190.28</td>\n",
" <td>23,190.28</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Model Cost to Classify One Recipe \\\n",
"0 Llama2 7B (FT) 0.000009 \n",
"1 GPT-3.5 0.000481 \n",
"2 GPT-3.5 (FT) 0.004044 \n",
"3 GPT-4 0.010800 \n",
" Model Cost to Classify One Recipe \\\n",
"0 Llama 2 7B (finetuned) 0.000009 \n",
"1 GPT-3.5 0.000481 \n",
"2 GPT-3.5 (finetuned) 0.004044 \n",
"3 GPT-4 0.010800 \n",
"\n",
" Cost to Classify Entire Dataset \n",
"0 18.81 \n",
"1 1033.26 \n",
"2 8683.47 \n",
"3 23190.28 "
" Cost to Classify Entire Dataset \n",
"0 18.86 \n",
"1 1,033.26 \n",
"2 8,683.47 \n",
"3 23,190.28 "
]
},
"execution_count": 23,
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
@@ -300,12 +382,12 @@
" avg_input_tokens * 0.012 / 1000 + avg_output_tokens * 0.016 / 1000 + 0.06 / 1000\n",
")\n",
"\n",
"models = pd.DataFrame(\n",
"costs = pd.DataFrame(\n",
" {\n",
" \"Model\": [\n",
" \"Llama2 7B (FT)\",\n",
" \"Llama 2 7B (finetuned)\",\n",
" \"GPT-3.5\",\n",
" \"GPT-3.5 (FT)\",\n",
" \"GPT-3.5 (finetuned)\",\n",
" \"GPT-4\",\n",
" ],\n",
" \"Cost to Classify One Recipe\": [\n",
@@ -317,11 +399,12 @@
" }\n",
")\n",
"\n",
"models[\"Cost to Classify Entire Dataset\"] = (\n",
" models[\"Cost to Classify One Recipe\"] * len(all_recipes)\n",
").round(2)\n",
"costs[\"Cost to Classify Entire Dataset\"] = (\n",
" costs[\"Cost to Classify One Recipe\"] * len(all_recipes)\n",
").map(lambda x: f\"{x:,.2f}\")\n",
"\n",
"models\n"
"\n",
"costs\n"
]
}
],

View File

@@ -4,29 +4,96 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"I have a model in `./models/run1/merged` that was trained on GPT-4's outputs to classify recipes. I need to figure out whether it does a good job at classifying recipes. I'll install dependencies first.\n"
"I have a model in `./models/run1/merged` that was trained on GPT-4's outputs to classify recipes. I need to figure out whether it does a good job at classifying recipes. I'll install dependencies first."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: vllm==0.1.3 in /usr/local/lib/python3.10/dist-packages (0.1.3)\n",
"Requirement already satisfied: pandas==2.0.3 in /usr/local/lib/python3.10/dist-packages (2.0.3)\n",
"Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (5.9.5)\n",
"Requirement already satisfied: ray>=2.5.1 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.6.3)\n",
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.1.99)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.24.4)\n",
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.0.1+cu118)\n",
"Requirement already satisfied: transformers>=4.31.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (4.33.0.dev0)\n",
"Requirement already satisfied: xformers>=0.0.19 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.0.21)\n",
"Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.101.1)\n",
"Requirement already satisfied: uvicorn in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.23.2)\n",
"Requirement already satisfied: pydantic<2 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.10.12)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas==2.0.3) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas==2.0.3) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas==2.0.3) (2023.3)\n",
"Requirement already satisfied: typing-extensions>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<2->vllm==0.1.3) (4.7.1)\n",
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas==2.0.3) (1.16.0)\n",
"Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (8.1.7)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (3.9.0)\n",
"Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.18.0)\n",
"Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.0.5)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (23.1)\n",
"Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.24.1)\n",
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (6.0)\n",
"Requirement already satisfied: aiosignal in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.3.1)\n",
"Requirement already satisfied: frozenlist in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.4.0)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (2.28.1)\n",
"Requirement already satisfied: grpcio>=1.42.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.57.0)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.0)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.1.2)\n",
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (2.0.0)\n",
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (3.25.0)\n",
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (15.0.7)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.15.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.16.4)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (2023.8.8)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.3.2)\n",
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (4.66.1)\n",
"Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi->vllm==0.1.3) (0.27.0)\n",
"Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn->vllm==0.1.3) (0.14.0)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.15.1->transformers>=4.31.0->vllm==0.1.3) (2023.6.0)\n",
"Requirement already satisfied: anyio<5,>=3.4.0 in /usr/local/lib/python3.10/dist-packages (from starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (3.7.1)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.0.0->vllm==0.1.3) (2.1.2)\n",
"Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (23.1.0)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (2023.6.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.29.1)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.8.10)\n",
"Requirement already satisfied: charset-normalizer<3,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm==0.1.3) (2.1.1)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm==0.1.3) (3.4)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm==0.1.3) (1.26.13)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm==0.1.3) (2022.12.7)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.0.0->vllm==0.1.3) (1.2.1)\n",
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.3.0)\n",
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.1.2)\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.10 -m pip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%%capture\n",
"%pip install vllm==0.1.3 pandas==2.0.3 joblib==1.3.2"
"%pip install vllm==0.1.3 pandas==2.0.3"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Remember I got a \"test.jsonl\" file from OpenPipe back in [./prepare.ipynb](./prepare.ipynb)? That's data from our dataset that we didn't use in training, so we can use it to check our model's performance.\n"
"Remember I got a \"test.jsonl\" file from OpenPipe back in [./prepare.ipynb](./prepare.ipynb)? That's data from our dataset that we didn't use in training, so we can use it to check our model's performance."
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -39,12 +106,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"During the training process Axolotl transformed our data into an instruction/response format known as the \"Alpaca format\" based on [the project that introduced it](https://github.com/tatsu-lab/stanford_alpaca). I need to transform my test data into the same format for best results.\n"
"During the training process Axolotl transformed our data into an instruction/response format known as the \"Alpaca format\" based on [the project that introduced it](https://github.com/tatsu-lab/stanford_alpaca). I need to transform my test data into the same format for best results."
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -80,27 +147,27 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Next up, I'll use [vLLM](https://vllm.readthedocs.io/en/latest/) to efficiently process all the prompts in our test data with our own model.\n"
"Next up, I'll use [vLLM](https://vllm.readthedocs.io/en/latest/) to efficiently process all the prompts in our test data with our own model."
]
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO 08-28 00:26:23 llm_engine.py:70] Initializing an LLM engine with config: model='./models/run1/merged', tokenizer='./models/run1/merged', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n",
"INFO 08-28 00:27:26 llm_engine.py:196] # GPU blocks: 3419, # CPU blocks: 512\n"
"INFO 08-25 03:58:49 llm_engine.py:70] Initializing an LLM engine with config: model='./models/run1/merged', tokenizer='./models/run1/merged', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n",
"INFO 08-25 03:59:40 llm_engine.py:196] # GPU blocks: 3419, # CPU blocks: 512\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 500/500 [00:37<00:00, 13.34it/s]"
"Processed prompts: 100%|██████████| 500/500 [00:37<00:00, 13.42it/s]"
]
},
{
@@ -144,12 +211,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok, we have our outputs! There are 5 categories we classify each recipe on, so let's check what percentage of the time our model's output matches GPT-4's. I'll write a quick eval function for that:\n"
"Ok, we have our outputs! There are 5 categories we classify each recipe on, so let's check what percentage of the time our model's output matches GPT-4's. I'll write a quick eval function for that:"
]
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 5,
"metadata": {},
"outputs": [
{
@@ -172,29 +239,20 @@
" return args_dict\n",
"\n",
"\n",
"test_data[\"output_parsed\"] = test_data[\"output\"].apply(parse_fn_call)\n",
"test_data[\"my_outputs_parsed\"] = test_data[\"my_outputs\"].apply(parse_fn_call)\n",
"\n",
"\n",
"def calculate_accuracy(row, labels_col):\n",
"def calculate_accuracy(row):\n",
" \"\"\"Calculate the fraction of my model's outputs that match the reference outputs\"\"\"\n",
" true_outputs = row[\"output_parsed\"]\n",
" labels_outputs = row[labels_col]\n",
"\n",
" # print(f\"true_outputs: {true_outputs}\")\n",
" # print(f\"my_outputs: {row[labels_col]}\")\n",
" true_outputs = parse_fn_call(row[\"output\"])\n",
" my_outputs = parse_fn_call(row[\"my_outputs\"])\n",
"\n",
" num_matching_outputs = 0\n",
" for key in true_outputs.keys():\n",
" if key in labels_outputs and true_outputs[key] == labels_outputs[key]:\n",
" if key in my_outputs and true_outputs[key] == my_outputs[key]:\n",
" num_matching_outputs += 1\n",
"\n",
" return num_matching_outputs / len(true_outputs)\n",
"\n",
"\n",
"test_data[\"accuracy\"] = test_data.apply(\n",
" calculate_accuracy, axis=1, labels_col=\"my_outputs_parsed\"\n",
")\n",
"test_data[\"accuracy\"] = test_data.apply(calculate_accuracy, axis=1)\n",
"\n",
"print(f\"Overall accuracy: {test_data['accuracy'].mean():.2f}\")\n"
]
@@ -203,293 +261,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"95% seems good! However, we don't have much to compare it to. Let's see how GPT-3.5 would do on the same task as a baseline. We'll use the same prompt we used with GPT-4 to generate the labels.\n"
"Not bad! However, there are still a few rows where the model outputs don't match. Let's take a closer look."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sample recipe:\n",
"--------------\n",
"Pan Gravy\n",
"\n",
"Ingredients:\n",
"- 1/3 cup all purpose flour\n",
"- 1/3 cup turkey drippings\n",
"- 3 cup water or broth\n",
"- 1/8 to 1/4 teaspoon salt\n",
"- 1/8 tsp pepper\n",
"\n",
"Directions:\n",
"- In a skillet or roasting pan, add flour to drippings; blend well.\n",
"- Cook over medium heat 2 to 3 minutes until smooth and light brown, stirring constantly.\n",
"- Add water; cook until mixture boils and thickens, stirring constantly.\n",
"- Stir in salt and pepper.\n",
"- *Flour and drippings can be decreased to 1/4 cup each for thinner gravy.\n",
"- *\n"
]
}
],
"source": [
"import json\n",
"\n",
"\n",
"def extract_recipe(row):\n",
" \"\"\"Extract the recipe from the instruction\"\"\"\n",
" return json.loads(row[\"instruction\"])[1][\"content\"]\n",
"\n",
"\n",
"recipes = test_data.apply(extract_recipe, axis=1)\n",
"print(f\"Sample recipe:\\n--------------\\n{recipes[0]}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Classifying first recipe:\n",
"------------------\n",
"{'has_non_fish_meat': False, 'requires_oven': False, 'requires_stove': True, 'cook_time_over_30_mins': False, 'main_dish': False}\n"
]
}
],
"source": [
"import joblib\n",
"import openai\n",
"import os\n",
"import dotenv\n",
"\n",
"dotenv.load_dotenv()\n",
"openai.api_key = os.getenv(\"OPENAI_API_KEY\")\n",
"\n",
"memory = joblib.Memory(\"./cache\", verbose=0)\n",
"\n",
"\n",
"@memory.cache\n",
"def classify_recipe_35(recipe: str):\n",
" completion = openai.ChatCompletion.create(\n",
" model=\"gpt-3.5-turbo\",\n",
" messages=[\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": \"Your goal is to classify a recipe along several dimensions.Pay attention to the instructions.\",\n",
" },\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": recipe,\n",
" },\n",
" ],\n",
" functions=[\n",
" {\n",
" \"name\": \"classify\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"has_non_fish_meat\": {\n",
" \"type\": \"boolean\",\n",
" \"description\": \"True if the recipe contains any meat or meat products (eg. chicken broth) besides fish\",\n",
" },\n",
" \"requires_oven\": {\n",
" \"type\": \"boolean\",\n",
" \"description\": \"True if the recipe requires an oven\",\n",
" },\n",
" \"requires_stove\": {\n",
" \"type\": \"boolean\",\n",
" \"description\": \"True if the recipe requires a stove\",\n",
" },\n",
" \"cook_time_over_30_mins\": {\n",
" \"type\": \"boolean\",\n",
" \"description\": \"True if the recipe takes over 30 minutes to prepare and cook, including waiting time\",\n",
" },\n",
" \"main_dish\": {\n",
" \"type\": \"boolean\",\n",
" \"description\": \"True if the recipe can be served as a main dish\",\n",
" },\n",
" },\n",
" \"required\": [\n",
" \"has_non_fish_meat\",\n",
" \"requires_oven\",\n",
" \"requires_stove\",\n",
" \"cook_time_over_30_mins\",\n",
" \"main_dish\",\n",
" ],\n",
" },\n",
" }\n",
" ],\n",
" function_call={\n",
" \"name\": \"classify\",\n",
" },\n",
" )\n",
" return json.loads(completion.choices[0].message.function_call.arguments)\n",
"\n",
"\n",
"print(\"Classifying first recipe:\\n------------------\")\n",
"print(classify_recipe_35(recipes[0]))\n"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"100%|██████████| 500/500 [00:31<00:00, 15.77it/s]\n"
]
}
],
"source": [
"from tqdm import tqdm\n",
"\n",
"test_data[\"gpt_3.5\"] = [classify_recipe_35(r) for r in tqdm(recipes)]\n"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GPT-3.5 accuracy: 0.91\n"
]
}
],
"source": [
"test_data[\"gpt_3.5_accuracy\"] = test_data.apply(\n",
" calculate_accuracy, axis=1, labels_col=\"gpt_3.5\"\n",
")\n",
"\n",
"print(f\"GPT-3.5 accuracy: {test_data['gpt_3.5_accuracy'].mean():.2f}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And for completeness, let's try a fine-tuned GPT-3.5 model. You can find the fine-tuning code in [finetune-gpt-3.5.ipynb](./finetune-gpt-3.5.ipynb)\n"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'has_non_fish_meat': True,\n",
" 'requires_oven': False,\n",
" 'requires_stove': True,\n",
" 'cook_time_over_30_mins': False,\n",
" 'main_dish': False}"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@memory.cache\n",
"def classify_recipe_35_ft(recipe: str):\n",
" completion = openai.ChatCompletion.create(\n",
" model=\"ft:gpt-3.5-turbo-0613:openpipe::7rZpPqYn\",\n",
" messages=[\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": \"Your goal is to classify a recipe along several \"\n",
" \"dimensions.Pay attention to the instructions.\",\n",
" },\n",
" {\"role\": \"user\", \"content\": recipe},\n",
" ],\n",
" )\n",
"\n",
" return json.loads(completion.choices[0].message.content)\n",
"\n",
"\n",
"classify_recipe_35_ft(recipes[0])\n"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 500/500 [07:31<00:00, 1.11it/s]\n"
]
}
],
"source": [
"test_data[\"gpt_3.5_ft\"] = [classify_recipe_35_ft(r) for r in tqdm(recipes)]\n"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GPT-3.5 FT accuracy: 0.94\n"
]
}
],
"source": [
"test_data[\"gpt_3.5_ft_accuracy\"] = test_data.apply(\n",
" calculate_accuracy, axis=1, labels_col=\"gpt_3.5_ft\"\n",
")\n",
"\n",
"print(f\"GPT-3.5 FT accuracy: {test_data['gpt_3.5_ft_accuracy'].mean():.2f}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Not bad! However, there are still a few rows where the model outputs don't match. Let's take a closer look.\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 6,
"metadata": {},
"outputs": [
{
@@ -848,28 +625,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Looking at the outputs, it's clear that our model still makes some mistakes. But at the same time, there are plenty of examples like \"Rhonda's Butter Chess Pie\" where our model gets it right, even though GPT-4 got it wrong! And there are also cases like the \"Veggie Casserole\", where the \"right\" answer is truly ambiguous and really both answers are defensible.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A realistic point of comparison here might be GPT-3.5. Let's try to classify the same set of recipes using GPT-3.5 and see how it does. We'll use the same prompt that we used with GPT-4 to generate the initial training data.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Interested in cost/latency benchmarking? You can check out [./benchmarking.ipynb](./benchmarking.ipynb) for an overview of my findings!\n"
"Looking at the outputs, it's clear that our model still makes some mistakes. But at the same time, there are plenty of examples like \"Rhonda's Butter Chess Pie\" where our model gets it right, even though GPT-4 got it wrong! And there are also cases like the \"Veggie Casserole\", where the \"right\" answer is truly ambiguous and really both answers are defensible.\n",
"\n",
"Interested in cost/latency benchmarking? You can check out [./benchmarking.ipynb](./benchmarking.ipynb) for an overview of my findings!"
]
},
{

View File

@@ -1,207 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sample training data:\n",
"{'messages': [{'content': 'Your goal is to classify a recipe along several '\n",
" 'dimensions.Pay attention to the instructions.',\n",
" 'role': 'system'},\n",
" {'content': 'Homemade Salad Dressing\\n'\n",
" '\\n'\n",
" 'Ingredients:\\n'\n",
" \"- 1 pt. Hellmann's mayonnaise\\n\"\n",
" '- 1 pt. buttermilk\\n'\n",
" '- 1 tsp. Accent\\n'\n",
" '- 2 Tbsp. dry parsley\\n'\n",
" '- 2 pkg. low-calorie Italian salad dressing mix\\n'\n",
" '- 1 can jalapeno peppers or 4 oz. Jimenez green '\n",
" 'sauce\\n'\n",
" '\\n'\n",
" 'Directions:\\n'\n",
" '- Blend well in blender; store in refrigerator.\\n'\n",
" '- For dip, decrease liquid.',\n",
" 'role': 'user'},\n",
" {'content': '{\\n'\n",
" '\"has_non_fish_meat\": false,\\n'\n",
" '\"requires_oven\": false,\\n'\n",
" '\"requires_stove\": false,\\n'\n",
" '\"cook_time_over_30_mins\": false,\\n'\n",
" '\"main_dish\": false\\n'\n",
" '}',\n",
" 'role': 'assistant'}]}\n"
]
}
],
"source": [
"import pandas as pd\n",
"from pprint import pprint\n",
"import json\n",
"\n",
"df = pd.read_json(\"data/train.jsonl\", lines=True)\n",
"\n",
"training_data = []\n",
"for row in df.itertuples():\n",
" input = json.loads(row.instruction)\n",
" output = json.loads(row.output)\n",
"\n",
" output[\"content\"] = output[\"function_call\"][\"arguments\"]\n",
" del output[\"function_call\"]\n",
"\n",
" sample = {\"messages\": input.copy() + [output]}\n",
" training_data.append(sample)\n",
"\n",
"# save the training data to data/train-gpt3.5.jsonl\n",
"\n",
"with open(\"data/train-gpt3.5.jsonl\", \"w\") as f:\n",
" for sample in training_data:\n",
" f.write(json.dumps(sample) + \"\\n\")\n",
"\n",
"print(f\"Sample training data:\")\n",
"pprint(training_data[0])\n"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<File file id=file-faAdQ1KPxZH79ThW4Dbu4z1y at 0x7fa55db5c6d0> JSON: {\n",
" \"object\": \"file\",\n",
" \"id\": \"file-faAdQ1KPxZH79ThW4Dbu4z1y\",\n",
" \"purpose\": \"fine-tune\",\n",
" \"filename\": \"recipe-classification\",\n",
" \"bytes\": 4210831,\n",
" \"created_at\": 1693000959,\n",
" \"status\": \"uploaded\",\n",
" \"status_details\": null\n",
"}"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import os\n",
"import openai\n",
"\n",
"import dotenv\n",
"\n",
"dotenv.load_dotenv()\n",
"\n",
"openai.api_key = os.getenv(\"OPENAI_API_KEY\")\n",
"\n",
"openai.File.create(\n",
" file=open(\"data/train-gpt3.5.jsonl\", \"rb\"),\n",
" purpose=\"fine-tune\",\n",
" user_provided_filename=\"recipe-classification\",\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<OpenAIObject list at 0x7fa55dbf6930> JSON: {\n",
" \"object\": \"list\",\n",
" \"data\": [\n",
" {\n",
" \"object\": \"file\",\n",
" \"id\": \"file-faAdQ1KPxZH79ThW4Dbu4z1y\",\n",
" \"purpose\": \"fine-tune\",\n",
" \"filename\": \"recipe-classification\",\n",
" \"bytes\": 4210831,\n",
" \"created_at\": 1693000959,\n",
" \"status\": \"processed\",\n",
" \"status_details\": null\n",
" }\n",
" ]\n",
"}"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"openai.File.list()\n"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<FineTuningJob fine_tuning.job id=ftjob-EjjLxmj9P8apwPRk5s2NPSeB at 0x7fa55ddc4360> JSON: {\n",
" \"object\": \"fine_tuning.job\",\n",
" \"id\": \"ftjob-EjjLxmj9P8apwPRk5s2NPSeB\",\n",
" \"model\": \"gpt-3.5-turbo-0613\",\n",
" \"created_at\": 1693001190,\n",
" \"finished_at\": null,\n",
" \"fine_tuned_model\": null,\n",
" \"organization_id\": \"org-jRz4nVPMoeGHWL5nVR3Mb0kp\",\n",
" \"result_files\": [],\n",
" \"status\": \"created\",\n",
" \"validation_file\": null,\n",
" \"training_file\": \"file-faAdQ1KPxZH79ThW4Dbu4z1y\",\n",
" \"hyperparameters\": {\n",
" \"n_epochs\": 3\n",
" },\n",
" \"trained_tokens\": null\n",
"}"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"openai.FineTuningJob.create(\n",
" training_file=\"file-faAdQ1KPxZH79ThW4Dbu4z1y\", model=\"gpt-3.5-turbo\"\n",
")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -6,24 +6,61 @@
"source": [
"In this notebook I'm using the OpenPipe client to capture a set of calls to the OpenAI API.\n",
"\n",
"For this example I'll blithely throw engineering best practices to the wind and use the notebook itself to manage dependencies. 😁\n"
"For this example I'll blithely throw engineering best practices to the wind and use the notebook itself to manage dependencies. 😁"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: openpipe==3.0.3 in /usr/local/lib/python3.10/dist-packages (3.0.3)\n",
"Requirement already satisfied: python-dotenv==1.0.0 in /usr/local/lib/python3.10/dist-packages (1.0.0)\n",
"Requirement already satisfied: joblib==1.3.2 in /usr/local/lib/python3.10/dist-packages (1.3.2)\n",
"Requirement already satisfied: attrs<24.0.0,>=23.1.0 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (23.1.0)\n",
"Requirement already satisfied: httpx<0.25.0,>=0.24.1 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (0.24.1)\n",
"Requirement already satisfied: openai<0.28.0,>=0.27.8 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (0.27.9)\n",
"Requirement already satisfied: python-dateutil<3.0.0,>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (2.8.2)\n",
"Requirement already satisfied: toml<0.11.0,>=0.10.2 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (0.10.2)\n",
"Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (2022.12.7)\n",
"Requirement already satisfied: httpcore<0.18.0,>=0.15.0 in /usr/local/lib/python3.10/dist-packages (from httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (0.17.3)\n",
"Requirement already satisfied: idna in /usr/local/lib/python3.10/dist-packages (from httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (3.4)\n",
"Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (1.3.0)\n",
"Requirement already satisfied: requests>=2.20 in /usr/local/lib/python3.10/dist-packages (from openai<0.28.0,>=0.27.8->openpipe==3.0.3) (2.28.1)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from openai<0.28.0,>=0.27.8->openpipe==3.0.3) (4.66.1)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from openai<0.28.0,>=0.27.8->openpipe==3.0.3) (3.8.5)\n",
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil<3.0.0,>=2.8.2->openpipe==3.0.3) (1.16.0)\n",
"Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.10/dist-packages (from httpcore<0.18.0,>=0.15.0->httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (0.14.0)\n",
"Requirement already satisfied: anyio<5.0,>=3.0 in /usr/local/lib/python3.10/dist-packages (from httpcore<0.18.0,>=0.15.0->httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (3.7.1)\n",
"Requirement already satisfied: charset-normalizer<3,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.20->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (2.1.1)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.20->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (1.26.13)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (4.0.3)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (1.9.2)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (1.4.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (1.3.1)\n",
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5.0,>=3.0->httpcore<0.18.0,>=0.15.0->httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (1.1.2)\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.10 -m pip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%%capture\n",
"%pip install openpipe==3.0.3 python-dotenv==1.0.0 datasets==2.14.4"
"%pip install openpipe==3.0.3 python-dotenv==1.0.0 joblib==1.3.2 datasets==2.14.4"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When working with remote datasets (or any data, really), it's a good idea to visually inspect some samples to make sure it looks like you expect. I'll print a recipe.\n"
"When working with remote datasets (or any data, really), it's a good idea to visually inspect some samples to make sure it looks like you expect. I'll print a recipe."
]
},
{
@@ -95,16 +132,15 @@
"Mm, delicious. Anyway, we need to generate a training dataset. We'll call GPT-4 on each of our examples.\n",
"\n",
"In this case, I'll ask GPT-4 to classify each recipe along 5 dimensions:\n",
"\n",
"- has_non_fish_meat\n",
"- requires_oven\n",
"- requires_stove\n",
"- cook_time_over_30_mins\n",
"- main_dish\n",
" - has_non_fish_meat\n",
" - requires_oven\n",
" - requires_stove\n",
" - cook_time_over_30_mins\n",
" - main_dish\n",
"\n",
"That looks like a pretty random list, but there's actually an important unifying thread: I'm looking for meals that my pescatarian brother/co-founder can make in his kitchen-less, near-window-less basement apartment in San Francisco! (If you haven't tried to get an apartment in SF you probably think I'm joking 😂.)\n",
"\n",
"I'll use [OpenPipe](https://github.com/openpipe/openpipe) to track the API calls and form a training dataset. To follow along you'll need to create a free OpenPipe account, then copy your API key from https://app.openpipe.ai/project/settings into a file called `.env`. You can see an example in [./.env.example](./.env.example).\n"
"I'll use [OpenPipe](https://github.com/openpipe/openpipe) to track the API calls and form a training dataset. To follow along you'll need to create a free OpenPipe account, then copy your API key from https://app.openpipe.ai/project/settings into a file called `.env`. You can see an example in [./.env.example](./.env.example)."
]
},
{
@@ -204,7 +240,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"That's working, so I'll go ahead and classify all 5000 recipes with GPT-4. Using GPT-4 for this is slowwww and costs about $40. The model I'm fine-tuning will be much faster -- we'll see if we can make it as good!\n"
"That's working, so I'll go ahead and classify all 5000 recipes with GPT-4. Using GPT-4 for this is slowwww and costs about $40. The model I'm fine-tuning will be much faster -- we'll see if we can make it as good!"
]
},
{
@@ -284,11 +320,11 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok, now that my recipes are classified I'll download the training data.\n",
"Ok, now that my recipes are classified I'll download the training data. \n",
"\n",
"Next up I'll train the model -- check out [./train.ipynb](./train.ipynb) for details! Just go to https://app.openpipe.ai/request-logs, select all the logs you created, and click \"Export\". The default 10% testing split is fine for this dataset size.\n",
"\n",
"I got two files from that: `train.jsonl` and `test.jsonl`. I moved both of them into this repository under `./data/`.\n"
"I got two files from that: `train.jsonl` and `test.jsonl`. I moved both of them into this repository under `./data/`."
]
}
],

View File

@@ -4,16 +4,142 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's get to the fun part -- training a model. I'll start by installing the dependencies.\n"
"Now let's get to the fun part -- training a model. I'll start by installing the dependencies."
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 4,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: peft==0.5.0 in /usr/local/lib/python3.10/dist-packages (0.5.0)\n",
"\u001b[31mERROR: Could not find a version that satisfies the requirement python-dotenv==2.0.0 (from versions: 0.1.0, 0.1.2, 0.1.3, 0.1.5, 0.2.0, 0.3.0, 0.4.0, 0.5.0, 0.5.1, 0.6.0, 0.6.1, 0.6.2, 0.6.3, 0.6.4, 0.6.5, 0.7.0, 0.7.1, 0.8.0, 0.8.1, 0.8.2, 0.9.0, 0.9.1, 0.10.0, 0.10.1, 0.10.2, 0.10.3, 0.10.4, 0.10.5, 0.11.0, 0.12.0, 0.13.0, 0.14.0, 0.15.0, 0.16.0, 0.17.0, 0.17.1, 0.18.0, 0.19.0, 0.19.1, 0.19.2, 0.20.0, 0.21.0, 0.21.1, 1.0.0)\u001b[0m\u001b[31m\n",
"\u001b[0m\u001b[31mERROR: No matching distribution found for python-dotenv==2.0.0\u001b[0m\u001b[31m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.10 -m pip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n",
"fatal: destination path 'axolotl' already exists and is not an empty directory.\n",
"Obtaining file:///workspace/OpenPipe/examples/classify-recipes/axolotl\n",
" Preparing metadata (setup.py) ... \u001b[?25ldone\n",
"\u001b[?25hCollecting transformers@ git+https://github.com/huggingface/transformers.git (from axolotl==0.1)\n",
" Cloning https://github.com/huggingface/transformers.git to /tmp/pip-install-ckp96ans/transformers_783779e09ad546a5be81c173eca5fd38\n",
" Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-install-ckp96ans/transformers_783779e09ad546a5be81c173eca5fd38\n",
" Resolved https://github.com/huggingface/transformers.git to commit f26099e7b5cf579f99a42bab6ddd371bf2c8d548\n",
" Installing build dependencies ... \u001b[?25ldone\n",
"\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n",
"\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n",
"\u001b[?25hCollecting accelerate@ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b (from axolotl==0.1)\n",
" Using cached accelerate-0.22.0.dev0-py3-none-any.whl\n",
"Requirement already satisfied: bitsandbytes>=0.41.1 in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (0.41.1)\n",
"Requirement already satisfied: addict in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (2.4.0)\n",
"Requirement already satisfied: fire in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (0.5.0)\n",
"Requirement already satisfied: PyYAML==6.0 in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (6.0)\n",
"Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (2.14.4)\n",
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (0.1.99)\n",
"Requirement already satisfied: wandb in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (0.15.8)\n",
"Requirement already satisfied: einops in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (0.6.1)\n",
"Requirement already satisfied: xformers in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (0.0.21)\n",
"Requirement already satisfied: optimum in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (1.11.2)\n",
"Requirement already satisfied: hf_transfer in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (0.1.3)\n",
"Requirement already satisfied: colorama in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (0.4.6)\n",
"Requirement already satisfied: numba in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (0.57.1)\n",
"Requirement already satisfied: numpy==1.24.4 in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (1.24.4)\n",
"Requirement already satisfied: bert-score==0.3.13 in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (0.3.13)\n",
"Requirement already satisfied: evaluate==0.4.0 in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (0.4.0)\n",
"Requirement already satisfied: rouge-score==0.1.2 in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (0.1.2)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (1.11.2)\n",
"Requirement already satisfied: scikit-learn==1.2.2 in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (1.2.2)\n",
"Requirement already satisfied: pynvml in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (11.5.0)\n",
"Requirement already satisfied: flash-attn==2.0.8 in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (2.0.8)\n",
"Requirement already satisfied: torch>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from bert-score==0.3.13->axolotl==0.1) (2.0.1+cu118)\n",
"Requirement already satisfied: pandas>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from bert-score==0.3.13->axolotl==0.1) (2.0.3)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from bert-score==0.3.13->axolotl==0.1) (2.28.1)\n",
"Requirement already satisfied: tqdm>=4.31.1 in /usr/local/lib/python3.10/dist-packages (from bert-score==0.3.13->axolotl==0.1) (4.66.1)\n",
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from bert-score==0.3.13->axolotl==0.1) (3.7.2)\n",
"Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from bert-score==0.3.13->axolotl==0.1) (23.1)\n",
"Requirement already satisfied: dill in /usr/local/lib/python3.10/dist-packages (from evaluate==0.4.0->axolotl==0.1) (0.3.7)\n",
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from evaluate==0.4.0->axolotl==0.1) (3.3.0)\n",
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from evaluate==0.4.0->axolotl==0.1) (0.70.15)\n",
"Requirement already satisfied: fsspec[http]>=2021.05.0 in /usr/local/lib/python3.10/dist-packages (from evaluate==0.4.0->axolotl==0.1) (2023.6.0)\n",
"Requirement already satisfied: huggingface-hub>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from evaluate==0.4.0->axolotl==0.1) (0.16.4)\n",
"Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.10/dist-packages (from evaluate==0.4.0->axolotl==0.1) (0.18.0)\n",
"Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from flash-attn==2.0.8->axolotl==0.1) (1.11.1)\n",
"Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from rouge-score==0.1.2->axolotl==0.1) (1.4.0)\n",
"Requirement already satisfied: nltk in /usr/local/lib/python3.10/dist-packages (from rouge-score==0.1.2->axolotl==0.1) (3.8.1)\n",
"Requirement already satisfied: six>=1.14.0 in /usr/lib/python3/dist-packages (from rouge-score==0.1.2->axolotl==0.1) (1.16.0)\n",
"Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn==1.2.2->axolotl==0.1) (1.3.2)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn==1.2.2->axolotl==0.1) (3.2.0)\n",
"Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets->axolotl==0.1) (12.0.1)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets->axolotl==0.1) (3.8.5)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers@ git+https://github.com/huggingface/transformers.git->axolotl==0.1) (3.9.0)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers@ git+https://github.com/huggingface/transformers.git->axolotl==0.1) (2023.8.8)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers@ git+https://github.com/huggingface/transformers.git->axolotl==0.1) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers@ git+https://github.com/huggingface/transformers.git->axolotl==0.1) (0.3.2)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate@ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b->axolotl==0.1) (5.9.5)\n",
"Requirement already satisfied: termcolor in /usr/local/lib/python3.10/dist-packages (from fire->axolotl==0.1) (2.3.0)\n",
"Requirement already satisfied: llvmlite<0.41,>=0.40.0dev0 in /usr/local/lib/python3.10/dist-packages (from numba->axolotl==0.1) (0.40.1)\n",
"Requirement already satisfied: coloredlogs in /usr/local/lib/python3.10/dist-packages (from optimum->axolotl==0.1) (15.0.1)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from optimum->axolotl==0.1) (1.11.1)\n",
"Requirement already satisfied: Click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb->axolotl==0.1) (8.1.7)\n",
"Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb->axolotl==0.1) (3.1.32)\n",
"Requirement already satisfied: sentry-sdk>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb->axolotl==0.1) (1.29.2)\n",
"Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from wandb->axolotl==0.1) (0.4.0)\n",
"Requirement already satisfied: pathtools in /usr/local/lib/python3.10/dist-packages (from wandb->axolotl==0.1) (0.1.2)\n",
"Requirement already satisfied: setproctitle in /usr/local/lib/python3.10/dist-packages (from wandb->axolotl==0.1) (1.3.2)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb->axolotl==0.1) (68.0.0)\n",
"Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.10/dist-packages (from wandb->axolotl==0.1) (1.4.4)\n",
"Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb->axolotl==0.1) (4.24.1)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->bert-score==0.3.13->axolotl==0.1) (4.7.1)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->bert-score==0.3.13->axolotl==0.1) (3.0)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->bert-score==0.3.13->axolotl==0.1) (3.1.2)\n",
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->bert-score==0.3.13->axolotl==0.1) (2.0.0)\n",
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.0.0->bert-score==0.3.13->axolotl==0.1) (3.25.0)\n",
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.0.0->bert-score==0.3.13->axolotl==0.1) (15.0.7)\n",
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->axolotl==0.1) (23.1.0)\n",
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->axolotl==0.1) (2.1.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->axolotl==0.1) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->axolotl==0.1) (4.0.3)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->axolotl==0.1) (1.9.2)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->axolotl==0.1) (1.4.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->axolotl==0.1) (1.3.1)\n",
"Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from GitPython!=3.1.29,>=1.0.0->wandb->axolotl==0.1) (4.0.10)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.0.1->bert-score==0.3.13->axolotl==0.1) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.0.1->bert-score==0.3.13->axolotl==0.1) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.0.1->bert-score==0.3.13->axolotl==0.1) (2023.3)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->bert-score==0.3.13->axolotl==0.1) (3.4)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->bert-score==0.3.13->axolotl==0.1) (1.26.13)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->bert-score==0.3.13->axolotl==0.1) (2022.12.7)\n",
"Requirement already satisfied: humanfriendly>=9.1 in /usr/local/lib/python3.10/dist-packages (from coloredlogs->optimum->axolotl==0.1) (10.0)\n",
"Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->bert-score==0.3.13->axolotl==0.1) (1.1.0)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->bert-score==0.3.13->axolotl==0.1) (0.11.0)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->bert-score==0.3.13->axolotl==0.1) (4.42.1)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->bert-score==0.3.13->axolotl==0.1) (1.4.4)\n",
"Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->bert-score==0.3.13->axolotl==0.1) (9.3.0)\n",
"Requirement already satisfied: pyparsing<3.1,>=2.3.1 in /usr/lib/python3/dist-packages (from matplotlib->bert-score==0.3.13->axolotl==0.1) (2.4.7)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->optimum->axolotl==0.1) (1.2.1)\n",
"Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb->axolotl==0.1) (5.0.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.0.0->bert-score==0.3.13->axolotl==0.1) (2.1.2)\n",
"Installing collected packages: axolotl\n",
" Attempting uninstall: axolotl\n",
" Found existing installation: axolotl 0.1\n",
" Uninstalling axolotl-0.1:\n",
" Successfully uninstalled axolotl-0.1\n",
" Running setup.py develop for axolotl\n",
"Successfully installed axolotl\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.10 -m pip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%%capture\n",
"%pip install peft==0.5.0 python-dotenv==2.0.0\n",
"\n",
"!git clone https://github.com/OpenAccess-AI-Collective/axolotl\n",
@@ -24,7 +150,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Note to the reader: since we'll be basing our fine-tuned model on Meta's Llama 2, you need to apply for access to the weights (which will be automatically granted). Follow the steps on [HuggingFace](https://huggingface.co/meta-llama/Llama-2-7b-hf), then create a read-only access token [here](https://huggingface.co/settings/tokens) and copy it into your .env file.\n"
"Note to the reader: since we'll be basing our fine-tuned model on Meta's Llama 2, you need to apply for access to the weights (which will be automatically granted). Follow the steps on [HuggingFace](https://huggingface.co/meta-llama/Llama-2-7b-hf), then create a read-only access token [here](https://huggingface.co/settings/tokens) and copy it into your .env file."
]
},
{
@@ -59,7 +185,7 @@
"\n",
"In this case I'm using 8-bit training to use less GPU RAM, and sample packing to maximize GPU utilization. You can read more about the available options at https://github.com/OpenAccess-AI-Collective/axolotl.\n",
"\n",
"The training run options are defined in [training-config.yaml](./training-config.yaml).\n"
"The training run options are defined in [training-config.yaml](./training-config.yaml)."
]
},
{
@@ -865,7 +991,7 @@
"source": [
"Sweet! I now have a new directory `./models/run1`. This contains my trained model, which I can use to classify more recipes.\n",
"\n",
"There's one more step though. I trained our model using [LoRA](https://huggingface.co/docs/peft/conceptual_guides/lora), which is a memory-efficient training method. But the inference library we'll use for testing doesn't support LoRA models directly yet, so we need to \"merge\" our LoRA model to transform it into a standard Llama2-shaped model. I've defined a small helper to do that called `merge_lora_model` that I'll use below.\n"
"There's one more step though. I trained our model using [LoRA](https://huggingface.co/docs/peft/conceptual_guides/lora), which is a memory-efficient training method. But the inference library we'll use for testing doesn't support LoRA models directly yet, so we need to \"merge\" our LoRA model to transform it into a standard Llama2-shaped model. I've defined a small helper to do that called `merge_lora_model` that I'll use below."
]
},
{
@@ -918,7 +1044,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok, I have a model, but is it actually any good? I'll run some evaluations in [./evaluate.ipynb](./evaluate.ipynb) to check.\n"
"Ok, I have a model, but is it actually any good? I'll run some evaluations in [./evaluate.ipynb](./evaluate.ipynb) to check."
]
}
],