diff --git a/examples/classify-recipes/.env.example b/examples/classify-recipes/.env.example
new file mode 100644
index 0000000..819c9fc
--- /dev/null
+++ b/examples/classify-recipes/.env.example
@@ -0,0 +1,4 @@
+OPENAI_API_KEY="[your OpenAI API key]"
+OPENPIPE_API_KEY="[your OpenPipe API key from https://app.openpipe.ai/project/settings]"
+
+WANDB_API_KEY="[Optionally, you can set a Weights & Biases API key to track your training run. Create it at https://wandb.ai/settings]"
\ No newline at end of file
diff --git a/examples/classify-recipes/__init__.py b/examples/classify-recipes/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/examples/classify-recipes/evaluate.ipynb b/examples/classify-recipes/evaluate.ipynb
new file mode 100644
index 0000000..68e451f
--- /dev/null
+++ b/examples/classify-recipes/evaluate.ipynb
@@ -0,0 +1,167 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "I have a model in `./models/run1/merged` that was trained on GPT-4's outputs to classify recipes. I need to figure out whether it does a good job at classifying recipes. I'll install dependencies first."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Requirement already satisfied: vllm==0.1.3 in /usr/local/lib/python3.10/dist-packages (0.1.3)\n",
+ "Requirement already satisfied: pandas==2.0.3 in /usr/local/lib/python3.10/dist-packages (2.0.3)\n",
+ "Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.11.1)\n",
+ "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (5.9.5)\n",
+ "Requirement already satisfied: ray>=2.5.1 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.6.3)\n",
+ "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.1.99)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.24.4)\n",
+ "Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.0.1+cu118)\n",
+ "Requirement already satisfied: transformers>=4.31.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (4.33.0.dev0)\n",
+ "Requirement already satisfied: xformers>=0.0.19 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.0.21)\n",
+ "Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.101.1)\n",
+ "Requirement already satisfied: uvicorn in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.23.2)\n",
+ "Requirement already satisfied: pydantic<2 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.10.12)\n",
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas==2.0.3) (2.8.2)\n",
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas==2.0.3) (2023.3)\n",
+ "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas==2.0.3) (2023.3)\n",
+ "Requirement already satisfied: typing-extensions>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<2->vllm==0.1.3) (4.7.1)\n",
+ "Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas==2.0.3) (1.16.0)\n",
+ "Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (8.1.7)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (3.9.0)\n",
+ "Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.18.0)\n",
+ "Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.0.5)\n",
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (23.1)\n",
+ "Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.24.1)\n",
+ "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (6.0)\n",
+ "Requirement already satisfied: aiosignal in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.3.1)\n",
+ "Requirement already satisfied: frozenlist in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.4.0)\n",
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (2.28.1)\n",
+ "Requirement already satisfied: grpcio>=1.42.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.57.0)\n",
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (1.11.1)\n",
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.0)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.1.2)\n",
+ "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (2.0.0)\n",
+ "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (3.25.0)\n",
+ "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (15.0.7)\n",
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.15.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.16.4)\n",
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (2023.8.8)\n",
+ "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.13.3)\n",
+ "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.3.2)\n",
+ "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (4.66.1)\n",
+ "Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi->vllm==0.1.3) (0.27.0)\n",
+ "Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn->vllm==0.1.3) (0.14.0)\n",
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.15.1->transformers>=4.31.0->vllm==0.1.3) (2023.6.0)\n",
+ "Requirement already satisfied: anyio<5,>=3.4.0 in /usr/local/lib/python3.10/dist-packages (from starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (3.7.1)\n",
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.0.0->vllm==0.1.3) (2.1.2)\n",
+ "Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (23.1.0)\n",
+ "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (2023.6.1)\n",
+ "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.29.1)\n",
+ "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.8.10)\n",
+ "Requirement already satisfied: charset-normalizer<3,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm==0.1.3) (2.1.1)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm==0.1.3) (3.4)\n",
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm==0.1.3) (1.26.13)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm==0.1.3) (2022.12.7)\n",
+ "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.0.0->vllm==0.1.3) (1.2.1)\n",
+ "Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.3.0)\n",
+ "Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.1.2)\n",
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0m\n",
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.10 -m pip install --upgrade pip\u001b[0m\n",
+ "Note: you may need to restart the kernel to use updated packages.\n"
+ ]
+ }
+ ],
+ "source": [
+ "%pip install vllm==0.1.3 pandas==2.0.3"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Remember I got a \"test.jsonl\" file from OpenPipe back in [./prepare.ipynb](./prepare.ipynb)? Since that is data formatted the same way as our training data but that we didn't use for training, we can use it to check our model's performance."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "\n",
+ "test_data = pd.read_json(\"./data/test.jsonl\", lines=True)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "During the training process Axolotl transformed our data into an instruction/response format known as the \"Alpaca format\" based on [the project that introduced it](https://github.com/tatsu-lab/stanford_alpaca). I need to transform my test data into the same format for best results."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "ModuleNotFoundError",
+ "evalue": "No module named 'axolotl.prompters'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39maxolotl\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mprompters\u001b[39;00m \u001b[39mimport\u001b[39;00m UnpromptedPrompter\n\u001b[1;32m 2\u001b[0m prompter \u001b[39m=\u001b[39m UnpromptedPrompter()\n\u001b[1;32m 4\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mformat_prompt\u001b[39m(\u001b[39minput\u001b[39m: \u001b[39mstr\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mstr\u001b[39m:\n",
+ "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'axolotl.prompters'"
+ ]
+ }
+ ],
+ "source": [
+ "from axolotl.prompters import UnpromptedPrompter\n",
+ "\n",
+ "prompter = UnpromptedPrompter()\n",
+ "\n",
+ "\n",
+ "def format_prompt(input: str) -> str:\n",
+ " return next(prompter.build_prompt(input))\n",
+ "\n",
+ "\n",
+ "prompts = test_data[\"input\"].apply(format_prompt)\n",
+ "\n",
+ "print(f\"Sample prompt:\\n-----------\\n{prompts[0]}\")\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.6"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/classify-recipes/generate-data.ipynb b/examples/classify-recipes/generate-data.ipynb
new file mode 100644
index 0000000..23b1ce9
--- /dev/null
+++ b/examples/classify-recipes/generate-data.ipynb
@@ -0,0 +1,329 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this notebook I'm using the OpenPipe client to capture a set of calls to the OpenAI API.\n",
+ "\n",
+ "For this example I'll blithely throw engineering best practices to the wind and use the notebook itself to manage dependencies. 😁"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Requirement already satisfied: openpipe==3.0.3 in /usr/local/lib/python3.10/dist-packages (3.0.3)\n",
+ "Requirement already satisfied: attrs<24.0.0,>=23.1.0 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (23.1.0)\n",
+ "Requirement already satisfied: httpx<0.25.0,>=0.24.1 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (0.24.1)\n",
+ "Requirement already satisfied: openai<0.28.0,>=0.27.8 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (0.27.9)\n",
+ "Requirement already satisfied: python-dateutil<3.0.0,>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (2.8.2)\n",
+ "Requirement already satisfied: toml<0.11.0,>=0.10.2 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (0.10.2)\n",
+ "Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (2022.12.7)\n",
+ "Requirement already satisfied: httpcore<0.18.0,>=0.15.0 in /usr/local/lib/python3.10/dist-packages (from httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (0.17.3)\n",
+ "Requirement already satisfied: idna in /usr/local/lib/python3.10/dist-packages (from httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (3.4)\n",
+ "Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (1.3.0)\n",
+ "Requirement already satisfied: requests>=2.20 in /usr/local/lib/python3.10/dist-packages (from openai<0.28.0,>=0.27.8->openpipe==3.0.3) (2.28.1)\n",
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from openai<0.28.0,>=0.27.8->openpipe==3.0.3) (4.66.1)\n",
+ "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from openai<0.28.0,>=0.27.8->openpipe==3.0.3) (3.8.5)\n",
+ "Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil<3.0.0,>=2.8.2->openpipe==3.0.3) (1.16.0)\n",
+ "Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.10/dist-packages (from httpcore<0.18.0,>=0.15.0->httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (0.14.0)\n",
+ "Requirement already satisfied: anyio<5.0,>=3.0 in /usr/local/lib/python3.10/dist-packages (from httpcore<0.18.0,>=0.15.0->httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (3.7.1)\n",
+ "Requirement already satisfied: charset-normalizer<3,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.20->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (2.1.1)\n",
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.20->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (1.26.13)\n",
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (6.0.4)\n",
+ "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (4.0.3)\n",
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (1.9.2)\n",
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (1.4.0)\n",
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (1.3.1)\n",
+ "Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5.0,>=3.0->httpcore<0.18.0,>=0.15.0->httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (1.1.2)\n",
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0m\n",
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.10 -m pip install --upgrade pip\u001b[0m\n",
+ "Note: you may need to restart the kernel to use updated packages.\n"
+ ]
+ }
+ ],
+ "source": [
+ "%pip install openpipe==3.0.3 python-dotenv==1.0.0 joblib==1.3.2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "When working with remote datasets (or any data, really), it's a good idea to visually inspect some samples to make sure it looks like you expect. I'll print a recipe."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Recipe dataset:\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "Dataset({\n",
+ " features: ['recipe'],\n",
+ " num_rows: 5000\n",
+ "})"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "First recipe:\n",
+ " Shrimp Creole\n",
+ "\n",
+ "Ingredients:\n",
+ "- 20 shrimp (8 oz.)\n",
+ "- 2 c. (16 oz. can) tomato sauce\n",
+ "- 1 small onion, chopped\n",
+ "- 1 celery stalk, chopped\n",
+ "- 1/4 green bell pepper, diced\n",
+ "- 1/4 c. sliced mushrooms\n",
+ "- 3 Tbsp. parsley\n",
+ "- 1/2 tsp. pepper\n",
+ "- 1 to 1-1/2 c. brown rice, prepared according to pkg. directions (not included in exchanges)\n",
+ "\n",
+ "Directions:\n",
+ "- Peel, devein and wash shrimp; set aside.\n",
+ "- (If shrimp are frozen, let thaw first in the refrigerator.)\n",
+ "- Simmer tomato sauce, onion, celery, green pepper, mushrooms, parsley and pepper in skillet for 30 minutes.\n",
+ "- Add shrimp and cook 10 to 15 minutes more, until shrimp are tender.\n",
+ "- Serve over brown rice.\n",
+ "- Serves 2.\n"
+ ]
+ }
+ ],
+ "source": [
+ "from datasets import load_dataset\n",
+ "\n",
+ "recipes = load_dataset(\"corbt/unlabeled-recipes\")[\"train\"]\n",
+ "print(\"Recipe dataset shape:\\n------------------\")\n",
+ "display(recipes)\n",
+ "print(\"First recipe:\\n------------------\", recipes[\"recipe\"][0])\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Mm, delicious. Anyway, we need to generate a training dataset. We'll call GPT-4 on each of our examples.\n",
+ "\n",
+ "We'll use [OpenPipe](https://github.com/openpipe/openpipe) to track our calls and form a training dataset. Create an account and a project, then copy your API key from https://app.openpipe.ai/project/settings into a file called `.env`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'has_non_fish_meat': False,\n",
+ " 'requires_oven': True,\n",
+ " 'requires_stove': True,\n",
+ " 'cook_time_over_30_mins': False,\n",
+ " 'main_dish': False}"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from openpipe import openai, configure_openpipe\n",
+ "import json\n",
+ "import os\n",
+ "import dotenv\n",
+ "\n",
+ "dotenv.load_dotenv()\n",
+ "\n",
+ "configure_openpipe(api_key=os.environ[\"OPENPIPE_API_KEY\"])\n",
+ "\n",
+ "openai.api_key = os.environ[\"OPENAI_API_KEY\"]\n",
+ "\n",
+ "\n",
+ "def classify_recipe(recipe: str):\n",
+ " completion = openai.ChatCompletion.create(\n",
+ " model=\"gpt-4\",\n",
+ " messages=[\n",
+ " {\n",
+ " \"role\": \"system\",\n",
+ " \"content\": \"Your goal is to classify a recipe along several dimensions.Pay attention to the instructions.\",\n",
+ " },\n",
+ " {\n",
+ " \"role\": \"user\",\n",
+ " \"content\": recipe,\n",
+ " },\n",
+ " ],\n",
+ " functions=[\n",
+ " {\n",
+ " \"name\": \"classify\",\n",
+ " \"parameters\": {\n",
+ " \"type\": \"object\",\n",
+ " \"properties\": {\n",
+ " \"has_non_fish_meat\": {\n",
+ " \"type\": \"boolean\",\n",
+ " \"description\": \"True if the recipe contains any meat or meat products (eg. chicken broth) besides fish\",\n",
+ " },\n",
+ " \"requires_oven\": {\n",
+ " \"type\": \"boolean\",\n",
+ " \"description\": \"True if the recipe requires an oven\",\n",
+ " },\n",
+ " \"requires_stove\": {\n",
+ " \"type\": \"boolean\",\n",
+ " \"description\": \"True if the recipe requires a stove\",\n",
+ " },\n",
+ " \"cook_time_over_30_mins\": {\n",
+ " \"type\": \"boolean\",\n",
+ " \"description\": \"True if the recipe takes over 30 minutes to prepare and cook, including waiting time\",\n",
+ " },\n",
+ " \"main_dish\": {\n",
+ " \"type\": \"boolean\",\n",
+ " \"description\": \"True if the recipe can be served as a main dish\",\n",
+ " },\n",
+ " },\n",
+ " \"required\": [\n",
+ " \"has_non_fish_meat\",\n",
+ " \"requires_oven\",\n",
+ " \"requires_stove\",\n",
+ " \"cook_time_over_30_mins\",\n",
+ " \"main_course\",\n",
+ " ],\n",
+ " },\n",
+ " }\n",
+ " ],\n",
+ " function_call={\n",
+ " \"name\": \"classify\",\n",
+ " },\n",
+ " openpipe={\"tags\": {\"prompt_id\": \"classify-recipe\"}, \"cache\": True},\n",
+ " )\n",
+ " return json.loads(completion.choices[0].message.function_call.arguments)\n",
+ "\n",
+ "\n",
+ "classify_recipe(recipes[\"recipe\"][-1])\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Classifying recipe 0/5000: Shrimp Creole\n",
+ "Classifying recipe 100/5000: Spoon Bread\n",
+ "Classifying recipe 200/5000: Quadrangle Grille'S Pumpkin-Walnut Cheesecake\n",
+ "Classifying recipe 300/5000: Broccoli Casserole\n",
+ "Error reporting to OpenPipe: 520 is not a valid HTTPStatus\n",
+ "520 is not a valid HTTPStatus\n",
+ "Classifying recipe 400/5000: Paal Payasam (3-Ingredient Rice Pudding)\n",
+ "Classifying recipe 500/5000: Dirt Dessert\n",
+ "Classifying recipe 600/5000: Dolma, Stuffed Dried Peppers And Eggplants\n",
+ "Classifying recipe 700/5000: Party Pecan Pies\n",
+ "Classifying recipe 800/5000: Pie Crust\n",
+ "Classifying recipe 900/5000: Russian Dressing(Salad Dressing) \n",
+ "Classifying recipe 1000/5000: O'Brien Potatoes\n",
+ "Classifying recipe 1100/5000: Monster Cookies\n",
+ "Classifying recipe 1200/5000: Striped Fruit Pops\n",
+ "Classifying recipe 1300/5000: Cute Heart-Shaped Fried Egg\n",
+ "Classifying recipe 1400/5000: Steak Marinade\n",
+ "Classifying recipe 1500/5000: Bbq Sauce For Fish Recipe\n",
+ "Classifying recipe 1600/5000: Barbecue Ranch Salad\n",
+ "Classifying recipe 1700/5000: White Fudge\n",
+ "Classifying recipe 1800/5000: Seaton Chocolate Chip Cookies\n",
+ "Classifying recipe 1900/5000: Beef Stroganoff\n",
+ "Classifying recipe 2000/5000: Lemon Delight\n",
+ "Classifying recipe 2100/5000: Cream Cheese Chicken Chili\n",
+ "Classifying recipe 2200/5000: Bean Salad\n",
+ "Classifying recipe 2300/5000: Green Beans Almondine\n",
+ "Classifying recipe 2400/5000: Radish-And-Avocado Salad\n",
+ "Classifying recipe 2500/5000: Salsa Rojo\n",
+ "Classifying recipe 2600/5000: Pepperoni Bread\n",
+ "Classifying recipe 2700/5000: Sabzi Polow\n",
+ "Classifying recipe 2800/5000: Italian Vegetable Pizzas\n",
+ "Error classifying recipe 2801: Bad gateway. {\"error\":{\"code\":502,\"message\":\"Bad gateway.\",\"param\":null,\"type\":\"cf_bad_gateway\"}} 502 {'error': {'code': 502, 'message': 'Bad gateway.', 'param': None, 'type': 'cf_bad_gateway'}} {'Date': 'Thu, 24 Aug 2023 15:44:45 GMT', 'Content-Type': 'application/json', 'Content-Length': '84', 'Connection': 'keep-alive', 'X-Frame-Options': 'SAMEORIGIN', 'Referrer-Policy': 'same-origin', 'Cache-Control': 'private, max-age=0, no-store, no-cache, must-revalidate, post-check=0, pre-check=0', 'Expires': 'Thu, 01 Jan 1970 00:00:01 GMT', 'Server': 'cloudflare', 'CF-RAY': '7fbca943df684de1-MCI', 'alt-svc': 'h3=\":443\"; ma=86400'}\n",
+ "Classifying recipe 2900/5000: Hot Fudge Sauce, Soda Shop Style\n",
+ "Classifying recipe 3000/5000: Meatball Soup With Vegetables And Brown Rice\n",
+ "Classifying recipe 3100/5000: Herbed Potatoes And Onions\n",
+ "Classifying recipe 3200/5000: Apple Crunch Pie (2 Extra Servings)\n",
+ "Classifying recipe 3300/5000: Pineapple-Orange Punch\n",
+ "Classifying recipe 3400/5000: Turkey Veggie Burgers With Avocado Mayo\n",
+ "Error reporting to OpenPipe: 520 is not a valid HTTPStatus\n",
+ "520 is not a valid HTTPStatus\n",
+ "Classifying recipe 3500/5000: Pear & Goat Cheese Salad\n",
+ "Classifying recipe 3600/5000: Triple Chocolate Cookies\n",
+ "Classifying recipe 3700/5000: Strawberry Banana Yogurt Pops\n",
+ "Error classifying recipe 3779: Request timed out: HTTPSConnectionPool(host='api.openai.com', port=443): Read timed out. (read timeout=600)\n",
+ "Classifying recipe 3800/5000: Chicken Croquettes\n",
+ "Classifying recipe 3900/5000: Mushroom Casserole\n"
+ ]
+ }
+ ],
+ "source": [
+ "for i, recipe in enumerate(recipes[\"recipe\"]):\n",
+ " if i % 100 == 0:\n",
+ " recipe_title = recipe.split(\"\\n\")[0]\n",
+ " print(f\"Classifying recipe {i}/{len(recipes)}: {recipe_title}\")\n",
+ " try:\n",
+ " classify_recipe(recipe)\n",
+ " except Exception as e:\n",
+ " print(f\"Error classifying recipe {i}: {e}\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "Ok, we have our "
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.6"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/classify-recipes/train.ipynb b/examples/classify-recipes/train.ipynb
index e2db259..3da98ab 100644
--- a/examples/classify-recipes/train.ipynb
+++ b/examples/classify-recipes/train.ipynb
@@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Now let's get to the fun part -- training a model. We'll start by installing our dependencies."
+ "Now let's get to the fun part -- training a model. I'll start by installing the dependencies."
]
},
{
@@ -177,11 +177,11 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "We'll use the [axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) library to manage our training run. It includes a lot of neat tricks that speed up training without sacrificing quality.\n",
+ "I'll use the [axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) library to manage our training run. It includes a lot of neat tricks that speed up training without sacrificing quality.\n",
"\n",
- "In this case we'll use 8-bit training to use less GPU RAM, and sample packing to maximize GPU utilization. You can read more about the available options at https://github.com/OpenAccess-AI-Collective/axolotl.\n",
+ "In this case I'm using 8-bit training to use less GPU RAM, and sample packing to maximize GPU utilization. You can read more about the available options at https://github.com/OpenAccess-AI-Collective/axolotl.\n",
"\n",
- "The training run options we're using here are defined in [training-args.yaml](./training-args.yaml)."
+ "The training run options are defined in [training-config.yaml](./training-config.yaml)."
]
},
{
@@ -365,16 +365,16 @@
}
],
"source": [
- "!accelerate launch ./axolotl/scripts/finetune.py training-args.yaml"
+ "!accelerate launch ./axolotl/scripts/finetune.py training-config.yaml"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Nice work! If you look on your filesystem you should see a new directory `./models/recipe-model`. This contains your trained model, which you can use to classify more recipes.\n",
+ "Sweet! If you look on your filesystem you should see a new directory `./models/run1`. This contains your trained model, which you can use to classify more recipes.\n",
"\n",
- "Before we using it though, we need to *merge* the model. We trained our model using [LoRA](https://huggingface.co/docs/peft/conceptual_guides/lora), which is a memory-efficient training method. But the inference library we'll use for testing doesn't support LoRA models yet, so we need to \"merge\" our LoRA model to transform it into a standard Llama2-style model. We've defined a helper to do that that we'll use below."
+ "There's one more step though. I trained our model using [LoRA](https://huggingface.co/docs/peft/conceptual_guides/lora), which is a memory-efficient training method. But the inference library we'll use for testing doesn't support LoRA models directly yet, so we need to \"merge\" our LoRA model to transform it into a standard Llama2-shaped model. I've defined a small helper to do that called `merge_lora_model` that I'll use below."
]
},
{
@@ -418,7 +418,7 @@
"from utils import merge_lora_model\n",
"\n",
"print(\"Merging model (this could take a while)\")\n",
- "final_model_dir = merge_lora_model(\"training-args.yaml\")\n",
+ "final_model_dir = merge_lora_model(\"training-config.yaml\")\n",
"print(f\"Final model saved to '{final_model_dir}'\")\n"
]
}
diff --git a/examples/classify-recipes/training-config.yaml b/examples/classify-recipes/training-config.yaml
new file mode 100644
index 0000000..ab96e76
--- /dev/null
+++ b/examples/classify-recipes/training-config.yaml
@@ -0,0 +1,73 @@
+# This file is used by the training script in train.ipynb. You can read more about
+# the format and see more examples at https://github.com/OpenAccess-AI-Collective/axolotl.
+# One of the parameters you might want to play around with is `num_epochs`: if you have a
+# smaller dataset size, making that large can have good results.
+
+base_model: meta-llama/Llama-2-7b-hf
+base_model_config: meta-llama/Llama-2-7b-hf
+model_type: LlamaForCausalLM
+tokenizer_type: LlamaTokenizer
+is_llama_derived_model: true
+
+load_in_8bit: true
+load_in_4bit: false
+strict: false
+
+datasets:
+ - path: ./data/train.jsonl
+ type: alpaca_instruct.load_no_prompt
+dataset_prepared_path: ./data/last_run_prepared
+val_set_size: 0.05
+output_dir: ./models/run1
+
+sequence_len: 4096
+sample_packing: true
+
+adapter: lora
+lora_model_dir:
+lora_r: 32
+lora_alpha: 16
+lora_dropout: 0.05
+lora_target_linear: true
+lora_fan_in_fan_out:
+
+# This will report stats from your training run to https://wandb.ai/. If you don't want to create a wandb account you can comment this section out.
+wandb_project: classify-recipes
+wandb_entity:
+wandb_watch:
+wandb_run_id: run1
+wandb_log_model:
+
+gradient_accumulation_steps: 4
+micro_batch_size: 2
+num_epochs: 5
+optimizer: adamw_bnb_8bit
+lr_scheduler: cosine
+learning_rate: 0.0002
+
+train_on_inputs: false
+group_by_length: false
+bf16: true
+fp16: false
+tf32: false
+
+gradient_checkpointing: true
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+
+warmup_steps: 10
+eval_steps: 20
+save_steps: 60
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+fsdp_config:
+special_tokens:
+ bos_token: ""
+ eos_token: ""
+ unk_token: ""
\ No newline at end of file
diff --git a/examples/classify-recipes/utils.py b/examples/classify-recipes/utils.py
new file mode 100644
index 0000000..35eada7
--- /dev/null
+++ b/examples/classify-recipes/utils.py
@@ -0,0 +1,37 @@
+import yaml
+from transformers import AutoModelForCausalLM, AutoTokenizer
+import torch
+from peft import PeftModel
+import os
+
+
+def merge(config_file: str):
+ config = yaml.load(open(config_file, "r"), Loader=yaml.FullLoader)
+
+ base_model = config["base_model"]
+ lora_model = config["output_dir"]
+ merged_model = f"{lora_model}/merged"
+
+ if os.path.exists(merged_model):
+ print(f"Model {merged_model} already exists, skipping")
+ return merged_model
+
+ print("Loading base model")
+ model = AutoModelForCausalLM.from_pretrained(
+ base_model,
+ return_dict=True,
+ torch_dtype=torch.float16,
+ )
+
+ print("Loading PEFT model")
+ model = PeftModel.from_pretrained(model, lora_model)
+ print(f"Running merge_and_unload")
+ model = model.merge_and_unload()
+
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
+
+ model.save_pretrained(merged_model)
+ tokenizer.save_pretrained(merged_model)
+ print(f"Model saved to {merged_model}")
+
+ return merged_model