Compare commits

...

12 Commits

Author SHA1 Message Date
arcticfly
5df898a4f6 Add support for CommonJS in sdk (#222)
* trying rollup

* Imports mostly working

* Remove yalc.lock

* Use tsup to build

* Disable code splitting

* Update cjs exports

* Use openpipe from npm registry in app

* Remove npmignore, fix lint

* Copy the README to dist

* Update openpipe version

* Remove second openai entrypoint

* Update openpipe version

---------

Co-authored-by: Kyle Corbitt <kyle@corbt.com>
2023-09-12 14:30:37 -07:00
Kyle Corbitt
a33f674ccd more README changes 2023-09-11 20:27:35 -07:00
Kyle Corbitt
0062952eb2 readme links 2023-09-12 00:49:09 +00:00
Kyle Corbitt
381604bc88 Merge pull request #221 from OpenPipe/examples
Examples
2023-09-11 17:48:05 -07:00
Kyle Corbitt
db69b8e496 clean up example 2023-09-12 00:47:22 +00:00
David Corbitt
88be0b07a9 Use alternate credential for blob service client 2023-09-07 13:07:41 -07:00
David Corbitt
ff621f2191 update llama versions 2023-09-07 12:05:18 -07:00
arcticfly
1e98972b6a Add docs for importing data (#217) 2023-09-07 11:10:00 -07:00
arcticfly
c5bca87486 Reword downloading status text (#216)
* Reword downloading status text

* Disable closing modal while sending in progress

* Auto-refresh on upload completion

* Make error red

* Remove red from error message
2023-09-07 10:49:25 -07:00
David Corbitt
fc1f15fee7 Remove experiment button 2023-09-07 09:35:49 -07:00
arcticfly
606a524c11 Allow bulk importing (#215)
* Remove console

* Add upload data modal

* Add dataset entry deletion and export

* Remove consoles

* Upload training data through Azure Blob Storage

* Give progress updates on downloads

* Delete rows in chunks

* Fix lint

* Add FileUploadsCard in bottom right corner
2023-09-07 09:19:47 -07:00
Kyle Corbitt
38e28fa30a benchmark comparison to gpt-3.5 and gpt-3.5 finetuned 2023-08-28 03:55:50 +00:00
52 changed files with 2919 additions and 1003 deletions

View File

@@ -40,3 +40,8 @@ SMTP_HOST="placeholder"
SMTP_PORT="placeholder"
SMTP_LOGIN="placeholder"
SMTP_PASSWORD="placeholder"
# Azure credentials are necessary for uploading large training data files
AZURE_STORAGE_ACCOUNT_NAME="placeholder"
AZURE_STORAGE_ACCOUNT_KEY="placeholder"
AZURE_STORAGE_CONTAINER_NAME="placeholder"

4
app/.gitignore vendored
View File

@@ -47,3 +47,7 @@ yarn-error.log*
# custom openai intialization
src/server/utils/openaiCustomConfig.json
# yalc
.yalc
yalc.lock

View File

@@ -26,6 +26,8 @@
"dependencies": {
"@anthropic-ai/sdk": "^0.5.8",
"@apidevtools/json-schema-ref-parser": "^10.1.0",
"@azure/identity": "^3.3.0",
"@azure/storage-blob": "12.15.0",
"@babel/standalone": "^7.22.9",
"@chakra-ui/anatomy": "^2.2.0",
"@chakra-ui/next-js": "^2.1.4",
@@ -69,6 +71,7 @@
"jsonschema": "^1.4.1",
"kysely": "^0.26.1",
"kysely-codegen": "^0.10.1",
"llama-tokenizer-js": "^1.1.3",
"lodash-es": "^4.17.21",
"lucide-react": "^0.265.0",
"marked": "^7.0.3",
@@ -79,7 +82,7 @@
"nextjs-routes": "^2.0.1",
"nodemailer": "^6.9.4",
"openai": "4.0.0-beta.7",
"openpipe": "^0.3.0",
"openpipe": "0.4.0-beta.1",
"openpipe-dev": "workspace:^",
"pg": "^8.11.2",
"pluralize": "^8.0.0",

View File

@@ -0,0 +1,5 @@
-- AlterTable
ALTER TABLE "DatasetEntry" ALTER COLUMN "loggedCallId" DROP NOT NULL,
ALTER COLUMN "inputTokens" DROP DEFAULT,
ALTER COLUMN "outputTokens" DROP DEFAULT,
ALTER COLUMN "type" DROP DEFAULT;

View File

@@ -0,0 +1,23 @@
-- CreateEnum
CREATE TYPE "DatasetFileUploadStatus" AS ENUM ('PENDING', 'DOWNLOADING', 'PROCESSING', 'SAVING', 'COMPLETE', 'ERROR');
-- CreateTable
CREATE TABLE "DatasetFileUpload" (
"id" UUID NOT NULL,
"datasetId" UUID NOT NULL,
"blobName" TEXT NOT NULL,
"fileName" TEXT NOT NULL,
"fileSize" INTEGER NOT NULL,
"progress" INTEGER NOT NULL DEFAULT 0,
"status" "DatasetFileUploadStatus" NOT NULL DEFAULT 'PENDING',
"uploadedAt" TIMESTAMP(3) NOT NULL,
"visible" BOOLEAN NOT NULL DEFAULT true,
"errorMessage" TEXT,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
CONSTRAINT "DatasetFileUpload_pkey" PRIMARY KEY ("id")
);
-- AddForeignKey
ALTER TABLE "DatasetFileUpload" ADD CONSTRAINT "DatasetFileUpload_datasetId_fkey" FOREIGN KEY ("datasetId") REFERENCES "Dataset"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -176,12 +176,41 @@ model OutputEvaluation {
@@unique([modelResponseId, evaluationId])
}
enum DatasetFileUploadStatus {
PENDING
DOWNLOADING
PROCESSING
SAVING
COMPLETE
ERROR
}
model DatasetFileUpload {
id String @id @default(uuid()) @db.Uuid
datasetId String @db.Uuid
dataset Dataset @relation(fields: [datasetId], references: [id], onDelete: Cascade)
blobName String
fileName String
fileSize Int
progress Int @default(0) // Percentage
status DatasetFileUploadStatus @default(PENDING)
uploadedAt DateTime
visible Boolean @default(true)
errorMessage String?
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
}
model Dataset {
id String @id @default(uuid()) @db.Uuid
name String
datasetEntries DatasetEntry[]
fineTunes FineTune[]
datasetFileUploads DatasetFileUpload[]
trainingRatio Float @default(0.8)
projectId String @db.Uuid
@@ -199,8 +228,8 @@ enum DatasetEntryType {
model DatasetEntry {
id String @id @default(uuid()) @db.Uuid
loggedCallId String @db.Uuid
loggedCall LoggedCall @relation(fields: [loggedCallId], references: [id], onDelete: Cascade)
loggedCallId String? @db.Uuid
loggedCall LoggedCall? @relation(fields: [loggedCallId], references: [id], onDelete: Cascade)
input Json @default("[]")
output Json?

View File

@@ -7,12 +7,14 @@ import { BetaModal } from "./BetaModal";
const ActionButton = ({
icon,
iconBoxSize = 3.5,
label,
requireBeta = false,
onClick,
...buttonProps
}: {
icon: IconType;
iconBoxSize?: number;
label: string;
requireBeta?: boolean;
onClick?: () => void;
@@ -39,7 +41,9 @@ const ActionButton = ({
{...buttonProps}
>
<HStack spacing={1}>
{icon && <Icon as={icon} color={requireBeta ? "orange.400" : undefined} />}
{icon && (
<Icon as={icon} boxSize={iconBoxSize} color={requireBeta ? "orange.400" : undefined} />
)}
<Text display={{ base: "none", md: "flex" }}>{label}</Text>
</HStack>
</Button>

View File

@@ -0,0 +1,107 @@
import {
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalCloseButton,
ModalBody,
ModalFooter,
HStack,
VStack,
Icon,
Text,
Button,
useDisclosure,
type UseDisclosureReturn,
} from "@chakra-ui/react";
import { BsTrash } from "react-icons/bs";
import { useHandledAsyncCallback, useDataset } from "~/utils/hooks";
import { api } from "~/utils/api";
import { useAppStore } from "~/state/store";
import ActionButton from "../ActionButton";
import { maybeReportError } from "~/utils/errorHandling/maybeReportError";
import pluralize from "pluralize";
const DeleteButton = () => {
const selectedIds = useAppStore((s) => s.selectedDatasetEntries.selectedIds);
const disclosure = useDisclosure();
return (
<>
<ActionButton
onClick={disclosure.onOpen}
label="Delete"
icon={BsTrash}
isDisabled={selectedIds.size === 0}
requireBeta
/>
<DeleteDatasetEntriesModal disclosure={disclosure} />
</>
);
};
export default DeleteButton;
const DeleteDatasetEntriesModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
const dataset = useDataset().data;
const selectedIds = useAppStore((s) => s.selectedDatasetEntries.selectedIds);
const clearSelectedIds = useAppStore((s) => s.selectedDatasetEntries.clearSelectedIds);
const deleteRowsMutation = api.datasetEntries.delete.useMutation();
const utils = api.useContext();
const [deleteRows, deletionInProgress] = useHandledAsyncCallback(async () => {
if (!dataset?.id || !selectedIds.size) return;
// divide selectedIds into chunks of 15000 to reduce request size
const chunkSize = 15000;
const idsArray = Array.from(selectedIds);
for (let i = 0; i < idsArray.length; i += chunkSize) {
const response = await deleteRowsMutation.mutateAsync({
ids: idsArray.slice(i, i + chunkSize),
});
if (maybeReportError(response)) return;
}
await utils.datasetEntries.list.invalidate();
disclosure.onClose();
clearSelectedIds();
}, [deleteRowsMutation, dataset, selectedIds, utils]);
return (
<Modal size={{ base: "xl", md: "2xl" }} {...disclosure}>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader>
<HStack>
<Icon as={BsTrash} />
<Text>Delete Logs</Text>
</HStack>
</ModalHeader>
<ModalCloseButton />
<ModalBody maxW="unset">
<VStack w="full" spacing={8} pt={4} alignItems="flex-start">
<Text>
Are you sure you want to delete the <b>{selectedIds.size}</b>{" "}
{pluralize("row", selectedIds.size)} rows you've selected?
</Text>
</VStack>
</ModalBody>
<ModalFooter>
<HStack>
<Button colorScheme="gray" onClick={disclosure.onClose} minW={24}>
Cancel
</Button>
<Button colorScheme="red" onClick={deleteRows} isLoading={deletionInProgress} minW={24}>
Delete
</Button>
</HStack>
</ModalFooter>
</ModalContent>
</Modal>
);
};

View File

@@ -0,0 +1,182 @@
import { useState, useEffect } from "react";
import {
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalCloseButton,
ModalBody,
ModalFooter,
HStack,
VStack,
Icon,
Text,
Button,
Checkbox,
NumberInput,
NumberInputField,
NumberInputStepper,
NumberIncrementStepper,
NumberDecrementStepper,
Collapse,
Flex,
useDisclosure,
type UseDisclosureReturn,
} from "@chakra-ui/react";
import { AiOutlineDownload } from "react-icons/ai";
import { useHandledAsyncCallback, useDataset } from "~/utils/hooks";
import { api } from "~/utils/api";
import { useAppStore } from "~/state/store";
import ActionButton from "../ActionButton";
import { FiChevronUp, FiChevronDown } from "react-icons/fi";
import InfoCircle from "../InfoCircle";
const ExportButton = () => {
const selectedIds = useAppStore((s) => s.selectedDatasetEntries.selectedIds);
const disclosure = useDisclosure();
return (
<>
<ActionButton
onClick={disclosure.onOpen}
label="Download"
icon={AiOutlineDownload}
isDisabled={selectedIds.size === 0}
requireBeta
/>
<ExportDatasetEntriesModal disclosure={disclosure} />
</>
);
};
export default ExportButton;
const ExportDatasetEntriesModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
const dataset = useDataset().data;
const selectedIds = useAppStore((s) => s.selectedDatasetEntries.selectedIds);
const clearSelectedIds = useAppStore((s) => s.selectedDatasetEntries.clearSelectedIds);
const [testingSplit, setTestingSplit] = useState(10);
const [removeDuplicates, setRemoveDuplicates] = useState(false);
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
useEffect(() => {
if (disclosure.isOpen) {
setTestingSplit(10);
setRemoveDuplicates(false);
}
}, [disclosure.isOpen]);
const exportDataMutation = api.datasetEntries.export.useMutation();
const [exportData, exportInProgress] = useHandledAsyncCallback(async () => {
if (!dataset?.id || !selectedIds.size || !testingSplit) return;
const response = await exportDataMutation.mutateAsync({
datasetId: dataset.id,
datasetEntryIds: Array.from(selectedIds),
testingSplit,
removeDuplicates,
});
const dataUrl = `data:application/pdf;base64,${response}`;
const blob = await fetch(dataUrl).then((res) => res.blob());
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = `data.zip`;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
disclosure.onClose();
clearSelectedIds();
}, [exportDataMutation, dataset, selectedIds, testingSplit, removeDuplicates]);
return (
<Modal size={{ base: "xl", md: "2xl" }} {...disclosure}>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader>
<HStack>
<Icon as={AiOutlineDownload} />
<Text>Export Logs</Text>
</HStack>
</ModalHeader>
<ModalCloseButton />
<ModalBody maxW="unset">
<VStack w="full" spacing={8} pt={4} alignItems="flex-start">
<Text>
We'll export the <b>{selectedIds.size}</b> rows you have selected in the OpenAI
training format.
</Text>
<VStack alignItems="flex-start" spacing={4}>
<Flex
flexDir={{ base: "column", md: "row" }}
alignItems={{ base: "flex-start", md: "center" }}
>
<HStack w={48} alignItems="center" spacing={1}>
<Text fontWeight="bold">Testing Split:</Text>
<InfoCircle tooltipText="The percent of your logs that will be reserved for testing and saved in another file. Logs are split randomly." />
</HStack>
<HStack>
<NumberInput
defaultValue={10}
onChange={(_, num) => setTestingSplit(num)}
min={1}
max={100}
w={48}
>
<NumberInputField />
<NumberInputStepper>
<NumberIncrementStepper />
<NumberDecrementStepper />
</NumberInputStepper>
</NumberInput>
</HStack>
</Flex>
</VStack>
<VStack alignItems="flex-start" spacing={0}>
<Button
variant="unstyled"
color="blue.600"
onClick={() => setShowAdvancedOptions(!showAdvancedOptions)}
>
<HStack>
<Text>Advanced Options</Text>
<Icon as={showAdvancedOptions ? FiChevronUp : FiChevronDown} />
</HStack>
</Button>
<Collapse in={showAdvancedOptions} unmountOnExit={true}>
<VStack align="stretch" pt={4}>
<HStack>
<Checkbox
colorScheme="blue"
isChecked={removeDuplicates}
onChange={(e) => setRemoveDuplicates(e.target.checked)}
>
<Text>Remove duplicates</Text>
</Checkbox>
<InfoCircle tooltipText="To avoid overfitting and speed up training, automatically deduplicate logs with matching input and output." />
</HStack>
</VStack>
</Collapse>
</VStack>
</VStack>
</ModalBody>
<ModalFooter>
<HStack>
<Button colorScheme="gray" onClick={disclosure.onClose} minW={24}>
Cancel
</Button>
<Button colorScheme="blue" onClick={exportData} isLoading={exportInProgress} minW={24}>
Download
</Button>
</HStack>
</ModalFooter>
</ModalContent>
</Modal>
);
};

View File

@@ -0,0 +1,139 @@
import { useState, useEffect } from "react";
import { VStack, HStack, Button, Text, Progress, IconButton, Portal } from "@chakra-ui/react";
import { BsX } from "react-icons/bs";
import { type RouterOutputs, api } from "~/utils/api";
import { useDataset, useHandledAsyncCallback } from "~/utils/hooks";
import { formatFileSize } from "~/utils/utils";
type FileUpload = RouterOutputs["datasets"]["listFileUploads"][0];
const FileUploadsCard = () => {
const dataset = useDataset();
const [fileUploadsRefetchInterval, setFileUploadsRefetchInterval] = useState<number>(500);
const fileUploads = api.datasets.listFileUploads.useQuery(
{ datasetId: dataset.data?.id as string },
{ enabled: !!dataset.data?.id, refetchInterval: fileUploadsRefetchInterval },
);
useEffect(() => {
if (fileUploads?.data?.some((fu) => fu.status !== "COMPLETE" && fu.status !== "ERROR")) {
setFileUploadsRefetchInterval(500);
} else {
setFileUploadsRefetchInterval(15000);
}
}, [fileUploads]);
const utils = api.useContext();
const hideFileUploadsMutation = api.datasets.hideFileUploads.useMutation();
const [hideAllFileUploads] = useHandledAsyncCallback(async () => {
if (!fileUploads.data?.length) return;
await hideFileUploadsMutation.mutateAsync({
fileUploadIds: fileUploads.data.map((upload) => upload.id),
});
await utils.datasets.listFileUploads.invalidate();
}, [hideFileUploadsMutation, fileUploads.data, utils]);
if (!fileUploads.data?.length) return null;
return (
<Portal>
<VStack
w={72}
borderRadius={8}
position="fixed"
bottom={8}
right={8}
overflow="hidden"
borderWidth={1}
boxShadow="0 0 40px 4px rgba(0, 0, 0, 0.1);"
minW={0}
bgColor="white"
>
<HStack p={4} w="full" bgColor="gray.200" justifyContent="space-between">
<Text fontWeight="bold">Uploads</Text>
<IconButton
aria-label="Close uploads"
as={BsX}
boxSize={6}
minW={0}
variant="ghost"
onClick={hideAllFileUploads}
cursor="pointer"
/>
</HStack>
{fileUploads?.data?.map((upload) => <FileUploadRow key={upload.id} fileUpload={upload} />)}
</VStack>
</Portal>
);
};
export default FileUploadsCard;
const FileUploadRow = ({ fileUpload }: { fileUpload: FileUpload }) => {
const { id, fileName, fileSize, progress, status, errorMessage } = fileUpload;
const utils = api.useContext();
const hideFileUploadsMutation = api.datasets.hideFileUploads.useMutation();
const [hideFileUpload, hidingInProgress] = useHandledAsyncCallback(async () => {
await hideFileUploadsMutation.mutateAsync({ fileUploadIds: [id] });
await utils.datasets.listFileUploads.invalidate();
}, [id, hideFileUploadsMutation, utils]);
useEffect(() => {
// Invalidate dataset entries list when upload is processed
if (status === "COMPLETE") void utils.datasetEntries.list.invalidate();
}, [status, utils]);
return (
<VStack w="full" alignItems="flex-start" p={4} borderBottomWidth={1}>
<HStack w="full" justifyContent="space-between" alignItems="flex-start">
<VStack alignItems="flex-start" spacing={0}>
<Text fontWeight="bold">{fileName}</Text>
<Text fontSize="xs">({formatFileSize(fileSize, 2)})</Text>
</VStack>
<Button
aria-label="Hide file upload"
minW={0}
variant="ghost"
isLoading={hidingInProgress}
onClick={hideFileUpload}
size="xs"
>
HIDE
</Button>
</HStack>
{errorMessage ? (
<Text alignSelf="center" pt={2}>
{errorMessage}
</Text>
) : (
<>
<Text alignSelf="center" fontSize="xs">
{getStatusText(status)}
</Text>
<Progress w="full" value={progress} borderRadius={2} />
</>
)}
</VStack>
);
};
const getStatusText = (status: FileUpload["status"]) => {
switch (status) {
case "PENDING":
return "Pending";
case "DOWNLOADING":
return "Loading Data";
case "PROCESSING":
return "Processing";
case "SAVING":
return "Saving";
case "COMPLETE":
return "Complete";
case "ERROR":
return "Error";
}
};

View File

@@ -22,10 +22,9 @@ import { useRouter } from "next/router";
import { useDataset, useDatasetEntries, useHandledAsyncCallback } from "~/utils/hooks";
import { api } from "~/utils/api";
import { useAppStore } from "~/state/store";
import ActionButton from "../ActionButton";
import InputDropdown from "../InputDropdown";
import { FiChevronDown } from "react-icons/fi";
// import { FiChevronDown } from "react-icons/fi";
const SUPPORTED_BASE_MODELS = ["llama2-7b", "llama2-13b", "llama2-70b", "gpt-3.5-turbo"];
@@ -53,7 +52,6 @@ const FineTuneButton = () => {
export default FineTuneButton;
const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
const selectedProjectId = useAppStore((s) => s.selectedProjectId);
const dataset = useDataset().data;
const datasetEntries = useDatasetEntries().data;
@@ -73,7 +71,7 @@ const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
const createFineTuneMutation = api.fineTunes.create.useMutation();
const [createFineTune, creationInProgress] = useHandledAsyncCallback(async () => {
if (!selectedProjectId || !modelSlug || !selectedBaseModel || !dataset) return;
if (!modelSlug || !selectedBaseModel || !dataset) return;
await createFineTuneMutation.mutateAsync({
slug: modelSlug,
baseModel: selectedBaseModel,
@@ -83,7 +81,7 @@ const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
await utils.fineTunes.list.invalidate();
await router.push({ pathname: "/fine-tunes" });
disclosure.onClose();
}, [createFineTuneMutation, selectedProjectId, modelSlug, selectedBaseModel]);
}, [createFineTuneMutation, modelSlug, selectedBaseModel]);
return (
<Modal size={{ base: "xl", md: "2xl" }} {...disclosure}>
@@ -133,12 +131,12 @@ const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
/>
</HStack>
</VStack>
<Button variant="unstyled" color="blue.600">
{/* <Button variant="unstyled" color="blue.600">
<HStack>
<Text>Advanced Options</Text>
<Icon as={FiChevronDown} />
</HStack>
</Button>
</Button> */}
</VStack>
</ModalBody>
<ModalFooter>

View File

@@ -0,0 +1,288 @@
import { useState, useEffect, useRef, useCallback } from "react";
import {
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalCloseButton,
ModalBody,
ModalFooter,
HStack,
VStack,
Icon,
Text,
Button,
Box,
useDisclosure,
type UseDisclosureReturn,
} from "@chakra-ui/react";
import pluralize from "pluralize";
import { AiOutlineCloudUpload, AiOutlineFile } from "react-icons/ai";
import { useDataset, useHandledAsyncCallback } from "~/utils/hooks";
import { api } from "~/utils/api";
import ActionButton from "../ActionButton";
import { validateTrainingRows, type TrainingRow, parseJSONL } from "./validateTrainingRows";
import { uploadDatasetEntryFile } from "~/utils/azure/website";
import { formatFileSize } from "~/utils/utils";
const UploadDataButton = () => {
const disclosure = useDisclosure();
return (
<>
<ActionButton
onClick={disclosure.onOpen}
label="Upload Data"
icon={AiOutlineCloudUpload}
iconBoxSize={4}
requireBeta
/>
<UploadDataModal disclosure={disclosure} />
</>
);
};
export default UploadDataButton;
const UploadDataModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
const dataset = useDataset().data;
const [validationError, setValidationError] = useState<string | null>(null);
const [trainingRows, setTrainingRows] = useState<TrainingRow[] | null>(null);
const [file, setFile] = useState<File | null>(null);
const fileInputRef = useRef<HTMLInputElement>(null);
const handleFileDrop = (e: React.DragEvent<HTMLDivElement>) => {
e.preventDefault();
const files = e.dataTransfer.files;
if (files.length > 0) {
processFile(files[0] as File);
}
};
const handleFileChange = (e: React.ChangeEvent<HTMLInputElement>) => {
const files = e.target.files;
if (files && files.length > 0) {
processFile(files[0] as File);
}
};
const processFile = (file: File) => {
setFile(file);
// skip reading if file is larger than 10MB
if (file.size > 10000000) {
setTrainingRows(null);
return;
}
const reader = new FileReader();
reader.onload = (e: ProgressEvent<FileReader>) => {
const content = e.target?.result as string;
// Process the content, e.g., set to state
let parsedJSONL;
try {
parsedJSONL = parseJSONL(content) as TrainingRow[];
const validationError = validateTrainingRows(parsedJSONL);
if (validationError) {
setValidationError(validationError);
setTrainingRows(null);
return;
}
setTrainingRows(parsedJSONL);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
setValidationError("Unable to parse JSONL file: " + (e.message as string));
setTrainingRows(null);
return;
}
};
reader.readAsText(file);
};
const resetState = useCallback(() => {
setValidationError(null);
setTrainingRows(null);
setFile(null);
}, [setValidationError, setTrainingRows, setFile]);
useEffect(() => {
if (disclosure.isOpen) {
resetState();
}
}, [disclosure.isOpen, resetState]);
const triggerFileDownloadMutation = api.datasets.triggerFileDownload.useMutation();
const utils = api.useContext();
const [sendJSONL, sendingInProgress] = useHandledAsyncCallback(async () => {
if (!dataset || !file) return;
const blobName = await uploadDatasetEntryFile(file);
await triggerFileDownloadMutation.mutateAsync({
datasetId: dataset.id,
blobName,
fileName: file.name,
fileSize: file.size,
});
await utils.datasets.listFileUploads.invalidate();
disclosure.onClose();
}, [dataset, trainingRows, triggerFileDownloadMutation, file, utils]);
return (
<Modal
size={{ base: "xl", md: "2xl" }}
closeOnOverlayClick={false}
closeOnEsc={false}
{...disclosure}
>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader>
<HStack>
<Text>Upload Training Logs</Text>
</HStack>
</ModalHeader>
{!sendingInProgress && <ModalCloseButton />}
<ModalBody maxW="unset" p={8}>
<Box w="full" aspectRatio={1.5}>
{validationError && (
<VStack w="full" h="full" justifyContent="center" spacing={8}>
<Icon as={AiOutlineFile} boxSize={24} color="gray.300" />
<VStack w="full">
<Text fontSize={32} color="gray.500" fontWeight="bold">
Error
</Text>
<Text color="gray.500">{validationError}</Text>
</VStack>
<Text
as="span"
textDecor="underline"
color="gray.500"
_hover={{ color: "orange.400" }}
cursor="pointer"
onClick={resetState}
>
Try again
</Text>
</VStack>
)}
{!validationError && !file && (
<VStack
w="full"
h="full"
stroke="gray.300"
justifyContent="center"
borderRadius={8}
sx={{
"background-image": `url("data:image/svg+xml,%3csvg width='100%25' height='100%25' xmlns='http://www.w3.org/2000/svg'%3e%3crect x='2%25' y='2%25' width='96%25' height='96%25' fill='none' stroke='%23eee' stroke-width='4' stroke-dasharray='6%2c 14' stroke-dashoffset='0' stroke-linecap='square' rx='8' ry='8'/%3e%3c/svg%3e")`,
}}
onDragOver={(e) => e.preventDefault()}
onDrop={handleFileDrop}
>
<JsonFileIcon />
<Icon as={AiOutlineCloudUpload} boxSize={24} color="gray.300" />
<Text fontSize={32} color="gray.500" fontWeight="bold">
Drag & Drop
</Text>
<Text color="gray.500">
your .jsonl file here, or{" "}
<input
type="file"
ref={fileInputRef}
onChange={handleFileChange}
style={{ display: "none" }}
accept=".jsonl"
/>
<Text
as="span"
textDecor="underline"
_hover={{ color: "orange.400" }}
cursor="pointer"
onClick={() => fileInputRef.current?.click()}
>
browse
</Text>
</Text>
</VStack>
)}
{!validationError && file && (
<VStack w="full" h="full" justifyContent="center" spacing={8}>
<JsonFileIcon />
<VStack w="full">
{trainingRows ? (
<>
<Text fontSize={32} color="gray.500" fontWeight="bold">
Success
</Text>
<Text color="gray.500">
We'll upload <b>{trainingRows.length}</b>{" "}
{pluralize("row", trainingRows.length)} into <b>{dataset?.name}</b>.{" "}
</Text>
</>
) : (
<>
<Text fontSize={32} color="gray.500" fontWeight="bold">
{file.name}
</Text>
<Text color="gray.500">{formatFileSize(file.size)}</Text>
</>
)}
</VStack>
{!sendingInProgress && (
<Text
as="span"
textDecor="underline"
color="gray.500"
_hover={{ color: "orange.400" }}
cursor="pointer"
onClick={resetState}
>
Change file
</Text>
)}
</VStack>
)}
</Box>
</ModalBody>
<ModalFooter>
<HStack>
<Button
colorScheme="gray"
isDisabled={sendingInProgress}
onClick={disclosure.onClose}
minW={24}
>
Cancel
</Button>
<Button
colorScheme="orange"
onClick={sendJSONL}
isLoading={sendingInProgress}
minW={24}
isDisabled={!file || !!validationError}
>
Upload
</Button>
</HStack>
</ModalFooter>
</ModalContent>
</Modal>
);
};
const JsonFileIcon = () => (
<Box position="relative" display="flex" alignItems="center" justifyContent="center">
<Icon as={AiOutlineFile} boxSize={24} color="gray.300" />
<Text position="absolute" color="orange.400" fontWeight="bold" fontSize={12} pt={4}>
JSONL
</Text>
</Box>
);

View File

@@ -0,0 +1,71 @@
import { type CreateChatCompletionRequestMessage } from "openai/resources/chat";
export type TrainingRow = {
input: CreateChatCompletionRequestMessage[];
output?: CreateChatCompletionRequestMessage;
};
export const parseJSONL = (jsonlString: string): unknown[] => {
const lines = jsonlString.trim().split("\n");
let lineNumber = 0;
const parsedLines = [];
try {
for (const line of lines) {
lineNumber++;
parsedLines.push(JSON.parse(line));
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
throw new Error(`Error parsing line ${lineNumber}: ${e.message as string}`);
}
return parsedLines;
};
export const validateTrainingRows = (rows: unknown): string | null => {
if (!Array.isArray(rows)) return "training data is not an array";
for (let i = 0; i < rows.length; i++) {
const row = rows[i] as TrainingRow;
let errorMessage: string | null = null;
try {
errorMessage = validateTrainingRow(row);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (error: any) {
errorMessage = error.message;
}
if (errorMessage) return `row ${i + 1}: ${errorMessage}`;
}
return null;
};
const validateTrainingRow = (row: TrainingRow): string | null => {
if (!row) return "empty row";
if (!row.input) return "missing input";
// Validate input
if (!Array.isArray(row.input)) return "input is not an array";
if ((row.input as unknown[]).some((x) => typeof x !== "object"))
return "input contains invalid item";
if (row.input.some((x) => !x)) return "input contains empty item";
if (row.input.some((x) => !x.content && !x.function_call))
return "input contains item with no content or function_call";
if (row.input.some((x) => x.function_call && !x.function_call.arguments))
return "input contains item with function_call but no arguments";
if (row.input.some((x) => x.function_call && !x.function_call.name))
return "input contains item with function_call but no name";
// Validate output
if (row.output) {
if (typeof row.output !== "object") return "output is not an object";
if (!row.output.content && !row.output.function_call)
return "output contains no content or function_call";
if (row.output.function_call && !row.output.function_call.arguments)
return "output contains function_call but no arguments";
if (row.output.function_call && !row.output.function_call.name)
return "output contains function_call but no name";
}
return null;
};

View File

@@ -26,6 +26,9 @@ export const env = createEnv({
SMTP_PORT: z.string().default("placeholder"),
SMTP_LOGIN: z.string().default("placeholder"),
SMTP_PASSWORD: z.string().default("placeholder"),
AZURE_STORAGE_ACCOUNT_NAME: z.string().default("placeholder"),
AZURE_STORAGE_ACCOUNT_KEY: z.string().default("placeholder"),
AZURE_STORAGE_CONTAINER_NAME: z.string().default("placeholder"),
WORKER_CONCURRENCY: z
.string()
.default("10")
@@ -72,6 +75,9 @@ export const env = createEnv({
SMTP_PORT: process.env.SMTP_PORT,
SMTP_LOGIN: process.env.SMTP_LOGIN,
SMTP_PASSWORD: process.env.SMTP_PASSWORD,
AZURE_STORAGE_ACCOUNT_NAME: process.env.AZURE_STORAGE_ACCOUNT_NAME,
AZURE_STORAGE_ACCOUNT_KEY: process.env.AZURE_STORAGE_ACCOUNT_KEY,
AZURE_STORAGE_CONTAINER_NAME: process.env.AZURE_STORAGE_CONTAINER_NAME,
WORKER_CONCURRENCY: process.env.WORKER_CONCURRENCY,
WORKER_MAX_POOL_SIZE: process.env.WORKER_MAX_POOL_SIZE,
},

View File

@@ -8,9 +8,9 @@ const replicate = new Replicate({
});
const modelIds: Record<ReplicateLlama2Input["model"], string> = {
"7b-chat": "d24902e3fa9b698cc208b5e63136c4e26e828659a9f09827ca6ec5bb83014381",
"13b-chat": "9dff94b1bed5af738655d4a7cbcdcde2bd503aa85c94334fe1f42af7f3dd5ee3",
"70b-chat": "2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1",
"7b-chat": "658b64a1e83d7caaba4ef10d5ee9c12c40770003f45852f05c2564962f921d8e",
"13b-chat": "7457c09004773f9f9710f7eb3b270287ffcebcfb23a13c8ec30cfb98f6bff9b2",
"70b-chat": "4dfd64cc207097970659087cf5670e3c1fbe02f83aa0f751e079cfba72ca790a",
};
export async function getCompletion(

View File

@@ -10,9 +10,9 @@ import {
useDisclosure,
} from "@chakra-ui/react";
import Link from "next/link";
import { useState, useEffect } from "react";
import { AiOutlineDatabase } from "react-icons/ai";
import AppShell from "~/components/nav/AppShell";
import { api } from "~/utils/api";
import { useDataset, useHandledAsyncCallback } from "~/utils/hooks";
@@ -24,7 +24,11 @@ import DatasetEntriesTable from "~/components/datasets/DatasetEntriesTable/Datas
import DatasetEntryPaginator from "~/components/datasets/DatasetEntryPaginator";
import { useAppStore } from "~/state/store";
import FineTuneButton from "~/components/datasets/FineTuneButton";
import ExperimentButton from "~/components/datasets/ExperimentButton";
// import ExperimentButton from "~/components/datasets/ExperimentButton";
import UploadDataButton from "~/components/datasets/UploadDataButton";
// import DownloadButton from "~/components/datasets/DownloadButton";
import DeleteButton from "~/components/datasets/DeleteButton";
import FileUploadsCard from "~/components/datasets/FileUploadsCard";
export default function Dataset() {
const utils = api.useContext();
@@ -100,12 +104,16 @@ export default function Dataset() {
<VStack px={8} py={8} alignItems="flex-start" spacing={4} w="full">
<HStack w="full" justifyContent="flex-end">
<FineTuneButton />
<ExperimentButton />
<UploadDataButton />
{/* <ExperimentButton /> */}
{/* <DownloadButton /> */}
<DeleteButton />
</HStack>
<DatasetEntriesTable />
<DatasetEntryPaginator />
</VStack>
</VStack>
<FileUploadsCard />
</AppShell>
<DatasetConfigurationDrawer disclosure={drawerDisclosure} />
</>

View File

@@ -1,4 +1,3 @@
import { type Prisma } from "@prisma/client";
import { z } from "zod";
import { v4 as uuidv4 } from "uuid";
import {
@@ -7,13 +6,18 @@ import {
type CreateChatCompletionRequestMessage,
} from "openai/resources/chat";
import { TRPCError } from "@trpc/server";
import { shuffle } from "lodash-es";
import archiver from "archiver";
import { WritableStreamBuffer } from "stream-buffers";
import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { requireCanModifyProject, requireCanViewProject } from "~/utils/accessControl";
import { error, success } from "~/utils/errorHandling/standardResponses";
import { countOpenAIChatTokens } from "~/utils/countTokens";
import { type TrainingRow } from "~/components/datasets/validateTrainingRows";
import hashObject from "~/server/utils/hashObject";
import { type JsonValue } from "type-fest";
import { formatEntriesFromTrainingRows } from "~/server/utils/createEntriesFromTrainingRows";
export const datasetEntriesRouter = createTRPCRouter({
list: protectedProcedure
@@ -94,7 +98,7 @@ export const datasetEntriesRouter = createTRPCRouter({
name: z.string(),
})
.optional(),
loggedCallIds: z.string().array(),
loggedCallIds: z.string().array().optional(),
}),
)
.mutation(async ({ input, ctx }) => {
@@ -115,50 +119,33 @@ export const datasetEntriesRouter = createTRPCRouter({
return error("No datasetId or newDatasetParams provided");
}
const [loggedCalls, existingTrainingCount, existingTestingCount] = await prisma.$transaction([
prisma.loggedCall.findMany({
where: {
id: {
in: input.loggedCallIds,
},
modelResponse: {
isNot: null,
if (!input.loggedCallIds) {
return error("No loggedCallIds provided");
}
const loggedCalls = await prisma.loggedCall.findMany({
where: {
id: {
in: input.loggedCallIds,
},
modelResponse: {
isNot: null,
},
},
include: {
modelResponse: {
select: {
reqPayload: true,
respPayload: true,
inputTokens: true,
outputTokens: true,
},
},
include: {
modelResponse: {
select: {
reqPayload: true,
respPayload: true,
inputTokens: true,
outputTokens: true,
},
},
},
}),
prisma.datasetEntry.count({
where: {
datasetId,
type: "TRAIN",
},
}),
prisma.datasetEntry.count({
where: {
datasetId,
type: "TEST",
},
}),
]);
},
orderBy: { createdAt: "desc" },
});
const shuffledLoggedCalls = shuffle(loggedCalls);
const totalEntries = existingTrainingCount + existingTestingCount + loggedCalls.length;
const numTrainingToAdd = Math.floor(trainingRatio * totalEntries) - existingTrainingCount;
const datasetEntriesToCreate: Prisma.DatasetEntryCreateManyInput[] = [];
let i = 0;
for (const loggedCall of shuffledLoggedCalls) {
const trainingRows = loggedCalls.map((loggedCall) => {
const inputMessages = (
loggedCall.modelResponse?.reqPayload as unknown as CompletionCreateParams
).messages;
@@ -166,24 +153,14 @@ export const datasetEntriesRouter = createTRPCRouter({
const resp = loggedCall.modelResponse?.respPayload as unknown as ChatCompletion | undefined;
if (resp && resp.choices?.[0]) {
output = resp.choices[0].message;
} else {
output = {
role: "assistant",
content: "",
};
}
return {
input: inputMessages as unknown as CreateChatCompletionRequestMessage[],
output: output as unknown as CreateChatCompletionRequestMessage,
};
});
datasetEntriesToCreate.push({
datasetId,
loggedCallId: loggedCall.id,
input: inputMessages as unknown as Prisma.InputJsonValue,
output: output as unknown as Prisma.InputJsonValue,
inputTokens: loggedCall.modelResponse?.inputTokens || 0,
outputTokens: loggedCall.modelResponse?.outputTokens || 0,
type: i < numTrainingToAdd ? "TRAIN" : "TEST",
});
i++;
}
const datasetEntriesToCreate = await formatEntriesFromTrainingRows(datasetId, trainingRows);
// Ensure dataset and dataset entries are created atomically
await prisma.$transaction([
@@ -198,13 +175,12 @@ export const datasetEntriesRouter = createTRPCRouter({
},
}),
prisma.datasetEntry.createMany({
data: shuffle(datasetEntriesToCreate),
data: datasetEntriesToCreate,
}),
]);
return success(datasetId);
}),
update: protectedProcedure
.input(
z.object({
@@ -242,7 +218,8 @@ export const datasetEntriesRouter = createTRPCRouter({
let parsedOutput = undefined;
let outputTokens = undefined;
if (input.updates.output) {
// The client might send "null" as a string, so we need to check for that
if (input.updates.output && input.updates.output !== "null") {
parsedOutput = JSON.parse(input.updates.output);
outputTokens = countOpenAIChatTokens("gpt-4-0613", [
parsedOutput as unknown as ChatCompletion.Choice.Message,
@@ -293,4 +270,68 @@ export const datasetEntriesRouter = createTRPCRouter({
return success("Dataset entries deleted");
}),
export: protectedProcedure
.input(
z.object({
datasetId: z.string(),
datasetEntryIds: z.string().array(),
testingSplit: z.number(),
removeDuplicates: z.boolean(),
}),
)
.mutation(async ({ input, ctx }) => {
const { projectId } = await prisma.dataset.findUniqueOrThrow({
where: { id: input.datasetId },
});
await requireCanViewProject(projectId, ctx);
const datasetEntries = await ctx.prisma.datasetEntry.findMany({
where: {
id: {
in: input.datasetEntryIds,
},
},
});
let rows: TrainingRow[] = datasetEntries.map((entry) => ({
input: entry.input as unknown as CreateChatCompletionRequestMessage[],
output: entry.output as unknown as CreateChatCompletionRequestMessage,
}));
if (input.removeDuplicates) {
const deduplicatedRows = [];
const rowHashSet = new Set<string>();
for (const row of rows) {
const rowHash = hashObject(row as unknown as JsonValue);
if (!rowHashSet.has(rowHash)) {
rowHashSet.add(rowHash);
deduplicatedRows.push(row);
}
}
rows = deduplicatedRows;
}
const splitIndex = Math.floor((rows.length * input.testingSplit) / 100);
const testingData = rows.slice(0, splitIndex);
const trainingData = rows.slice(splitIndex);
// Convert arrays to JSONL format
const trainingDataJSONL = trainingData.map((item) => JSON.stringify(item)).join("\n");
const testingDataJSONL = testingData.map((item) => JSON.stringify(item)).join("\n");
const output = new WritableStreamBuffer();
const archive = archiver("zip");
archive.pipe(output);
archive.append(trainingDataJSONL, { name: "train.jsonl" });
archive.append(testingDataJSONL, { name: "test.jsonl" });
await archive.finalize();
// Convert buffer to base64
const base64 = output.getContents().toString("base64");
return base64;
}),
});

View File

@@ -1,8 +1,11 @@
import { z } from "zod";
import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { requireCanModifyProject, requireCanViewProject } from "~/utils/accessControl";
import { success } from "~/utils/errorHandling/standardResponses";
import { error, success } from "~/utils/errorHandling/standardResponses";
import { generateServiceClientUrl } from "~/utils/azure/server";
import { queueImportDatasetEntries } from "~/server/tasks/importDatasetEntries.task";
export const datasetsRouter = createTRPCRouter({
get: protectedProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
@@ -94,4 +97,87 @@ export const datasetsRouter = createTRPCRouter({
return success("Dataset deleted");
}),
getServiceClientUrl: protectedProcedure
.input(z.object({ projectId: z.string() }))
.query(async ({ input, ctx }) => {
// The user must at least be authenticated to get a SAS token
await requireCanModifyProject(input.projectId, ctx);
return generateServiceClientUrl();
}),
triggerFileDownload: protectedProcedure
.input(
z.object({
datasetId: z.string(),
blobName: z.string(),
fileName: z.string(),
fileSize: z.number(),
}),
)
.mutation(async ({ input, ctx }) => {
const { projectId } = await prisma.dataset.findUniqueOrThrow({
where: { id: input.datasetId },
});
await requireCanViewProject(projectId, ctx);
const { id } = await prisma.datasetFileUpload.create({
data: {
datasetId: input.datasetId,
blobName: input.blobName,
status: "PENDING",
fileName: input.fileName,
fileSize: input.fileSize,
uploadedAt: new Date(),
},
});
await queueImportDatasetEntries(id);
}),
listFileUploads: protectedProcedure
.input(z.object({ datasetId: z.string() }))
.query(async ({ input, ctx }) => {
const { projectId } = await prisma.dataset.findUniqueOrThrow({
where: { id: input.datasetId },
});
await requireCanViewProject(projectId, ctx);
return await prisma.datasetFileUpload.findMany({
where: {
datasetId: input.datasetId,
visible: true,
},
orderBy: { createdAt: "desc" },
});
}),
hideFileUploads: protectedProcedure
.input(z.object({ fileUploadIds: z.string().array() }))
.mutation(async ({ input, ctx }) => {
if (!input.fileUploadIds.length) return error("No file upload ids provided");
const {
dataset: { projectId, id: datasetId },
} = await prisma.datasetFileUpload.findUniqueOrThrow({
where: { id: input.fileUploadIds[0] },
select: {
dataset: {
select: {
id: true,
projectId: true,
},
},
},
});
await requireCanModifyProject(projectId, ctx);
await prisma.datasetFileUpload.updateMany({
where: {
id: {
in: input.fileUploadIds,
},
datasetId,
},
data: {
visible: false,
},
});
}),
});

View File

@@ -0,0 +1,152 @@
import { type DatasetFileUpload } from "@prisma/client";
import { prisma } from "~/server/db";
import defineTask from "./defineTask";
import { downloadBlobToString } from "~/utils/azure/server";
import {
type TrainingRow,
validateTrainingRows,
parseJSONL,
} from "~/components/datasets/validateTrainingRows";
import { formatEntriesFromTrainingRows } from "~/server/utils/createEntriesFromTrainingRows";
export type ImportDatasetEntriesJob = {
datasetFileUploadId: string;
};
export const importDatasetEntries = defineTask<ImportDatasetEntriesJob>(
"importDatasetEntries",
async (task) => {
const { datasetFileUploadId } = task;
const datasetFileUpload = await prisma.datasetFileUpload.findUnique({
where: { id: datasetFileUploadId },
});
if (!datasetFileUpload) {
await prisma.datasetFileUpload.update({
where: { id: datasetFileUploadId },
data: {
errorMessage: "Dataset File Upload not found",
status: "ERROR",
},
});
return;
}
await prisma.datasetFileUpload.update({
where: { id: datasetFileUploadId },
data: {
status: "DOWNLOADING",
progress: 5,
},
});
const onBlobDownloadProgress = async (progress: number) => {
await prisma.datasetFileUpload.update({
where: { id: datasetFileUploadId },
data: {
progress: 5 + Math.floor((progress / datasetFileUpload.fileSize) * 25),
},
});
};
const jsonlStr = await downloadBlobToString(datasetFileUpload.blobName, onBlobDownloadProgress);
let trainingRows: TrainingRow[] = [];
let validationError: string | null = null;
try {
trainingRows = parseJSONL(jsonlStr) as TrainingRow[];
validationError = validateTrainingRows(trainingRows);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
validationError = e.message;
}
if (validationError) {
await prisma.datasetFileUpload.update({
where: { id: datasetFileUploadId },
data: {
errorMessage: `Invalid JSONL: ${validationError}`,
status: "ERROR",
},
});
return;
}
await prisma.datasetFileUpload.update({
where: { id: datasetFileUploadId },
data: {
status: "PROCESSING",
progress: 30,
},
});
const updatePromises: Promise<DatasetFileUpload>[] = [];
const updateCallback = async (progress: number) => {
await prisma.datasetFileUpload.update({
where: { id: datasetFileUploadId },
data: {
progress: 30 + Math.floor((progress / trainingRows.length) * 69),
},
});
};
let datasetEntriesToCreate;
try {
datasetEntriesToCreate = await formatEntriesFromTrainingRows(
datasetFileUpload.datasetId,
trainingRows,
updateCallback,
500,
);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
await prisma.datasetFileUpload.update({
where: { id: datasetFileUploadId },
data: {
errorMessage: `Error formatting rows: ${e.message as string}`,
status: "ERROR",
visible: true,
},
});
return;
}
await Promise.all(updatePromises);
await prisma.datasetFileUpload.update({
where: { id: datasetFileUploadId },
data: {
status: "SAVING",
progress: 99,
},
});
await prisma.datasetEntry.createMany({
data: datasetEntriesToCreate,
});
await prisma.datasetFileUpload.update({
where: { id: datasetFileUploadId },
data: {
status: "COMPLETE",
progress: 100,
visible: true,
},
});
},
);
export const queueImportDatasetEntries = async (datasetFileUploadId: string) => {
await Promise.all([
prisma.datasetFileUpload.update({
where: {
id: datasetFileUploadId,
},
data: {
errorMessage: null,
status: "PENDING",
},
}),
importDatasetEntries.enqueue({ datasetFileUploadId }),
]);
};

View File

@@ -5,10 +5,11 @@ import "../../../sentry.server.config";
import { env } from "~/env.mjs";
import { queryModel } from "./queryModel.task";
import { runNewEval } from "./runNewEval.task";
import { importDatasetEntries } from "./importDatasetEntries.task";
console.log("Starting worker");
const registeredTasks = [queryModel, runNewEval];
const registeredTasks = [queryModel, runNewEval, importDatasetEntries];
const taskList = registeredTasks.reduce((acc, task) => {
acc[task.task.identifier] = task.task.handler;

View File

@@ -0,0 +1,70 @@
import { type Prisma } from "@prisma/client";
import { shuffle } from "lodash-es";
import {
type CreateChatCompletionRequestMessage,
type ChatCompletion,
} from "openai/resources/chat";
import { prisma } from "~/server/db";
import { type TrainingRow } from "~/components/datasets/validateTrainingRows";
import { countLlamaChatTokens } from "~/utils/countTokens";
export const formatEntriesFromTrainingRows = async (
datasetId: string,
trainingRows: TrainingRow[],
updateCallback?: (progress: number) => Promise<void>,
updateFrequency = 1000,
) => {
const [dataset, existingTrainingCount, existingTestingCount] = await prisma.$transaction([
prisma.dataset.findUnique({ where: { id: datasetId } }),
prisma.datasetEntry.count({
where: {
datasetId,
type: "TRAIN",
},
}),
prisma.datasetEntry.count({
where: {
datasetId,
type: "TEST",
},
}),
]);
const trainingRatio = dataset?.trainingRatio ?? 0.8;
const newTotalEntries = existingTrainingCount + existingTestingCount + trainingRows.length;
const numTrainingToAdd = Math.floor(trainingRatio * newTotalEntries) - existingTrainingCount;
const numTestingToAdd = trainingRows.length - numTrainingToAdd;
const typesToAssign = shuffle([
...Array(numTrainingToAdd).fill("TRAIN"),
...Array(numTestingToAdd).fill("TEST"),
]);
const datasetEntriesToCreate: Prisma.DatasetEntryCreateManyInput[] = [];
let i = 0;
for (const row of trainingRows) {
// console.log(row);
if (updateCallback && i % updateFrequency === 0) await updateCallback(i);
let outputTokens = 0;
if (row.output) {
outputTokens = countLlamaChatTokens([row.output as unknown as ChatCompletion.Choice.Message]);
}
// console.log("outputTokens", outputTokens);
datasetEntriesToCreate.push({
datasetId: datasetId,
input: row.input as unknown as Prisma.InputJsonValue,
output: (row.output as unknown as Prisma.InputJsonValue) ?? {
role: "assistant",
content: "",
},
inputTokens: countLlamaChatTokens(
row.input as unknown as CreateChatCompletionRequestMessage[],
),
outputTokens,
type: typesToAssign.pop() as "TRAIN" | "TEST",
});
i++;
}
return datasetEntriesToCreate;
};

View File

@@ -0,0 +1,93 @@
import {
BlobServiceClient,
generateAccountSASQueryParameters,
AccountSASPermissions,
AccountSASServices,
AccountSASResourceTypes,
StorageSharedKeyCredential,
SASProtocol,
} from "@azure/storage-blob";
const accountName = process.env.AZURE_STORAGE_ACCOUNT_NAME;
if (!accountName) throw Error("Azure Storage accountName not found");
const accountKey = process.env.AZURE_STORAGE_ACCOUNT_KEY;
if (!accountKey) throw Error("Azure Storage accountKey not found");
const containerName = process.env.AZURE_STORAGE_CONTAINER_NAME;
if (!containerName) throw Error("Azure Storage containerName not found");
const sharedKeyCredential = new StorageSharedKeyCredential(accountName, accountKey);
const blobServiceClient = new BlobServiceClient(
`https://${accountName}.blob.core.windows.net`,
sharedKeyCredential,
);
const containerClient = blobServiceClient.getContainerClient(containerName);
export const generateServiceClientUrl = () => {
const sasOptions = {
services: AccountSASServices.parse("b").toString(), // blobs
resourceTypes: AccountSASResourceTypes.parse("sco").toString(), // service, container, object
permissions: AccountSASPermissions.parse("w"), // write permissions
protocol: SASProtocol.Https,
startsOn: new Date(),
expiresOn: new Date(new Date().valueOf() + 10 * 60 * 1000), // 10 minutes
};
let sasToken = generateAccountSASQueryParameters(sasOptions, sharedKeyCredential).toString();
// remove leading "?"
sasToken = sasToken[0] === "?" ? sasToken.substring(1) : sasToken;
return {
serviceClientUrl: `https://${accountName}.blob.core.windows.net?${sasToken}`,
containerName,
};
};
export async function downloadBlobToString(
blobName: string,
onProgress?: (progress: number) => Promise<void>,
chunkInterval?: number,
) {
const blobClient = containerClient.getBlobClient(blobName);
const downloadResponse = await blobClient.download();
if (!downloadResponse) throw Error("error downloading blob");
if (!downloadResponse.readableStreamBody)
throw Error("downloadResponse.readableStreamBody not found");
const downloaded = await streamToBuffer(
downloadResponse.readableStreamBody,
onProgress,
chunkInterval,
);
return downloaded.toString();
}
async function streamToBuffer(
readableStream: NodeJS.ReadableStream,
onProgress?: (progress: number) => Promise<void>,
chunkInterval = 1048576, // send progress every 1MB
): Promise<Buffer> {
return new Promise((resolve, reject) => {
const chunks: Uint8Array[] = [];
let bytesDownloaded = 0;
let lastReportedByteCount = 0;
readableStream.on("data", (data: ArrayBuffer) => {
chunks.push(data instanceof Buffer ? data : Buffer.from(data));
bytesDownloaded += data.byteLength;
if (onProgress && bytesDownloaded - lastReportedByteCount >= chunkInterval) {
void onProgress(bytesDownloaded); // progress in Bytes
lastReportedByteCount = bytesDownloaded;
}
});
readableStream.on("end", () => {
resolve(Buffer.concat(chunks));
});
readableStream.on("error", reject);
});
}

View File

@@ -0,0 +1,30 @@
import { BlobServiceClient } from "@azure/storage-blob";
import { v4 as uuidv4 } from "uuid";
import { useAppStore } from "~/state/store";
export const uploadDatasetEntryFile = async (file: File) => {
const { selectedProjectId: projectId, api } = useAppStore.getState();
if (!projectId) throw Error("projectId not found");
if (!api) throw Error("api not initialized");
const { serviceClientUrl, containerName } = await api.client.datasets.getServiceClientUrl.query({
projectId,
});
const blobServiceClient = new BlobServiceClient(serviceClientUrl);
// create container client
const containerClient = blobServiceClient.getContainerClient(containerName);
// base name without extension
const basename = file.name.split("/").pop()?.split(".").shift();
if (!basename) throw Error("basename not found");
const blobName = `${basename}-${uuidv4()}.jsonl`;
// create blob client
const blobClient = containerClient.getBlockBlobClient(blobName);
// upload file
await blobClient.uploadData(file);
return blobName;
};

View File

@@ -1,5 +1,7 @@
import { type ChatCompletion } from "openai/resources/chat";
import { GPTTokens } from "gpt-tokens";
import llamaTokenizer from "llama-tokenizer-js";
import { type SupportedModel } from "~/modelProviders/openai-ChatCompletion";
interface GPTTokensMessageItem {
@@ -22,3 +24,11 @@ export const countOpenAIChatTokens = (
messages: reformattedMessages as unknown as GPTTokensMessageItem[],
}).usedTokens;
};
export const countLlamaChatTokens = (messages: ChatCompletion.Choice.Message[]) => {
const stringToTokenize = messages
.map((message) => message.content || JSON.stringify(message.function_call))
.join("\n");
const tokens = llamaTokenizer.encode(stringToTokenize);
return tokens.length;
};

View File

@@ -20,7 +20,6 @@ export const parseableToFunctionCall = (str: string) => {
} catch {
return false;
}
console.log("remove me");
// Check if the parsedJSON is an object and not null
if (typeof parsedJSON !== "object" || parsedJSON === null) {
@@ -53,3 +52,18 @@ export const parseableToFunctionCall = (str: string) => {
return true;
};
export const formatFileSize = (bytes: number, decimals = 2) => {
if (bytes === 0) return "0 Bytes";
const k = 1024;
const dm = decimals < 0 ? 0 : decimals;
const sizes = ["Bytes", "KB", "MB", "GB", "TB"];
for (const size of sizes) {
if (bytes < k) return `${parseFloat(bytes.toFixed(dm))} ${size}`;
bytes /= k;
}
return "> 1024 TB";
};

View File

@@ -19,7 +19,9 @@
"baseUrl": ".",
"paths": {
"~/*": ["./src/*"]
}
},
"typeRoots": ["./types", "./node_modules/@types"],
"types": ["llama-tokenizer-js", "node"]
},
"include": [
".eslintrc.cjs",

View File

@@ -0,0 +1,4 @@
declare module "llama-tokenizer-js" {
export function encode(input: string): number[];
export function decode(input: number[]): string;
}

View File

@@ -1,27 +1,28 @@
#!/usr/bin/env bash
# Adapted from https://github.com/openai/openai-node/blob/master/build
set -exuo pipefail
rm -rf dist /tmp/openpipe-build-dist
rm -rf dist
mkdir /tmp/openpipe-build-dist
npx tsup
cp -rp * /tmp/openpipe-build-dist
# copy the package.json file to /dist
cp package.json dist
# copy the README.md file to /dist
cp README.md dist
# Rename package name in package.json
python3 -c "
import json
with open('/tmp/openpipe-build-dist/package.json', 'r') as f:
data = json.load(f)
# Load the package.json file
with open('dist/package.json', 'r') as file:
data = json.load(file)
# Change the names
data['name'] = 'openpipe'
with open('/tmp/openpipe-build-dist/package.json', 'w') as f:
json.dump(data, f, indent=4)
# Write the changes back to the package.json file
with open('dist/package.json', 'w') as file:
json.dump(data, file, indent=2)
"
rm -rf /tmp/openpipe-build-dist/node_modules
mv /tmp/openpipe-build-dist dist
# build to .js files
(cd dist && npm exec tsc -- --noEmit false)

View File

@@ -1,16 +1,34 @@
{
"name": "openpipe-dev",
"version": "0.3.5",
"version": "0.4.0-beta.3",
"type": "module",
"description": "Metrics and auto-evaluation for LLM calls",
"scripts": {
"build": "./build.sh",
"build-update": "./build.sh && ./update-app.sh",
"test": "vitest"
},
"main": "./index.ts",
"main": "./src/index.ts",
"publishConfig": {
"name": "openpipe",
"access": "public",
"main": "./index.js"
"main": "./index.cjs",
"module": "./index.js",
"types": "./index.d.ts",
"exports": {
".": {
"import": "./index.js",
"require": "./index.cjs"
},
"./openai": {
"import": "./openai.js",
"require": "./openai.cjs"
},
"./openai/mergeChunks": {
"import": "./openai/mergeChunks.js",
"require": "./openai/mergeChunks.cjs"
}
}
},
"keywords": [],
"author": "",
@@ -24,10 +42,16 @@
"openai-legacy": "npm:openai@3.3.0"
},
"devDependencies": {
"@rollup/plugin-json": "^6.0.0",
"@rollup/plugin-node-resolve": "^15.2.1",
"@types/lodash-es": "^4.17.8",
"@types/node": "^20.4.8",
"@types/node-fetch": "^2.6.4",
"dotenv": "^16.3.1",
"rollup": "^3.28.1",
"rollup-plugin-typescript2": "^0.35.0",
"tslib": "^2.6.2",
"tsup": "^7.2.0",
"tsx": "^3.12.7",
"typescript": "^5.0.4",
"vitest": "^0.33.0"

View File

@@ -6,4 +6,4 @@ set -exuo pipefail
./build.sh
(cd dist && pnpm publish --access public)
(cd dist && pnpm publish --access public --tag beta --no-git-checks)

View File

@@ -7,10 +7,10 @@ import {
CompletionCreateParams,
} from "openai-beta/resources/chat/completions";
import { WrappedStream } from "./streaming";
import { WrappedStream } from "./openai/streaming";
import { DefaultService, OPClient } from "../codegen";
import { Stream } from "openai-beta/streaming";
import { OpenPipeArgs, OpenPipeMeta, type OpenPipeConfig, getTags } from "../shared";
import { OpenPipeArgs, OpenPipeMeta, type OpenPipeConfig, getTags } from "./shared";
export type ClientOptions = openai.ClientOptions & { openpipe?: OpenPipeConfig };
export default class OpenAI extends openai.OpenAI {

View File

@@ -1,12 +1,12 @@
import dotenv from "dotenv";
import { expect, test } from "vitest";
import OpenAI from ".";
import OpenAI from "../openai";
import {
ChatCompletion,
CompletionCreateParams,
CreateChatCompletionRequestMessage,
} from "openai-beta/resources/chat/completions";
import { OPClient } from "../codegen";
import { OPClient } from "../../codegen";
import mergeChunks from "./mergeChunks";
import assert from "assert";

View File

@@ -1,6 +1,6 @@
import pkg from "./package.json";
import pkg from "../package.json";
import { DefaultService } from "./codegen";
import { DefaultService } from "../codegen";
export type OpenPipeConfig = {
apiKey?: string;

View File

@@ -12,7 +12,6 @@
"moduleResolution": "node",
"resolveJsonModule": true,
"isolatedModules": true,
"incremental": true,
"noUncheckedIndexedAccess": true,
"noEmit": true,
"sourceMap": true,

View File

@@ -0,0 +1,24 @@
import { Options } from "tsup";
const config: Options = {
splitting: false, // Disable code splitting
sourcemap: true, // Include sourcemaps
target: "es2020", // Target ES2020 syntax for modern environments
dts: true, // Generate declaration files
// Define entry points
entry: [
"src/index.ts", // Main entry
"src/openai.ts", // 'openai' sub-module
"src/openai/mergeChunks.ts", // 'openai/mergeChunks' sub-module
"src/openai/streaming.ts", // 'openai/streaming' sub-module
],
// Define format of the output bundles
format: ["cjs", "esm"],
// External libraries that shouldn't be bundled
external: [...Object.keys(require("./package.json").dependencies || {})],
};
export default config;

View File

@@ -0,0 +1,4 @@
(cd dist && yalc publish)
(cd ../../app && yalc add openpipe)
(cd ../../app && pnpm install)

View File

@@ -0,0 +1,19 @@
---
title: "Import Data - Beta"
sidebarTitle: "Import Data"
description: "
Import external data to kickstart your fine-tuning process. Use the OpenAI chat fine-tuning format."
---
Upload a JSONL file populated with an array of OpenAI messages as input, and a single message as output.
<Frame>![](/images/features/importing-data.png)</Frame>
Compatible with both typical content messages and function calls. Input arrays must contain one or more messages.
```jsonl
...
{"input":[{"role":"system","content":"You are a helpful assistant"},{"role":"user","content":"What is the capitol of Sweden?"}],"output":{"role":"assistant","content":null,"function_call":{"name":"log_capitol","arguments":"{\"capitol\":\"Stockholm\"}"}}}
{"input":[{"role":"system","content":"You are a helpful assistant"},{"role":"user","content":"What is the capitol of Tasmania?"}],"output":{"role":"assistant","content":null,"function_call":{"name":"log_capitol","arguments":"{\"capitol\":\"Hobart\"}"}}}
...
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 408 KiB

View File

@@ -41,6 +41,7 @@
"group": "Features",
"pages": [
"features/log-filters",
"features/importing-data",
"features/exporting-data",
"features/fine-tuning",
"features/experiments"

4
examples/.gitignore vendored
View File

@@ -1,4 +1,6 @@
axolotl/
models/
data/
wandb/
wandb/
cache/
.ipynb_checkpoints/

View File

@@ -1,473 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Current Time: 2023-08-24 21:25:06\n",
"Current Time: 2023-08-24 21:25:36\n"
]
}
],
"source": [
"import time\n",
"\n",
"while True:\n",
" current_time = time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime())\n",
" print(f\"Current Time: {current_time}\")\n",
" time.sleep(30)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"I'm pretty happy with my model's accuracy relative to GPT-4. How does it compare cost-wise?\n",
"\n",
"I'll really push this to its limits -- let's see how quickly our poor model can classify the [full 2-million-recipe dataset](https://huggingface.co/datasets/corbt/all-recipes) 😈."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: datasets==2.14.4 in /usr/local/lib/python3.10/dist-packages (2.14.4)\n",
"Requirement already satisfied: vllm==0.1.3 in /usr/local/lib/python3.10/dist-packages (0.1.3)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (1.24.4)\n",
"Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (12.0.1)\n",
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.3.7)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.0.3)\n",
"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.28.1)\n",
"Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (4.66.1)\n",
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.3.0)\n",
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.70.15)\n",
"Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2023.6.0)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.8.5)\n",
"Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.16.4)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (23.1)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (6.0)\n",
"Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (5.9.5)\n",
"Requirement already satisfied: ray>=2.5.1 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.6.3)\n",
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.1.99)\n",
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.0.1+cu118)\n",
"Requirement already satisfied: transformers>=4.31.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (4.33.0.dev0)\n",
"Requirement already satisfied: xformers>=0.0.19 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.0.21)\n",
"Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.101.1)\n",
"Requirement already satisfied: uvicorn in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.23.2)\n",
"Requirement already satisfied: pydantic<2 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.10.12)\n",
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (23.1.0)\n",
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (2.1.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (4.0.3)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.9.2)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.4.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.3.1)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (3.9.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (4.7.1)\n",
"Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (8.1.7)\n",
"Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.18.0)\n",
"Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.0.5)\n",
"Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.24.1)\n",
"Requirement already satisfied: grpcio>=1.42.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.57.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (3.4)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (1.26.13)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (2022.12.7)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.0)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.1.2)\n",
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (2.0.0)\n",
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (3.25.0)\n",
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (15.0.7)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (2023.8.8)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.3.2)\n",
"Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi->vllm==0.1.3) (0.27.0)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
"Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn->vllm==0.1.3) (0.14.0)\n",
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->datasets==2.14.4) (1.16.0)\n",
"Requirement already satisfied: anyio<5,>=3.4.0 in /usr/local/lib/python3.10/dist-packages (from starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (3.7.1)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.0.0->vllm==0.1.3) (2.1.2)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (2023.6.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.29.1)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.8.10)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.0.0->vllm==0.1.3) (1.2.1)\n",
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.3.0)\n",
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.1.2)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install datasets==2.14.4 vllm==0.1.3"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of recipes: 2,147,248\n"
]
}
],
"source": [
"from datasets import load_dataset\n",
"\n",
"all_recipes = load_dataset(\"corbt/all-recipes\")[\"train\"][\"input\"]\n",
"\n",
"print(f\"Number of recipes: {len(all_recipes):,}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO 08-24 19:38:29 llm_engine.py:70] Initializing an LLM engine with config: model='./models/run1/merged', tokenizer='./models/run1/merged', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n",
"INFO 08-24 19:39:48 llm_engine.py:196] # GPU blocks: 3419, # CPU blocks: 512\n"
]
}
],
"source": [
"from vllm import LLM, SamplingParams\n",
"\n",
"llm = LLM(model=\"./models/run1/merged\", max_num_batched_tokens=4096)\n",
"\n",
"sampling_params = SamplingParams(\n",
" # 120 should be fine for the work we're doing here.\n",
" max_tokens=120,\n",
" # This is a deterministic task so temperature=0 is best.\n",
" temperature=0,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Start time: 1692906050.3340027\n",
"Processing recipes 0 to 10,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 10000/10000 [04:51<00:00, 34.30it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing recipes 10,000 to 20,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 10000/10000 [04:54<00:00, 33.98it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing recipes 20,000 to 30,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 10000/10000 [04:53<00:00, 34.11it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing recipes 30,000 to 40,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 10000/10000 [04:53<00:00, 34.11it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing recipes 40,000 to 50,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 48%|████▊ | 4796/10000 [02:21<03:18, 26.22it/s]"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[6], line 12\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(\u001b[39m0\u001b[39m, \u001b[39mlen\u001b[39m(all_recipes), BATCH_SIZE):\n\u001b[1;32m 11\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mProcessing recipes \u001b[39m\u001b[39m{\u001b[39;00mi\u001b[39m:\u001b[39;00m\u001b[39m,\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m to \u001b[39m\u001b[39m{\u001b[39;00mi\u001b[39m+\u001b[39mBATCH_SIZE\u001b[39m:\u001b[39;00m\u001b[39m,\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m...\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m---> 12\u001b[0m outputs \u001b[39m=\u001b[39m llm\u001b[39m.\u001b[39;49mgenerate(all_recipes[i:i\u001b[39m+\u001b[39;49mBATCH_SIZE], sampling_params\u001b[39m=\u001b[39;49msampling_params)\n\u001b[1;32m 14\u001b[0m all_outputs\u001b[39m.\u001b[39mextend([o\u001b[39m.\u001b[39moutputs[\u001b[39m0\u001b[39m]\u001b[39m.\u001b[39mtext \u001b[39mfor\u001b[39;00m o \u001b[39min\u001b[39;00m outputs])\n\u001b[1;32m 16\u001b[0m end_time \u001b[39m=\u001b[39m time\u001b[39m.\u001b[39mtime()\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/llm.py:130\u001b[0m, in \u001b[0;36mLLM.generate\u001b[0;34m(self, prompts, sampling_params, prompt_token_ids, use_tqdm)\u001b[0m\n\u001b[1;32m 128\u001b[0m token_ids \u001b[39m=\u001b[39m prompt_token_ids[i]\n\u001b[1;32m 129\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_add_request(prompt, sampling_params, token_ids)\n\u001b[0;32m--> 130\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_run_engine(use_tqdm)\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/llm.py:150\u001b[0m, in \u001b[0;36mLLM._run_engine\u001b[0;34m(self, use_tqdm)\u001b[0m\n\u001b[1;32m 148\u001b[0m outputs: List[RequestOutput] \u001b[39m=\u001b[39m []\n\u001b[1;32m 149\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mllm_engine\u001b[39m.\u001b[39mhas_unfinished_requests():\n\u001b[0;32m--> 150\u001b[0m step_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mllm_engine\u001b[39m.\u001b[39;49mstep()\n\u001b[1;32m 151\u001b[0m \u001b[39mfor\u001b[39;00m output \u001b[39min\u001b[39;00m step_outputs:\n\u001b[1;32m 152\u001b[0m \u001b[39mif\u001b[39;00m output\u001b[39m.\u001b[39mfinished:\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py:313\u001b[0m, in \u001b[0;36mLLMEngine.step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 307\u001b[0m \u001b[39mreturn\u001b[39;00m [\n\u001b[1;32m 308\u001b[0m RequestOutput\u001b[39m.\u001b[39mfrom_seq_group(seq_group)\n\u001b[1;32m 309\u001b[0m \u001b[39mfor\u001b[39;00m seq_group \u001b[39min\u001b[39;00m scheduler_outputs\u001b[39m.\u001b[39mignored_seq_groups\n\u001b[1;32m 310\u001b[0m ]\n\u001b[1;32m 312\u001b[0m \u001b[39m# Execute the model.\u001b[39;00m\n\u001b[0;32m--> 313\u001b[0m output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_run_workers(\n\u001b[1;32m 314\u001b[0m \u001b[39m\"\u001b[39;49m\u001b[39mexecute_model\u001b[39;49m\u001b[39m\"\u001b[39;49m,\n\u001b[1;32m 315\u001b[0m seq_group_metadata_list\u001b[39m=\u001b[39;49mseq_group_metadata_list,\n\u001b[1;32m 316\u001b[0m blocks_to_swap_in\u001b[39m=\u001b[39;49mscheduler_outputs\u001b[39m.\u001b[39;49mblocks_to_swap_in,\n\u001b[1;32m 317\u001b[0m blocks_to_swap_out\u001b[39m=\u001b[39;49mscheduler_outputs\u001b[39m.\u001b[39;49mblocks_to_swap_out,\n\u001b[1;32m 318\u001b[0m blocks_to_copy\u001b[39m=\u001b[39;49mscheduler_outputs\u001b[39m.\u001b[39;49mblocks_to_copy,\n\u001b[1;32m 319\u001b[0m )\n\u001b[1;32m 320\u001b[0m \u001b[39m# Update the scheduler with the model outputs.\u001b[39;00m\n\u001b[1;32m 321\u001b[0m seq_groups \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mscheduler\u001b[39m.\u001b[39mupdate(output)\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py:470\u001b[0m, in \u001b[0;36mLLMEngine._run_workers\u001b[0;34m(self, method, get_all_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 467\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 468\u001b[0m executor \u001b[39m=\u001b[39m \u001b[39mgetattr\u001b[39m(worker, method)\n\u001b[0;32m--> 470\u001b[0m output \u001b[39m=\u001b[39m executor(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 471\u001b[0m all_outputs\u001b[39m.\u001b[39mappend(output)\n\u001b[1;32m 473\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mparallel_config\u001b[39m.\u001b[39mworker_use_ray:\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[39m@functools\u001b[39m\u001b[39m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mdecorate_context\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[39mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[39mreturn\u001b[39;00m func(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/worker/worker.py:293\u001b[0m, in \u001b[0;36mWorker.execute_model\u001b[0;34m(self, seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)\u001b[0m\n\u001b[1;32m 289\u001b[0m input_tokens, input_positions, input_metadata \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_prepare_inputs(\n\u001b[1;32m 290\u001b[0m seq_group_metadata_list)\n\u001b[1;32m 292\u001b[0m \u001b[39m# Execute the model.\u001b[39;00m\n\u001b[0;32m--> 293\u001b[0m output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel(\n\u001b[1;32m 294\u001b[0m input_ids\u001b[39m=\u001b[39;49minput_tokens,\n\u001b[1;32m 295\u001b[0m positions\u001b[39m=\u001b[39;49minput_positions,\n\u001b[1;32m 296\u001b[0m kv_caches\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mgpu_cache,\n\u001b[1;32m 297\u001b[0m input_metadata\u001b[39m=\u001b[39;49minput_metadata,\n\u001b[1;32m 298\u001b[0m cache_events\u001b[39m=\u001b[39;49mcache_events,\n\u001b[1;32m 299\u001b[0m )\n\u001b[1;32m 300\u001b[0m \u001b[39mreturn\u001b[39;00m output\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/llama.py:255\u001b[0m, in \u001b[0;36mLlamaForCausalLM.forward\u001b[0;34m(self, input_ids, positions, kv_caches, input_metadata, cache_events)\u001b[0m\n\u001b[1;32m 245\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\n\u001b[1;32m 246\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 247\u001b[0m input_ids: torch\u001b[39m.\u001b[39mTensor,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 251\u001b[0m cache_events: Optional[List[torch\u001b[39m.\u001b[39mcuda\u001b[39m.\u001b[39mEvent]],\n\u001b[1;32m 252\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Dict[\u001b[39mint\u001b[39m, SequenceOutputs]:\n\u001b[1;32m 253\u001b[0m hidden_states \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel(input_ids, positions, kv_caches,\n\u001b[1;32m 254\u001b[0m input_metadata, cache_events)\n\u001b[0;32m--> 255\u001b[0m next_tokens \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49msampler(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mlm_head\u001b[39m.\u001b[39;49mweight, hidden_states,\n\u001b[1;32m 256\u001b[0m input_metadata)\n\u001b[1;32m 257\u001b[0m \u001b[39mreturn\u001b[39;00m next_tokens\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/sampler.py:44\u001b[0m, in \u001b[0;36mSampler.forward\u001b[0;34m(self, embedding, hidden_states, input_metadata, embedding_bias)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\n\u001b[1;32m 37\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 38\u001b[0m embedding: torch\u001b[39m.\u001b[39mTensor,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 42\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Dict[\u001b[39mint\u001b[39m, SequenceOutputs]:\n\u001b[1;32m 43\u001b[0m \u001b[39m# Get the hidden states that we use for sampling.\u001b[39;00m\n\u001b[0;32m---> 44\u001b[0m hidden_states \u001b[39m=\u001b[39m _prune_hidden_states(hidden_states, input_metadata)\n\u001b[1;32m 46\u001b[0m \u001b[39m# Get the logits for the next tokens.\u001b[39;00m\n\u001b[1;32m 47\u001b[0m logits \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mmatmul(hidden_states, embedding\u001b[39m.\u001b[39mt())\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"# We'll process our recipes in batches of 10,000.\n",
"\n",
"import time\n",
"\n",
"BATCH_SIZE = 10000\n",
"all_outputs = []\n",
"\n",
"start_time = time.time()\n",
"print(f\"Start time: {start_time}\")\n",
"for i in range(0, len(all_recipes), BATCH_SIZE):\n",
" print(f\"Processing recipes {i:,} to {i+BATCH_SIZE:,}...\")\n",
" outputs = llm.generate(\n",
" all_recipes[i : i + BATCH_SIZE], sampling_params=sampling_params\n",
" )\n",
"\n",
" all_outputs.extend([o.outputs[0].text for o in outputs])\n",
"\n",
"end_time = time.time()\n",
"print(f\"End time: {end_time}\")\n",
"print(f\"Total hours: {((end_time - start_time) / 3600):.2f}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Nice! I've processed all 2,147,248 recipes in under 17 hours. Let's do a cost comparison with GPT-3.5 and GPT-4. I'll use the GPT-4 latency/cost numbers based on the 5000 samples used to generate our model's training data."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Model</th>\n",
" <th>Cost to Classify One Recipe</th>\n",
" <th>Cost to Classify Entire Dataset</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Llama 2 7B (finetuned)</td>\n",
" <td>0.000009</td>\n",
" <td>18.86</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>GPT-3.5</td>\n",
" <td>0.000481</td>\n",
" <td>1,033.26</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>GPT-3.5 (finetuned)</td>\n",
" <td>0.004044</td>\n",
" <td>8,683.47</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>GPT-4</td>\n",
" <td>0.010800</td>\n",
" <td>23,190.28</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Model Cost to Classify One Recipe \\\n",
"0 Llama 2 7B (finetuned) 0.000009 \n",
"1 GPT-3.5 0.000481 \n",
"2 GPT-3.5 (finetuned) 0.004044 \n",
"3 GPT-4 0.010800 \n",
"\n",
" Cost to Classify Entire Dataset \n",
"0 18.86 \n",
"1 1,033.26 \n",
"2 8,683.47 \n",
"3 23,190.28 "
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"# I used an on-demand Nvidia L40 on RunPod for this, at an hourly cost of $1.14.\n",
"finetuned_hourly_cost = 1.14\n",
"\n",
"finetuned_total_hours = 16.54\n",
"\n",
"finetuned_avg_cost = finetuned_hourly_cost * finetuned_total_hours / len(all_recipes)\n",
"\n",
"# The average input and output tokens calculated by OpenAI, based on the 5000 recipes I sent them\n",
"avg_input_tokens = 276\n",
"avg_output_tokens = 42\n",
"\n",
"# Token pricing from https://openai.com/pricing\n",
"gpt_4_avg_cost = avg_input_tokens * 0.03 / 1000 + avg_output_tokens * 0.06 / 1000\n",
"\n",
"gpt_35_avg_cost = avg_input_tokens * 0.0015 / 1000 + avg_output_tokens * 0.0016 / 1000\n",
"\n",
"gpt_35_finetuned_avg_cost = (\n",
" avg_input_tokens * 0.012 / 1000 + avg_output_tokens * 0.016 / 1000 + 0.06 / 1000\n",
")\n",
"\n",
"# Multiply the number of recipes\n",
"# gpt_4_cost = len(all_recipes) * gpt_4_avg_cost\n",
"# gpt_35_cost = len(all_recipes) * gpt_35_avg_cost\n",
"# gpt_35_finetuned_cost = len(all_recipes) * gpt_35_finetuned_avg_cost\n",
"\n",
"# Let's put this in a dataframe for easier comparison.\n",
"\n",
"costs = pd.DataFrame(\n",
" {\n",
" \"Model\": [\n",
" \"Llama 2 7B (finetuned)\",\n",
" \"GPT-3.5\",\n",
" \"GPT-3.5 (finetuned)\",\n",
" \"GPT-4\",\n",
" ],\n",
" \"Cost to Classify One Recipe\": [\n",
" finetuned_avg_cost,\n",
" gpt_35_avg_cost,\n",
" gpt_35_finetuned_avg_cost,\n",
" gpt_4_avg_cost,\n",
" ],\n",
" }\n",
")\n",
"\n",
"costs[\"Cost to Classify Entire Dataset\"] = (\n",
" costs[\"Cost to Classify One Recipe\"] * len(all_recipes)\n",
").map(lambda x: f\"{x:,.2f}\")\n",
"\n",
"\n",
"costs\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"...and just for fun, let's figure out how many recipes my pescatarian basement-dwelling brother can make! 😂"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -6,61 +6,24 @@
"source": [
"In this notebook I'm using the OpenPipe client to capture a set of calls to the OpenAI API.\n",
"\n",
"For this example I'll blithely throw engineering best practices to the wind and use the notebook itself to manage dependencies. 😁"
"For this example I'll blithely throw engineering best practices to the wind and use the notebook itself to manage dependencies. 😁\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: openpipe==3.0.3 in /usr/local/lib/python3.10/dist-packages (3.0.3)\n",
"Requirement already satisfied: python-dotenv==1.0.0 in /usr/local/lib/python3.10/dist-packages (1.0.0)\n",
"Requirement already satisfied: joblib==1.3.2 in /usr/local/lib/python3.10/dist-packages (1.3.2)\n",
"Requirement already satisfied: attrs<24.0.0,>=23.1.0 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (23.1.0)\n",
"Requirement already satisfied: httpx<0.25.0,>=0.24.1 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (0.24.1)\n",
"Requirement already satisfied: openai<0.28.0,>=0.27.8 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (0.27.9)\n",
"Requirement already satisfied: python-dateutil<3.0.0,>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (2.8.2)\n",
"Requirement already satisfied: toml<0.11.0,>=0.10.2 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (0.10.2)\n",
"Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (2022.12.7)\n",
"Requirement already satisfied: httpcore<0.18.0,>=0.15.0 in /usr/local/lib/python3.10/dist-packages (from httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (0.17.3)\n",
"Requirement already satisfied: idna in /usr/local/lib/python3.10/dist-packages (from httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (3.4)\n",
"Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (1.3.0)\n",
"Requirement already satisfied: requests>=2.20 in /usr/local/lib/python3.10/dist-packages (from openai<0.28.0,>=0.27.8->openpipe==3.0.3) (2.28.1)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from openai<0.28.0,>=0.27.8->openpipe==3.0.3) (4.66.1)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from openai<0.28.0,>=0.27.8->openpipe==3.0.3) (3.8.5)\n",
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil<3.0.0,>=2.8.2->openpipe==3.0.3) (1.16.0)\n",
"Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.10/dist-packages (from httpcore<0.18.0,>=0.15.0->httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (0.14.0)\n",
"Requirement already satisfied: anyio<5.0,>=3.0 in /usr/local/lib/python3.10/dist-packages (from httpcore<0.18.0,>=0.15.0->httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (3.7.1)\n",
"Requirement already satisfied: charset-normalizer<3,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.20->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (2.1.1)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.20->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (1.26.13)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (4.0.3)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (1.9.2)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (1.4.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (1.3.1)\n",
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5.0,>=3.0->httpcore<0.18.0,>=0.15.0->httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (1.1.2)\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.10 -m pip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"outputs": [],
"source": [
"%pip install openpipe==3.0.3 python-dotenv==1.0.0 joblib==1.3.2 datasets==2.14.4"
"%%capture\n",
"%pip install openpipe==3.0.3 python-dotenv==1.0.0 datasets==2.14.4"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When working with remote datasets (or any data, really), it's a good idea to visually inspect some samples to make sure it looks like you expect. I'll print a recipe."
"When working with remote datasets (or any data, really), it's a good idea to visually inspect some samples to make sure it looks like you expect. I'll print a recipe.\n"
]
},
{
@@ -132,15 +95,16 @@
"Mm, delicious. Anyway, we need to generate a training dataset. We'll call GPT-4 on each of our examples.\n",
"\n",
"In this case, I'll ask GPT-4 to classify each recipe along 5 dimensions:\n",
" - has_non_fish_meat\n",
" - requires_oven\n",
" - requires_stove\n",
" - cook_time_over_30_mins\n",
" - main_dish\n",
"\n",
"- has_non_fish_meat\n",
"- requires_oven\n",
"- requires_stove\n",
"- cook_time_over_30_mins\n",
"- main_dish\n",
"\n",
"That looks like a pretty random list, but there's actually an important unifying thread: I'm looking for meals that my pescatarian brother/co-founder can make in his kitchen-less, near-window-less basement apartment in San Francisco! (If you haven't tried to get an apartment in SF you probably think I'm joking 😂.)\n",
"\n",
"I'll use [OpenPipe](https://github.com/openpipe/openpipe) to track the API calls and form a training dataset. To follow along you'll need to create a free OpenPipe account, then copy your API key from https://app.openpipe.ai/project/settings into a file called `.env`. You can see an example in [./.env.example](./.env.example)."
"I'll use [OpenPipe](https://github.com/openpipe/openpipe) to track the API calls and form a training dataset. To follow along you'll need to create a free OpenPipe account, then copy your API key from https://app.openpipe.ai/project/settings into a file called `.env`. You can see an example in [./.env.example](./.env.example).\n"
]
},
{
@@ -240,7 +204,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"That's working, so I'll go ahead and classify all 5000 recipes with GPT-4. Using GPT-4 for this is slowwww and costs about $40. The model I'm fine-tuning will be much faster -- we'll see if we can make it as good!"
"That's working, so I'll go ahead and classify all 5000 recipes with GPT-4. Using GPT-4 for this is slowwww and costs about $40. The model I'm fine-tuning will be much faster -- we'll see if we can make it as good!\n"
]
},
{
@@ -320,11 +284,11 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok, now that my recipes are classified I'll download the training data. \n",
"Ok, now that my recipes are classified I'll download the training data.\n",
"\n",
"Next up I'll train the model -- check out [./train.ipynb](./train.ipynb) for details! Just go to https://app.openpipe.ai/request-logs, select all the logs you created, and click \"Export\". The default 10% testing split is fine for this dataset size.\n",
"\n",
"I got two files from that: `train.jsonl` and `test.jsonl`. I moved both of them into this repository under `./data/`."
"I got two files from that: `train.jsonl` and `test.jsonl`. I moved both of them into this repository under `./data/`.\n"
]
}
],

View File

@@ -4,142 +4,16 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's get to the fun part -- training a model. I'll start by installing the dependencies."
"Now let's get to the fun part -- training a model. I'll start by installing the dependencies.\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: peft==0.5.0 in /usr/local/lib/python3.10/dist-packages (0.5.0)\n",
"\u001b[31mERROR: Could not find a version that satisfies the requirement python-dotenv==2.0.0 (from versions: 0.1.0, 0.1.2, 0.1.3, 0.1.5, 0.2.0, 0.3.0, 0.4.0, 0.5.0, 0.5.1, 0.6.0, 0.6.1, 0.6.2, 0.6.3, 0.6.4, 0.6.5, 0.7.0, 0.7.1, 0.8.0, 0.8.1, 0.8.2, 0.9.0, 0.9.1, 0.10.0, 0.10.1, 0.10.2, 0.10.3, 0.10.4, 0.10.5, 0.11.0, 0.12.0, 0.13.0, 0.14.0, 0.15.0, 0.16.0, 0.17.0, 0.17.1, 0.18.0, 0.19.0, 0.19.1, 0.19.2, 0.20.0, 0.21.0, 0.21.1, 1.0.0)\u001b[0m\u001b[31m\n",
"\u001b[0m\u001b[31mERROR: No matching distribution found for python-dotenv==2.0.0\u001b[0m\u001b[31m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.10 -m pip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n",
"fatal: destination path 'axolotl' already exists and is not an empty directory.\n",
"Obtaining file:///workspace/OpenPipe/examples/classify-recipes/axolotl\n",
" Preparing metadata (setup.py) ... \u001b[?25ldone\n",
"\u001b[?25hCollecting transformers@ git+https://github.com/huggingface/transformers.git (from axolotl==0.1)\n",
" Cloning https://github.com/huggingface/transformers.git to /tmp/pip-install-ckp96ans/transformers_783779e09ad546a5be81c173eca5fd38\n",
" Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-install-ckp96ans/transformers_783779e09ad546a5be81c173eca5fd38\n",
" Resolved https://github.com/huggingface/transformers.git to commit f26099e7b5cf579f99a42bab6ddd371bf2c8d548\n",
" Installing build dependencies ... \u001b[?25ldone\n",
"\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n",
"\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n",
"\u001b[?25hCollecting accelerate@ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b (from axolotl==0.1)\n",
" Using cached accelerate-0.22.0.dev0-py3-none-any.whl\n",
"Requirement already satisfied: bitsandbytes>=0.41.1 in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (0.41.1)\n",
"Requirement already satisfied: addict in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (2.4.0)\n",
"Requirement already satisfied: fire in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (0.5.0)\n",
"Requirement already satisfied: PyYAML==6.0 in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (6.0)\n",
"Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (2.14.4)\n",
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (0.1.99)\n",
"Requirement already satisfied: wandb in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (0.15.8)\n",
"Requirement already satisfied: einops in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (0.6.1)\n",
"Requirement already satisfied: xformers in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (0.0.21)\n",
"Requirement already satisfied: optimum in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (1.11.2)\n",
"Requirement already satisfied: hf_transfer in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (0.1.3)\n",
"Requirement already satisfied: colorama in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (0.4.6)\n",
"Requirement already satisfied: numba in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (0.57.1)\n",
"Requirement already satisfied: numpy==1.24.4 in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (1.24.4)\n",
"Requirement already satisfied: bert-score==0.3.13 in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (0.3.13)\n",
"Requirement already satisfied: evaluate==0.4.0 in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (0.4.0)\n",
"Requirement already satisfied: rouge-score==0.1.2 in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (0.1.2)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (1.11.2)\n",
"Requirement already satisfied: scikit-learn==1.2.2 in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (1.2.2)\n",
"Requirement already satisfied: pynvml in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (11.5.0)\n",
"Requirement already satisfied: flash-attn==2.0.8 in /usr/local/lib/python3.10/dist-packages (from axolotl==0.1) (2.0.8)\n",
"Requirement already satisfied: torch>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from bert-score==0.3.13->axolotl==0.1) (2.0.1+cu118)\n",
"Requirement already satisfied: pandas>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from bert-score==0.3.13->axolotl==0.1) (2.0.3)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from bert-score==0.3.13->axolotl==0.1) (2.28.1)\n",
"Requirement already satisfied: tqdm>=4.31.1 in /usr/local/lib/python3.10/dist-packages (from bert-score==0.3.13->axolotl==0.1) (4.66.1)\n",
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from bert-score==0.3.13->axolotl==0.1) (3.7.2)\n",
"Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from bert-score==0.3.13->axolotl==0.1) (23.1)\n",
"Requirement already satisfied: dill in /usr/local/lib/python3.10/dist-packages (from evaluate==0.4.0->axolotl==0.1) (0.3.7)\n",
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from evaluate==0.4.0->axolotl==0.1) (3.3.0)\n",
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from evaluate==0.4.0->axolotl==0.1) (0.70.15)\n",
"Requirement already satisfied: fsspec[http]>=2021.05.0 in /usr/local/lib/python3.10/dist-packages (from evaluate==0.4.0->axolotl==0.1) (2023.6.0)\n",
"Requirement already satisfied: huggingface-hub>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from evaluate==0.4.0->axolotl==0.1) (0.16.4)\n",
"Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.10/dist-packages (from evaluate==0.4.0->axolotl==0.1) (0.18.0)\n",
"Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from flash-attn==2.0.8->axolotl==0.1) (1.11.1)\n",
"Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from rouge-score==0.1.2->axolotl==0.1) (1.4.0)\n",
"Requirement already satisfied: nltk in /usr/local/lib/python3.10/dist-packages (from rouge-score==0.1.2->axolotl==0.1) (3.8.1)\n",
"Requirement already satisfied: six>=1.14.0 in /usr/lib/python3/dist-packages (from rouge-score==0.1.2->axolotl==0.1) (1.16.0)\n",
"Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn==1.2.2->axolotl==0.1) (1.3.2)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn==1.2.2->axolotl==0.1) (3.2.0)\n",
"Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets->axolotl==0.1) (12.0.1)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets->axolotl==0.1) (3.8.5)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers@ git+https://github.com/huggingface/transformers.git->axolotl==0.1) (3.9.0)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers@ git+https://github.com/huggingface/transformers.git->axolotl==0.1) (2023.8.8)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers@ git+https://github.com/huggingface/transformers.git->axolotl==0.1) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers@ git+https://github.com/huggingface/transformers.git->axolotl==0.1) (0.3.2)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate@ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b->axolotl==0.1) (5.9.5)\n",
"Requirement already satisfied: termcolor in /usr/local/lib/python3.10/dist-packages (from fire->axolotl==0.1) (2.3.0)\n",
"Requirement already satisfied: llvmlite<0.41,>=0.40.0dev0 in /usr/local/lib/python3.10/dist-packages (from numba->axolotl==0.1) (0.40.1)\n",
"Requirement already satisfied: coloredlogs in /usr/local/lib/python3.10/dist-packages (from optimum->axolotl==0.1) (15.0.1)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from optimum->axolotl==0.1) (1.11.1)\n",
"Requirement already satisfied: Click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb->axolotl==0.1) (8.1.7)\n",
"Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb->axolotl==0.1) (3.1.32)\n",
"Requirement already satisfied: sentry-sdk>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb->axolotl==0.1) (1.29.2)\n",
"Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from wandb->axolotl==0.1) (0.4.0)\n",
"Requirement already satisfied: pathtools in /usr/local/lib/python3.10/dist-packages (from wandb->axolotl==0.1) (0.1.2)\n",
"Requirement already satisfied: setproctitle in /usr/local/lib/python3.10/dist-packages (from wandb->axolotl==0.1) (1.3.2)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb->axolotl==0.1) (68.0.0)\n",
"Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.10/dist-packages (from wandb->axolotl==0.1) (1.4.4)\n",
"Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb->axolotl==0.1) (4.24.1)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->bert-score==0.3.13->axolotl==0.1) (4.7.1)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->bert-score==0.3.13->axolotl==0.1) (3.0)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->bert-score==0.3.13->axolotl==0.1) (3.1.2)\n",
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.0.0->bert-score==0.3.13->axolotl==0.1) (2.0.0)\n",
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.0.0->bert-score==0.3.13->axolotl==0.1) (3.25.0)\n",
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.0.0->bert-score==0.3.13->axolotl==0.1) (15.0.7)\n",
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->axolotl==0.1) (23.1.0)\n",
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->axolotl==0.1) (2.1.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->axolotl==0.1) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->axolotl==0.1) (4.0.3)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->axolotl==0.1) (1.9.2)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->axolotl==0.1) (1.4.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->axolotl==0.1) (1.3.1)\n",
"Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from GitPython!=3.1.29,>=1.0.0->wandb->axolotl==0.1) (4.0.10)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.0.1->bert-score==0.3.13->axolotl==0.1) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.0.1->bert-score==0.3.13->axolotl==0.1) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.0.1->bert-score==0.3.13->axolotl==0.1) (2023.3)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->bert-score==0.3.13->axolotl==0.1) (3.4)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->bert-score==0.3.13->axolotl==0.1) (1.26.13)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->bert-score==0.3.13->axolotl==0.1) (2022.12.7)\n",
"Requirement already satisfied: humanfriendly>=9.1 in /usr/local/lib/python3.10/dist-packages (from coloredlogs->optimum->axolotl==0.1) (10.0)\n",
"Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->bert-score==0.3.13->axolotl==0.1) (1.1.0)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->bert-score==0.3.13->axolotl==0.1) (0.11.0)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->bert-score==0.3.13->axolotl==0.1) (4.42.1)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->bert-score==0.3.13->axolotl==0.1) (1.4.4)\n",
"Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->bert-score==0.3.13->axolotl==0.1) (9.3.0)\n",
"Requirement already satisfied: pyparsing<3.1,>=2.3.1 in /usr/lib/python3/dist-packages (from matplotlib->bert-score==0.3.13->axolotl==0.1) (2.4.7)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->optimum->axolotl==0.1) (1.2.1)\n",
"Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb->axolotl==0.1) (5.0.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.0.0->bert-score==0.3.13->axolotl==0.1) (2.1.2)\n",
"Installing collected packages: axolotl\n",
" Attempting uninstall: axolotl\n",
" Found existing installation: axolotl 0.1\n",
" Uninstalling axolotl-0.1:\n",
" Successfully uninstalled axolotl-0.1\n",
" Running setup.py develop for axolotl\n",
"Successfully installed axolotl\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.10 -m pip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"outputs": [],
"source": [
"%%capture\n",
"%pip install peft==0.5.0 python-dotenv==2.0.0\n",
"\n",
"!git clone https://github.com/OpenAccess-AI-Collective/axolotl\n",
@@ -150,7 +24,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Note to the reader: since we'll be basing our fine-tuned model on Meta's Llama 2, you need to apply for access to the weights (which will be automatically granted). Follow the steps on [HuggingFace](https://huggingface.co/meta-llama/Llama-2-7b-hf), then create a read-only access token [here](https://huggingface.co/settings/tokens) and copy it into your .env file."
"Note to the reader: since we'll be basing our fine-tuned model on Meta's Llama 2, you need to apply for access to the weights (which will be automatically granted). Follow the steps on [HuggingFace](https://huggingface.co/meta-llama/Llama-2-7b-hf), then create a read-only access token [here](https://huggingface.co/settings/tokens) and copy it into your .env file.\n"
]
},
{
@@ -185,7 +59,7 @@
"\n",
"In this case I'm using 8-bit training to use less GPU RAM, and sample packing to maximize GPU utilization. You can read more about the available options at https://github.com/OpenAccess-AI-Collective/axolotl.\n",
"\n",
"The training run options are defined in [training-config.yaml](./training-config.yaml)."
"The training run options are defined in [training-config.yaml](./training-config.yaml).\n"
]
},
{
@@ -991,7 +865,7 @@
"source": [
"Sweet! I now have a new directory `./models/run1`. This contains my trained model, which I can use to classify more recipes.\n",
"\n",
"There's one more step though. I trained our model using [LoRA](https://huggingface.co/docs/peft/conceptual_guides/lora), which is a memory-efficient training method. But the inference library we'll use for testing doesn't support LoRA models directly yet, so we need to \"merge\" our LoRA model to transform it into a standard Llama2-shaped model. I've defined a small helper to do that called `merge_lora_model` that I'll use below."
"There's one more step though. I trained our model using [LoRA](https://huggingface.co/docs/peft/conceptual_guides/lora), which is a memory-efficient training method. But the inference library we'll use for testing doesn't support LoRA models directly yet, so we need to \"merge\" our LoRA model to transform it into a standard Llama2-shaped model. I've defined a small helper to do that called `merge_lora_model` that I'll use below.\n"
]
},
{
@@ -1044,7 +918,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok, I have a model, but is it actually any good? I'll run some evaluations in [./evaluate.ipynb](./evaluate.ipynb) to check."
"Ok, I have a model, but is it actually any good? I'll run some evaluations in [./evaluate.ipynb](./evaluate.ipynb) to check.\n"
]
}
],

View File

@@ -4,96 +4,29 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"I have a model in `./models/run1/merged` that was trained on GPT-4's outputs to classify recipes. I need to figure out whether it does a good job at classifying recipes. I'll install dependencies first."
"I have a model in `./models/run1/merged` that was trained on GPT-4's outputs to classify recipes. I need to figure out whether it does a good job at classifying recipes. I'll install dependencies first.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: vllm==0.1.3 in /usr/local/lib/python3.10/dist-packages (0.1.3)\n",
"Requirement already satisfied: pandas==2.0.3 in /usr/local/lib/python3.10/dist-packages (2.0.3)\n",
"Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (5.9.5)\n",
"Requirement already satisfied: ray>=2.5.1 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.6.3)\n",
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.1.99)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.24.4)\n",
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.0.1+cu118)\n",
"Requirement already satisfied: transformers>=4.31.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (4.33.0.dev0)\n",
"Requirement already satisfied: xformers>=0.0.19 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.0.21)\n",
"Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.101.1)\n",
"Requirement already satisfied: uvicorn in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.23.2)\n",
"Requirement already satisfied: pydantic<2 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.10.12)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas==2.0.3) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas==2.0.3) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas==2.0.3) (2023.3)\n",
"Requirement already satisfied: typing-extensions>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<2->vllm==0.1.3) (4.7.1)\n",
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas==2.0.3) (1.16.0)\n",
"Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (8.1.7)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (3.9.0)\n",
"Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.18.0)\n",
"Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.0.5)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (23.1)\n",
"Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.24.1)\n",
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (6.0)\n",
"Requirement already satisfied: aiosignal in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.3.1)\n",
"Requirement already satisfied: frozenlist in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.4.0)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (2.28.1)\n",
"Requirement already satisfied: grpcio>=1.42.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.57.0)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.0)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.1.2)\n",
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (2.0.0)\n",
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (3.25.0)\n",
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (15.0.7)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.15.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.16.4)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (2023.8.8)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.3.2)\n",
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (4.66.1)\n",
"Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi->vllm==0.1.3) (0.27.0)\n",
"Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn->vllm==0.1.3) (0.14.0)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.15.1->transformers>=4.31.0->vllm==0.1.3) (2023.6.0)\n",
"Requirement already satisfied: anyio<5,>=3.4.0 in /usr/local/lib/python3.10/dist-packages (from starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (3.7.1)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.0.0->vllm==0.1.3) (2.1.2)\n",
"Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (23.1.0)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (2023.6.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.29.1)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.8.10)\n",
"Requirement already satisfied: charset-normalizer<3,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm==0.1.3) (2.1.1)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm==0.1.3) (3.4)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm==0.1.3) (1.26.13)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm==0.1.3) (2022.12.7)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.0.0->vllm==0.1.3) (1.2.1)\n",
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.3.0)\n",
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.1.2)\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.10 -m pip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"outputs": [],
"source": [
"%pip install vllm==0.1.3 pandas==2.0.3"
"%%capture\n",
"%pip install vllm==0.1.3 pandas==2.0.3 joblib==1.3.2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Remember I got a \"test.jsonl\" file from OpenPipe back in [./prepare.ipynb](./prepare.ipynb)? That's data from our dataset that we didn't use in training, so we can use it to check our model's performance."
"Remember I got a \"test.jsonl\" file from OpenPipe back in [./prepare.ipynb](./prepare.ipynb)? That's data from our dataset that we didn't use in training, so we can use it to check our model's performance.\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@@ -106,12 +39,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"During the training process Axolotl transformed our data into an instruction/response format known as the \"Alpaca format\" based on [the project that introduced it](https://github.com/tatsu-lab/stanford_alpaca). I need to transform my test data into the same format for best results."
"During the training process Axolotl transformed our data into an instruction/response format known as the \"Alpaca format\" based on [the project that introduced it](https://github.com/tatsu-lab/stanford_alpaca). I need to transform my test data into the same format for best results.\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 11,
"metadata": {},
"outputs": [
{
@@ -147,27 +80,27 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Next up, I'll use [vLLM](https://vllm.readthedocs.io/en/latest/) to efficiently process all the prompts in our test data with our own model."
"Next up, I'll use [vLLM](https://vllm.readthedocs.io/en/latest/) to efficiently process all the prompts in our test data with our own model.\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO 08-25 03:58:49 llm_engine.py:70] Initializing an LLM engine with config: model='./models/run1/merged', tokenizer='./models/run1/merged', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n",
"INFO 08-25 03:59:40 llm_engine.py:196] # GPU blocks: 3419, # CPU blocks: 512\n"
"INFO 08-28 00:26:23 llm_engine.py:70] Initializing an LLM engine with config: model='./models/run1/merged', tokenizer='./models/run1/merged', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n",
"INFO 08-28 00:27:26 llm_engine.py:196] # GPU blocks: 3419, # CPU blocks: 512\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 500/500 [00:37<00:00, 13.42it/s]"
"Processed prompts: 100%|██████████| 500/500 [00:37<00:00, 13.34it/s]"
]
},
{
@@ -211,12 +144,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok, we have our outputs! There are 5 categories we classify each recipe on, so let's check what percentage of the time our model's output matches GPT-4's. I'll write a quick eval function for that:"
"Ok, we have our outputs! There are 5 categories we classify each recipe on, so let's check what percentage of the time our model's output matches GPT-4's. I'll write a quick eval function for that:\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 32,
"metadata": {},
"outputs": [
{
@@ -239,20 +172,29 @@
" return args_dict\n",
"\n",
"\n",
"def calculate_accuracy(row):\n",
"test_data[\"output_parsed\"] = test_data[\"output\"].apply(parse_fn_call)\n",
"test_data[\"my_outputs_parsed\"] = test_data[\"my_outputs\"].apply(parse_fn_call)\n",
"\n",
"\n",
"def calculate_accuracy(row, labels_col):\n",
" \"\"\"Calculate the fraction of my model's outputs that match the reference outputs\"\"\"\n",
" true_outputs = parse_fn_call(row[\"output\"])\n",
" my_outputs = parse_fn_call(row[\"my_outputs\"])\n",
" true_outputs = row[\"output_parsed\"]\n",
" labels_outputs = row[labels_col]\n",
"\n",
" # print(f\"true_outputs: {true_outputs}\")\n",
" # print(f\"my_outputs: {row[labels_col]}\")\n",
"\n",
" num_matching_outputs = 0\n",
" for key in true_outputs.keys():\n",
" if key in my_outputs and true_outputs[key] == my_outputs[key]:\n",
" if key in labels_outputs and true_outputs[key] == labels_outputs[key]:\n",
" num_matching_outputs += 1\n",
"\n",
" return num_matching_outputs / len(true_outputs)\n",
"\n",
"\n",
"test_data[\"accuracy\"] = test_data.apply(calculate_accuracy, axis=1)\n",
"test_data[\"accuracy\"] = test_data.apply(\n",
" calculate_accuracy, axis=1, labels_col=\"my_outputs_parsed\"\n",
")\n",
"\n",
"print(f\"Overall accuracy: {test_data['accuracy'].mean():.2f}\")\n"
]
@@ -261,12 +203,293 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Not bad! However, there are still a few rows where the model outputs don't match. Let's take a closer look."
"95% seems good! However, we don't have much to compare it to. Let's see how GPT-3.5 would do on the same task as a baseline. We'll use the same prompt we used with GPT-4 to generate the labels.\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sample recipe:\n",
"--------------\n",
"Pan Gravy\n",
"\n",
"Ingredients:\n",
"- 1/3 cup all purpose flour\n",
"- 1/3 cup turkey drippings\n",
"- 3 cup water or broth\n",
"- 1/8 to 1/4 teaspoon salt\n",
"- 1/8 tsp pepper\n",
"\n",
"Directions:\n",
"- In a skillet or roasting pan, add flour to drippings; blend well.\n",
"- Cook over medium heat 2 to 3 minutes until smooth and light brown, stirring constantly.\n",
"- Add water; cook until mixture boils and thickens, stirring constantly.\n",
"- Stir in salt and pepper.\n",
"- *Flour and drippings can be decreased to 1/4 cup each for thinner gravy.\n",
"- *\n"
]
}
],
"source": [
"import json\n",
"\n",
"\n",
"def extract_recipe(row):\n",
" \"\"\"Extract the recipe from the instruction\"\"\"\n",
" return json.loads(row[\"instruction\"])[1][\"content\"]\n",
"\n",
"\n",
"recipes = test_data.apply(extract_recipe, axis=1)\n",
"print(f\"Sample recipe:\\n--------------\\n{recipes[0]}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Classifying first recipe:\n",
"------------------\n",
"{'has_non_fish_meat': False, 'requires_oven': False, 'requires_stove': True, 'cook_time_over_30_mins': False, 'main_dish': False}\n"
]
}
],
"source": [
"import joblib\n",
"import openai\n",
"import os\n",
"import dotenv\n",
"\n",
"dotenv.load_dotenv()\n",
"openai.api_key = os.getenv(\"OPENAI_API_KEY\")\n",
"\n",
"memory = joblib.Memory(\"./cache\", verbose=0)\n",
"\n",
"\n",
"@memory.cache\n",
"def classify_recipe_35(recipe: str):\n",
" completion = openai.ChatCompletion.create(\n",
" model=\"gpt-3.5-turbo\",\n",
" messages=[\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": \"Your goal is to classify a recipe along several dimensions.Pay attention to the instructions.\",\n",
" },\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": recipe,\n",
" },\n",
" ],\n",
" functions=[\n",
" {\n",
" \"name\": \"classify\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"has_non_fish_meat\": {\n",
" \"type\": \"boolean\",\n",
" \"description\": \"True if the recipe contains any meat or meat products (eg. chicken broth) besides fish\",\n",
" },\n",
" \"requires_oven\": {\n",
" \"type\": \"boolean\",\n",
" \"description\": \"True if the recipe requires an oven\",\n",
" },\n",
" \"requires_stove\": {\n",
" \"type\": \"boolean\",\n",
" \"description\": \"True if the recipe requires a stove\",\n",
" },\n",
" \"cook_time_over_30_mins\": {\n",
" \"type\": \"boolean\",\n",
" \"description\": \"True if the recipe takes over 30 minutes to prepare and cook, including waiting time\",\n",
" },\n",
" \"main_dish\": {\n",
" \"type\": \"boolean\",\n",
" \"description\": \"True if the recipe can be served as a main dish\",\n",
" },\n",
" },\n",
" \"required\": [\n",
" \"has_non_fish_meat\",\n",
" \"requires_oven\",\n",
" \"requires_stove\",\n",
" \"cook_time_over_30_mins\",\n",
" \"main_dish\",\n",
" ],\n",
" },\n",
" }\n",
" ],\n",
" function_call={\n",
" \"name\": \"classify\",\n",
" },\n",
" )\n",
" return json.loads(completion.choices[0].message.function_call.arguments)\n",
"\n",
"\n",
"print(\"Classifying first recipe:\\n------------------\")\n",
"print(classify_recipe_35(recipes[0]))\n"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"100%|██████████| 500/500 [00:31<00:00, 15.77it/s]\n"
]
}
],
"source": [
"from tqdm import tqdm\n",
"\n",
"test_data[\"gpt_3.5\"] = [classify_recipe_35(r) for r in tqdm(recipes)]\n"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GPT-3.5 accuracy: 0.91\n"
]
}
],
"source": [
"test_data[\"gpt_3.5_accuracy\"] = test_data.apply(\n",
" calculate_accuracy, axis=1, labels_col=\"gpt_3.5\"\n",
")\n",
"\n",
"print(f\"GPT-3.5 accuracy: {test_data['gpt_3.5_accuracy'].mean():.2f}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And for completeness, let's try a fine-tuned GPT-3.5 model. You can find the fine-tuning code in [finetune-gpt-3.5.ipynb](./finetune-gpt-3.5.ipynb)\n"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'has_non_fish_meat': True,\n",
" 'requires_oven': False,\n",
" 'requires_stove': True,\n",
" 'cook_time_over_30_mins': False,\n",
" 'main_dish': False}"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@memory.cache\n",
"def classify_recipe_35_ft(recipe: str):\n",
" completion = openai.ChatCompletion.create(\n",
" model=\"ft:gpt-3.5-turbo-0613:openpipe::7rZpPqYn\",\n",
" messages=[\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": \"Your goal is to classify a recipe along several \"\n",
" \"dimensions.Pay attention to the instructions.\",\n",
" },\n",
" {\"role\": \"user\", \"content\": recipe},\n",
" ],\n",
" )\n",
"\n",
" return json.loads(completion.choices[0].message.content)\n",
"\n",
"\n",
"classify_recipe_35_ft(recipes[0])\n"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 500/500 [07:31<00:00, 1.11it/s]\n"
]
}
],
"source": [
"test_data[\"gpt_3.5_ft\"] = [classify_recipe_35_ft(r) for r in tqdm(recipes)]\n"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GPT-3.5 FT accuracy: 0.94\n"
]
}
],
"source": [
"test_data[\"gpt_3.5_ft_accuracy\"] = test_data.apply(\n",
" calculate_accuracy, axis=1, labels_col=\"gpt_3.5_ft\"\n",
")\n",
"\n",
"print(f\"GPT-3.5 FT accuracy: {test_data['gpt_3.5_ft_accuracy'].mean():.2f}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Not bad! However, there are still a few rows where the model outputs don't match. Let's take a closer look.\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
@@ -625,9 +848,28 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Looking at the outputs, it's clear that our model still makes some mistakes. But at the same time, there are plenty of examples like \"Rhonda's Butter Chess Pie\" where our model gets it right, even though GPT-4 got it wrong! And there are also cases like the \"Veggie Casserole\", where the \"right\" answer is truly ambiguous and really both answers are defensible.\n",
"\n",
"Interested in cost/latency benchmarking? You can check out [./benchmarking.ipynb](./benchmarking.ipynb) for an overview of my findings!"
"Looking at the outputs, it's clear that our model still makes some mistakes. But at the same time, there are plenty of examples like \"Rhonda's Butter Chess Pie\" where our model gets it right, even though GPT-4 got it wrong! And there are also cases like the \"Veggie Casserole\", where the \"right\" answer is truly ambiguous and really both answers are defensible.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A realistic point of comparison here might be GPT-3.5. Let's try to classify the same set of recipes using GPT-3.5 and see how it does. We'll use the same prompt that we used with GPT-4 to generate the initial training data.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Interested in cost/latency benchmarking? You can check out [./benchmarking.ipynb](./benchmarking.ipynb) for an overview of my findings!\n"
]
},
{

View File

@@ -6,104 +6,22 @@
"source": [
"I'm pretty happy with my model's accuracy relative to GPT-4. How does it compare cost-wise?\n",
"\n",
"I'll really push this to its limits -- let's see how quickly our poor model can classify the [full 2-million-recipe dataset](https://huggingface.co/datasets/corbt/all-recipes) 😈."
"I'll really push this to its limits -- let's see how quickly our poor model can classify the [full 2-million-recipe dataset](https://huggingface.co/datasets/corbt/all-recipes) 😈.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: datasets==2.14.4 in /usr/local/lib/python3.10/dist-packages (2.14.4)\n",
"Requirement already satisfied: vllm==0.1.3 in /usr/local/lib/python3.10/dist-packages (0.1.3)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (1.24.4)\n",
"Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (12.0.1)\n",
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.3.7)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.0.3)\n",
"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.28.1)\n",
"Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (4.66.1)\n",
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.3.0)\n",
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.70.15)\n",
"Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2023.6.0)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.8.5)\n",
"Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.16.4)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (23.1)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (6.0)\n",
"Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (5.9.5)\n",
"Requirement already satisfied: ray>=2.5.1 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.6.3)\n",
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.1.99)\n",
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.0.1+cu118)\n",
"Requirement already satisfied: transformers>=4.31.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (4.33.0.dev0)\n",
"Requirement already satisfied: xformers>=0.0.19 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.0.21)\n",
"Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.101.1)\n",
"Requirement already satisfied: uvicorn in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.23.2)\n",
"Requirement already satisfied: pydantic<2 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.10.12)\n",
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (23.1.0)\n",
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (2.1.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (4.0.3)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.9.2)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.4.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.3.1)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (3.9.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (4.7.1)\n",
"Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (8.1.7)\n",
"Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.18.0)\n",
"Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.0.5)\n",
"Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.24.1)\n",
"Requirement already satisfied: grpcio>=1.42.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.57.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (3.4)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (1.26.13)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (2022.12.7)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.0)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.1.2)\n",
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (2.0.0)\n",
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (3.25.0)\n",
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (15.0.7)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (2023.8.8)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.3.2)\n",
"Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi->vllm==0.1.3) (0.27.0)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
"Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn->vllm==0.1.3) (0.14.0)\n",
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->datasets==2.14.4) (1.16.0)\n",
"Requirement already satisfied: anyio<5,>=3.4.0 in /usr/local/lib/python3.10/dist-packages (from starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (3.7.1)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.0.0->vllm==0.1.3) (2.1.2)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (2023.6.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.29.1)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.8.10)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.0.0->vllm==0.1.3) (1.2.1)\n",
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.3.0)\n",
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.1.2)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"outputs": [],
"source": [
"%%capture\n",
"%pip install datasets==2.14.4 vllm==0.1.3"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -276,12 +194,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Nice! I've processed all 2,147,248 recipes in under 17 hours. Let's do a cost comparison with GPT-3.5 and GPT-4. I'll use the GPT-4 latency/cost numbers based on the 5000 samples used to generate our model's training data."
"Nice! I've processed all 2,147,248 recipes in under 17 hours. Let's do a cost comparison with GPT-3.5 and GPT-4. I'll use the GPT-4 latency/cost numbers based on the 5000 samples used to generate our model's training data.\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 23,
"metadata": {},
"outputs": [
{
@@ -313,47 +231,47 @@
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Llama 2 7B (finetuned)</td>\n",
" <td>Llama2 7B (FT)</td>\n",
" <td>0.000009</td>\n",
" <td>18.86</td>\n",
" <td>18.81</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>GPT-3.5</td>\n",
" <td>0.000481</td>\n",
" <td>1,033.26</td>\n",
" <td>1033.26</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>GPT-3.5 (finetuned)</td>\n",
" <td>GPT-3.5 (FT)</td>\n",
" <td>0.004044</td>\n",
" <td>8,683.47</td>\n",
" <td>8683.47</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>GPT-4</td>\n",
" <td>0.010800</td>\n",
" <td>23,190.28</td>\n",
" <td>23190.28</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Model Cost to Classify One Recipe \\\n",
"0 Llama 2 7B (finetuned) 0.000009 \n",
"1 GPT-3.5 0.000481 \n",
"2 GPT-3.5 (finetuned) 0.004044 \n",
"3 GPT-4 0.010800 \n",
" Model Cost to Classify One Recipe \\\n",
"0 Llama2 7B (FT) 0.000009 \n",
"1 GPT-3.5 0.000481 \n",
"2 GPT-3.5 (FT) 0.004044 \n",
"3 GPT-4 0.010800 \n",
"\n",
" Cost to Classify Entire Dataset \n",
"0 18.86 \n",
"1 1,033.26 \n",
"2 8,683.47 \n",
"3 23,190.28 "
" Cost to Classify Entire Dataset \n",
"0 18.81 \n",
"1 1033.26 \n",
"2 8683.47 \n",
"3 23190.28 "
]
},
"execution_count": 19,
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
@@ -382,12 +300,12 @@
" avg_input_tokens * 0.012 / 1000 + avg_output_tokens * 0.016 / 1000 + 0.06 / 1000\n",
")\n",
"\n",
"costs = pd.DataFrame(\n",
"models = pd.DataFrame(\n",
" {\n",
" \"Model\": [\n",
" \"Llama 2 7B (finetuned)\",\n",
" \"Llama2 7B (FT)\",\n",
" \"GPT-3.5\",\n",
" \"GPT-3.5 (finetuned)\",\n",
" \"GPT-3.5 (FT)\",\n",
" \"GPT-4\",\n",
" ],\n",
" \"Cost to Classify One Recipe\": [\n",
@@ -399,12 +317,11 @@
" }\n",
")\n",
"\n",
"costs[\"Cost to Classify Entire Dataset\"] = (\n",
" costs[\"Cost to Classify One Recipe\"] * len(all_recipes)\n",
").map(lambda x: f\"{x:,.2f}\")\n",
"models[\"Cost to Classify Entire Dataset\"] = (\n",
" models[\"Cost to Classify One Recipe\"] * len(all_recipes)\n",
").round(2)\n",
"\n",
"\n",
"costs\n"
"models\n"
]
}
],

View File

@@ -1,10 +1,14 @@
# OpenPipe demo: fine-tuning your own model
# Tutorial: Fine-Tune your Own Llama 2
Hi there! This repository should give you a brief overview of how to fine-tune a competitive model from start to finish. You should review the notebooks in this directory in the following order:
Hi there! This directory should give you a brief overview of how to fine-tune a Llama 2 model from start to finish. The example model we're training will classify recipes from [a large dataset scraped from the internet](https://www.kaggle.com/datasets/wilmerarltstrmberg/recipe-dataset-over-2m). We'll use GPT-4 to generate labels for our training and test set, then fine-tune a Llama 2 model using the [axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) library. You should review the notebooks in this directory in the following order:
1. [./generate-data.ipynb](./generate-data.ipynb): Demonstrates how to generate a sample dataset of GPT-4 completions, store it using OpenPipe, and then export it in a format suitable for training a model.
2. [./train.ipynb](./train.ipynb): Trains a Llama 2 7B model on the dataset from step (1).
3. [./evaluate.ipynb](./evaluate.ipynb): Evaluates the model we trained using a special test set that we set aside in step (1).
4. [./benchmark.ipynb](./benchmark.ipynb): A script to compare costs and completion latencies between our fine-tuned model, GPT-3.5, and GPT-4.
1. [./1-generate-data.ipynb](./1-generate-data.ipynb): Demonstrates how to generate a sample dataset of GPT-4 completions, store it using OpenPipe, and then export it in a format suitable for training a model.
2. [./2-train.ipynb](./2-train.ipynb): Trains a Llama 2 7B model on the dataset from step (1).
3. [./3-evaluate.ipynb](./3-evaluate.ipynb): Evaluates the model we trained using a special test set that we set aside in step (1).
4. [./4-benchmark.ipynb](./4-benchmark.ipynb): A script to compare costs and completion latencies between our fine-tuned model, GPT-3.5, and GPT-4.
If you want to follow along yourself, I recommend using [RunPod](https://www.runpod.io/). The training scripts we use will run on any of their GPUs with 24GB of vRAM or more.
If you want to follow along yourself, I recommend using [RunPod](https://www.runpod.io/). The training scripts we use will run on any of their GPUs with 24GB of vRAM or more.
## About OpenPipe
[OpenPipe](https://openpipe.ai) is an open-source company that makes it easy for product engineers to build and deploy their own fine-tuned models. OpenPipe actually takes care of a lot of the steps this repository covers for you automatically, but we still wanted to give back and explain how fine-tuning works under the hood.

View File

@@ -0,0 +1,207 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sample training data:\n",
"{'messages': [{'content': 'Your goal is to classify a recipe along several '\n",
" 'dimensions.Pay attention to the instructions.',\n",
" 'role': 'system'},\n",
" {'content': 'Homemade Salad Dressing\\n'\n",
" '\\n'\n",
" 'Ingredients:\\n'\n",
" \"- 1 pt. Hellmann's mayonnaise\\n\"\n",
" '- 1 pt. buttermilk\\n'\n",
" '- 1 tsp. Accent\\n'\n",
" '- 2 Tbsp. dry parsley\\n'\n",
" '- 2 pkg. low-calorie Italian salad dressing mix\\n'\n",
" '- 1 can jalapeno peppers or 4 oz. Jimenez green '\n",
" 'sauce\\n'\n",
" '\\n'\n",
" 'Directions:\\n'\n",
" '- Blend well in blender; store in refrigerator.\\n'\n",
" '- For dip, decrease liquid.',\n",
" 'role': 'user'},\n",
" {'content': '{\\n'\n",
" '\"has_non_fish_meat\": false,\\n'\n",
" '\"requires_oven\": false,\\n'\n",
" '\"requires_stove\": false,\\n'\n",
" '\"cook_time_over_30_mins\": false,\\n'\n",
" '\"main_dish\": false\\n'\n",
" '}',\n",
" 'role': 'assistant'}]}\n"
]
}
],
"source": [
"import pandas as pd\n",
"from pprint import pprint\n",
"import json\n",
"\n",
"df = pd.read_json(\"data/train.jsonl\", lines=True)\n",
"\n",
"training_data = []\n",
"for row in df.itertuples():\n",
" input = json.loads(row.instruction)\n",
" output = json.loads(row.output)\n",
"\n",
" output[\"content\"] = output[\"function_call\"][\"arguments\"]\n",
" del output[\"function_call\"]\n",
"\n",
" sample = {\"messages\": input.copy() + [output]}\n",
" training_data.append(sample)\n",
"\n",
"# save the training data to data/train-gpt3.5.jsonl\n",
"\n",
"with open(\"data/train-gpt3.5.jsonl\", \"w\") as f:\n",
" for sample in training_data:\n",
" f.write(json.dumps(sample) + \"\\n\")\n",
"\n",
"print(f\"Sample training data:\")\n",
"pprint(training_data[0])\n"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<File file id=file-faAdQ1KPxZH79ThW4Dbu4z1y at 0x7fa55db5c6d0> JSON: {\n",
" \"object\": \"file\",\n",
" \"id\": \"file-faAdQ1KPxZH79ThW4Dbu4z1y\",\n",
" \"purpose\": \"fine-tune\",\n",
" \"filename\": \"recipe-classification\",\n",
" \"bytes\": 4210831,\n",
" \"created_at\": 1693000959,\n",
" \"status\": \"uploaded\",\n",
" \"status_details\": null\n",
"}"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import os\n",
"import openai\n",
"\n",
"import dotenv\n",
"\n",
"dotenv.load_dotenv()\n",
"\n",
"openai.api_key = os.getenv(\"OPENAI_API_KEY\")\n",
"\n",
"openai.File.create(\n",
" file=open(\"data/train-gpt3.5.jsonl\", \"rb\"),\n",
" purpose=\"fine-tune\",\n",
" user_provided_filename=\"recipe-classification\",\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<OpenAIObject list at 0x7fa55dbf6930> JSON: {\n",
" \"object\": \"list\",\n",
" \"data\": [\n",
" {\n",
" \"object\": \"file\",\n",
" \"id\": \"file-faAdQ1KPxZH79ThW4Dbu4z1y\",\n",
" \"purpose\": \"fine-tune\",\n",
" \"filename\": \"recipe-classification\",\n",
" \"bytes\": 4210831,\n",
" \"created_at\": 1693000959,\n",
" \"status\": \"processed\",\n",
" \"status_details\": null\n",
" }\n",
" ]\n",
"}"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"openai.File.list()\n"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<FineTuningJob fine_tuning.job id=ftjob-EjjLxmj9P8apwPRk5s2NPSeB at 0x7fa55ddc4360> JSON: {\n",
" \"object\": \"fine_tuning.job\",\n",
" \"id\": \"ftjob-EjjLxmj9P8apwPRk5s2NPSeB\",\n",
" \"model\": \"gpt-3.5-turbo-0613\",\n",
" \"created_at\": 1693001190,\n",
" \"finished_at\": null,\n",
" \"fine_tuned_model\": null,\n",
" \"organization_id\": \"org-jRz4nVPMoeGHWL5nVR3Mb0kp\",\n",
" \"result_files\": [],\n",
" \"status\": \"created\",\n",
" \"validation_file\": null,\n",
" \"training_file\": \"file-faAdQ1KPxZH79ThW4Dbu4z1y\",\n",
" \"hyperparameters\": {\n",
" \"n_epochs\": 3\n",
" },\n",
" \"trained_tokens\": null\n",
"}"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"openai.FineTuningJob.create(\n",
" training_file=\"file-faAdQ1KPxZH79ThW4Dbu4z1y\", model=\"gpt-3.5-turbo\"\n",
")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

774
pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff