Compare commits

..

7 Commits

Author SHA1 Message Date
David Corbitt
83d71c6e9d Record updated version number 2023-08-31 22:12:29 -07:00
David Corbitt
1693ac1c58 Publish updated README to npm 2023-08-31 22:09:03 -07:00
David Corbitt
8de0c0fc5a Update npm lib README 2023-08-31 21:57:30 -07:00
David Corbitt
3ed390c941 Close project menu after creating a project 2023-08-31 19:31:28 -07:00
David Corbitt
fa16dd61dc Add publish script 2023-08-31 18:49:03 -07:00
David Corbitt
cb73598148 Update package version 2023-08-31 18:48:56 -07:00
David Corbitt
2f01e53cf3 Update README 2023-08-31 18:48:44 -07:00
102 changed files with 1171 additions and 4778 deletions

View File

@@ -40,8 +40,3 @@ SMTP_HOST="placeholder"
SMTP_PORT="placeholder"
SMTP_LOGIN="placeholder"
SMTP_PASSWORD="placeholder"
# Azure credentials are necessary for uploading large training data files
AZURE_STORAGE_ACCOUNT_NAME="placeholder"
AZURE_STORAGE_ACCOUNT_KEY="placeholder"
AZURE_STORAGE_CONTAINER_NAME="placeholder"

4
app/.gitignore vendored
View File

@@ -47,7 +47,3 @@ yarn-error.log*
# custom openai intialization
src/server/utils/openaiCustomConfig.json
# yalc
.yalc
yalc.lock

View File

@@ -19,8 +19,6 @@ declare module "nextjs-routes" {
| DynamicRoute<"/api/v1/[...trpc]", { "trpc": string[] }>
| StaticRoute<"/api/v1/openapi">
| StaticRoute<"/dashboard">
| DynamicRoute<"/datasets/[id]", { "id": string }>
| StaticRoute<"/datasets">
| DynamicRoute<"/experiments/[experimentSlug]", { "experimentSlug": string }>
| StaticRoute<"/experiments">
| StaticRoute<"/fine-tunes">

View File

@@ -26,8 +26,6 @@
"dependencies": {
"@anthropic-ai/sdk": "^0.5.8",
"@apidevtools/json-schema-ref-parser": "^10.1.0",
"@azure/identity": "^3.3.0",
"@azure/storage-blob": "12.15.0",
"@babel/standalone": "^7.22.9",
"@chakra-ui/anatomy": "^2.2.0",
"@chakra-ui/next-js": "^2.1.4",
@@ -71,7 +69,6 @@
"jsonschema": "^1.4.1",
"kysely": "^0.26.1",
"kysely-codegen": "^0.10.1",
"llama-tokenizer-js": "^1.1.3",
"lodash-es": "^4.17.21",
"lucide-react": "^0.265.0",
"marked": "^7.0.3",
@@ -82,7 +79,7 @@
"nextjs-routes": "^2.0.1",
"nodemailer": "^6.9.4",
"openai": "4.0.0-beta.7",
"openpipe": "0.4.0-beta.1",
"openpipe": "^0.3.0",
"openpipe-dev": "workspace:^",
"pg": "^8.11.2",
"pluralize": "^8.0.0",

View File

@@ -1,26 +0,0 @@
/*
Warnings:
- Added the required column `inputTokens` to the `DatasetEntry` table without a default value. This is not possible if the table is not empty.
- Added the required column `outputTokens` to the `DatasetEntry` table without a default value. This is not possible if the table is not empty.
- Added the required column `type` to the `DatasetEntry` table without a default value. This is not possible if the table is not empty.
*/
-- CreateEnum
CREATE TYPE "DatasetEntryType" AS ENUM ('TRAIN', 'TEST');
-- AlterTable
ALTER TABLE "Dataset" ADD COLUMN "trainingRatio" DOUBLE PRECISION NOT NULL DEFAULT 0.8;
-- AlterTable
ALTER TABLE "DatasetEntry" ADD COLUMN "input" JSONB NOT NULL DEFAULT '[]',
ADD COLUMN "inputTokens" INTEGER NOT NULL DEFAULT 0,
ADD COLUMN "output" JSONB,
ADD COLUMN "outputTokens" INTEGER NOT NULL DEFAULT 0,
ADD COLUMN "type" "DatasetEntryType" NOT NULL DEFAULT 'TRAIN';
-- CreateIndex
CREATE INDEX "DatasetEntry_datasetId_createdAt_id_idx" ON "DatasetEntry"("datasetId", "createdAt", "id");
-- CreateIndex
CREATE INDEX "DatasetEntry_datasetId_type_idx" ON "DatasetEntry"("datasetId", "type");

View File

@@ -1,5 +0,0 @@
-- AlterTable
ALTER TABLE "DatasetEntry" ALTER COLUMN "loggedCallId" DROP NOT NULL,
ALTER COLUMN "inputTokens" DROP DEFAULT,
ALTER COLUMN "outputTokens" DROP DEFAULT,
ALTER COLUMN "type" DROP DEFAULT;

View File

@@ -1,23 +0,0 @@
-- CreateEnum
CREATE TYPE "DatasetFileUploadStatus" AS ENUM ('PENDING', 'DOWNLOADING', 'PROCESSING', 'SAVING', 'COMPLETE', 'ERROR');
-- CreateTable
CREATE TABLE "DatasetFileUpload" (
"id" UUID NOT NULL,
"datasetId" UUID NOT NULL,
"blobName" TEXT NOT NULL,
"fileName" TEXT NOT NULL,
"fileSize" INTEGER NOT NULL,
"progress" INTEGER NOT NULL DEFAULT 0,
"status" "DatasetFileUploadStatus" NOT NULL DEFAULT 'PENDING',
"uploadedAt" TIMESTAMP(3) NOT NULL,
"visible" BOOLEAN NOT NULL DEFAULT true,
"errorMessage" TEXT,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
CONSTRAINT "DatasetFileUpload_pkey" PRIMARY KEY ("id")
);
-- AddForeignKey
ALTER TABLE "DatasetFileUpload" ADD CONSTRAINT "DatasetFileUpload_datasetId_fkey" FOREIGN KEY ("datasetId") REFERENCES "Dataset"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -176,42 +176,12 @@ model OutputEvaluation {
@@unique([modelResponseId, evaluationId])
}
enum DatasetFileUploadStatus {
PENDING
DOWNLOADING
PROCESSING
SAVING
COMPLETE
ERROR
}
model DatasetFileUpload {
id String @id @default(uuid()) @db.Uuid
datasetId String @db.Uuid
dataset Dataset @relation(fields: [datasetId], references: [id], onDelete: Cascade)
blobName String
fileName String
fileSize Int
progress Int @default(0) // Percentage
status DatasetFileUploadStatus @default(PENDING)
uploadedAt DateTime
visible Boolean @default(true)
errorMessage String?
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
}
model Dataset {
id String @id @default(uuid()) @db.Uuid
name String
datasetEntries DatasetEntry[]
fineTunes FineTune[]
datasetFileUploads DatasetFileUpload[]
trainingRatio Float @default(0.8)
name String
datasetEntries DatasetEntry[]
fineTunes FineTune[]
projectId String @db.Uuid
project Project @relation(fields: [projectId], references: [id], onDelete: Cascade)
@@ -220,32 +190,17 @@ model Dataset {
updatedAt DateTime @updatedAt
}
enum DatasetEntryType {
TRAIN
TEST
}
model DatasetEntry {
id String @id @default(uuid()) @db.Uuid
loggedCallId String? @db.Uuid
loggedCall LoggedCall? @relation(fields: [loggedCallId], references: [id], onDelete: Cascade)
input Json @default("[]")
output Json?
inputTokens Int
outputTokens Int
type DatasetEntryType
loggedCallId String @db.Uuid
loggedCall LoggedCall @relation(fields: [loggedCallId], references: [id], onDelete: Cascade)
datasetId String @db.Uuid
dataset Dataset? @relation(fields: [datasetId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
@@index([datasetId, createdAt, id])
@@index([datasetId, type])
}
model Project {
@@ -497,7 +452,7 @@ model FineTune {
deploymentFinishedAt DateTime?
datasetId String @db.Uuid
dataset Dataset @relation(fields: [datasetId], references: [id], onDelete: Cascade)
dataset Dataset @relation(fields: [datasetId], references: [id], onDelete: Cascade)
projectId String @db.Uuid
project Project @relation(fields: [projectId], references: [id], onDelete: Cascade)

View File

@@ -1,4 +1,5 @@
import { prisma } from "~/server/db";
import { generateNewCell } from "~/server/utils/generateNewCell";
import dedent from "dedent";
import { execSync } from "child_process";
import fs from "fs";

View File

@@ -108,7 +108,7 @@ const MODEL_RESPONSE_TEMPLATES: {
inputTokens: 236,
outputTokens: 5,
finishReason: "stop",
tags: [{ name: "prompt_id", value: "add_scenario" }],
tags: [{ name: "prompt_id", value: "define_func" }],
},
{
reqPayload: {
@@ -311,7 +311,7 @@ const MODEL_RESPONSE_TEMPLATES: {
outputTokens: 108,
finishReason: "stop",
tags: [
{ name: "prompt_id", value: "define_func" },
{ name: "prompt_id", value: "chatcmpl-7" },
{ name: "some_other_tag", value: "some_other_value" },
],
},

View File

@@ -2,15 +2,17 @@ import { Button, Icon, useDisclosure, Text } from "@chakra-ui/react";
import { useRouter } from "next/router";
import { BsTrash } from "react-icons/bs";
import { useAppStore } from "~/state/store";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import DeleteExperimentDialog from "../DeleteExperimentDialog";
import DeleteExperimentDialog from "../experiments/DeleteExperimentDialog";
export const DeleteButton = ({ closeDrawer }: { closeDrawer: () => void }) => {
export const DeleteButton = () => {
const experiment = useExperiment();
const router = useRouter();
const disclosure = useDisclosure();
const closeDrawer = useAppStore((s) => s.closeDrawer);
const [onDelete] = useHandledAsyncCallback(async () => {
await router.push({ pathname: "/experiments" });
closeDrawer();

View File

@@ -7,19 +7,18 @@ import {
DrawerOverlay,
Heading,
VStack,
type UseDisclosureReturn,
} from "@chakra-ui/react";
import EditScenarioVars from "./EditScenarioVars";
import EditEvaluations from "./EditEvaluations";
import EditScenarioVars from "../OutputsTable/EditScenarioVars";
import EditEvaluations from "../OutputsTable/EditEvaluations";
import { useAppStore } from "~/state/store";
import { DeleteButton } from "./DeleteButton";
export default function ExperimentSettingsDrawer({
disclosure,
}: {
disclosure: UseDisclosureReturn;
}) {
export default function ExperimentSettingsDrawer() {
const isOpen = useAppStore((state) => state.drawerOpen);
const closeDrawer = useAppStore((state) => state.closeDrawer);
return (
<Drawer placement="right" size="md" {...disclosure}>
<Drawer isOpen={isOpen} placement="right" onClose={closeDrawer} size="md">
<DrawerOverlay />
<DrawerContent>
<DrawerCloseButton />
@@ -32,7 +31,7 @@ export default function ExperimentSettingsDrawer({
<EditScenarioVars />
<EditEvaluations />
</VStack>
<DeleteButton closeDrawer={disclosure.onClose} />
<DeleteButton />
</VStack>
</DrawerBody>
</DrawerContent>

View File

@@ -16,16 +16,12 @@ import {
import { FiChevronDown } from "react-icons/fi";
import { BiCheck } from "react-icons/bi";
import { isEqual } from "lodash-es";
import React from "react";
type InputDropdownProps<T> = {
options: ReadonlyArray<T>;
selectedOption: T;
onSelect: (option: T) => void;
inputGroupProps?: InputGroupProps;
getDisplayLabel?: (option: T) => string;
isDisabled?: boolean;
};
const InputDropdown = <T,>({
@@ -33,21 +29,19 @@ const InputDropdown = <T,>({
selectedOption,
onSelect,
inputGroupProps,
getDisplayLabel = (option) => option as string,
isDisabled,
}: InputDropdownProps<T>) => {
const { onOpen, ...popover } = useDisclosure();
const popover = useDisclosure();
return (
<Popover placement="bottom-start" onOpen={isDisabled ? undefined : onOpen} {...popover}>
<Popover placement="bottom-start" {...popover}>
<PopoverTrigger>
<InputGroup
cursor="pointer"
w={getDisplayLabel(selectedOption).length * 14 + 180}
w={(selectedOption as string).length * 14 + 180}
{...inputGroupProps}
>
<Input
value={getDisplayLabel(selectedOption)}
value={selectedOption as string}
// eslint-disable-next-line @typescript-eslint/no-empty-function -- controlled input requires onChange
onChange={() => {}}
cursor="pointer"
@@ -58,10 +52,9 @@ const InputDropdown = <T,>({
onFocus={(e) => {
e.target.blur();
}}
isDisabled={isDisabled}
/>
<InputRightElement>
<Icon as={FiChevronDown} color={isDisabled ? "gray.300" : undefined} />
<Icon as={FiChevronDown} />
</InputRightElement>
</InputGroup>
</PopoverTrigger>
@@ -85,10 +78,8 @@ const InputDropdown = <T,>({
fontSize="sm"
borderBottomWidth={1}
>
<Text mr={16}>{getDisplayLabel(option)}</Text>
{isEqual(option, selectedOption) && (
<Icon as={BiCheck} color="blue.500" boxSize={5} />
)}
<Text mr={16}>{option as string}</Text>
{option === selectedOption && <Icon as={BiCheck} color="blue.500" boxSize={5} />}
</HStack>
))}
</VStack>

View File

@@ -19,7 +19,7 @@ import { useCallback, useState } from "react";
import { BsPencil, BsX } from "react-icons/bs";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import AutoResizeTextArea from "~/components/AutoResizeTextArea";
import AutoResizeTextArea from "../AutoResizeTextArea";
type EvalValues = Pick<Evaluation, "label" | "value" | "evalType">;

View File

@@ -5,7 +5,7 @@ import { BsPencil, BsX } from "react-icons/bs";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback, useScenarioVars } from "~/utils/hooks";
import { maybeReportError } from "~/utils/errorHandling/maybeReportError";
import { FloatingLabelInput } from "~/components/OutputsTable/FloatingLabelInput";
import { FloatingLabelInput } from "./FloatingLabelInput";
export const ScenarioVar = ({
variable,

View File

@@ -19,13 +19,15 @@ import {
useScenarios,
} from "~/utils/hooks";
import { BsGear, BsPencil, BsPlus, BsStars } from "react-icons/bs";
import { useAppStore } from "~/state/store";
import { api } from "~/utils/api";
export const ActionButton = (props: ButtonProps) => (
<Button size="sm" variant="ghost" color="gray.600" {...props} />
);
export const ScenariosHeader = ({ openDrawer }: { openDrawer: () => void }) => {
export const ScenariosHeader = () => {
const openDrawer = useAppStore((s) => s.openDrawer);
const { canModify } = useExperimentAccess();
const scenarios = useScenarios();

View File

@@ -20,7 +20,7 @@ export default function VariantStats(props: { variant: PromptVariant }) {
inputTokens: 0,
outputTokens: 0,
scenarioCount: 0,
finishedCount: 0,
outputCount: 0,
awaitingCompletions: false,
awaitingEvals: false,
},
@@ -42,7 +42,7 @@ export default function VariantStats(props: { variant: PromptVariant }) {
const scale = chroma.scale([failColor, neutralColor, passColor]).domain([0, 0.5, 1]);
const showNumFinished = data.scenarioCount > 0 && data.scenarioCount !== data.finishedCount;
const showNumFinished = data.scenarioCount > 0 && data.scenarioCount !== data.outputCount;
return (
<HStack
@@ -55,7 +55,7 @@ export default function VariantStats(props: { variant: PromptVariant }) {
<HStack px={cellPadding.x} flexWrap="wrap">
{showNumFinished && (
<Text>
{data.finishedCount} / {data.scenarioCount}
{data.outputCount} / {data.scenarioCount}
</Text>
)}
{data.evalResults.map((result) => {

View File

@@ -12,13 +12,7 @@ import ScenarioPaginator from "./ScenarioPaginator";
import { Fragment } from "react";
import useScrolledPast from "./useHasScrolledPast";
export default function OutputsTable({
experimentId,
openDrawer,
}: {
experimentId: string | undefined;
openDrawer: () => void;
}) {
export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) {
const variants = api.promptVariants.list.useQuery(
{ experimentId: experimentId as string },
{ enabled: !!experimentId },
@@ -97,7 +91,7 @@ export default function OutputsTable({
colStart={1}
borderRightWidth={0}
>
<ScenariosHeader openDrawer={openDrawer} />
<ScenariosHeader />
</GridItem>
{scenarios.data.scenarios.map((scenario, i) => (

View File

@@ -1,37 +0,0 @@
import {
Drawer,
DrawerBody,
DrawerCloseButton,
DrawerContent,
DrawerHeader,
DrawerOverlay,
Heading,
VStack,
type UseDisclosureReturn,
} from "@chakra-ui/react";
import { DeleteButton } from "./DeleteButton";
export default function DatasetConfigurationDrawer({
disclosure,
}: {
disclosure: UseDisclosureReturn;
}) {
return (
<Drawer placement="right" size="md" {...disclosure}>
<DrawerOverlay />
<DrawerContent>
<DrawerCloseButton />
<DrawerHeader>
<Heading size="md">Dataset Configuration</Heading>
</DrawerHeader>
<DrawerBody h="full" pb={4}>
<VStack h="full" justifyContent="space-between">
<VStack spacing={6}></VStack>
<DeleteButton closeDrawer={disclosure.onClose} />
</VStack>
</DrawerBody>
</DrawerContent>
</Drawer>
);
}

View File

@@ -1,39 +0,0 @@
import { Button, Icon, useDisclosure, Text } from "@chakra-ui/react";
import { useRouter } from "next/router";
import { BsTrash } from "react-icons/bs";
import { useHandledAsyncCallback, useDataset } from "~/utils/hooks";
import DeleteDatasetDialog from "./DeleteDatasetDialog";
export const DeleteButton = ({ closeDrawer }: { closeDrawer: () => void }) => {
const dataset = useDataset();
const router = useRouter();
const disclosure = useDisclosure();
const [onDelete] = useHandledAsyncCallback(async () => {
await router.push({ pathname: "/datasets" });
closeDrawer();
}, [router, closeDrawer]);
return (
<>
<Button
size="sm"
variant="ghost"
colorScheme="red"
fontWeight="normal"
onClick={disclosure.onOpen}
>
<Icon as={BsTrash} boxSize={4} />
<Text ml={2}>Delete Dataset</Text>
</Button>
<DeleteDatasetDialog
datasetId={dataset.data?.id}
onDelete={onDelete}
disclosure={disclosure}
/>
</>
);
};

View File

@@ -1,73 +0,0 @@
import { useRef } from "react";
import {
type UseDisclosureReturn,
AlertDialog,
AlertDialogOverlay,
AlertDialogContent,
AlertDialogHeader,
AlertDialogBody,
AlertDialogFooter,
Button,
} from "@chakra-ui/react";
import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks";
const DeleteDatasetDialog = ({
datasetId,
onDelete,
disclosure,
}: {
datasetId?: string;
onDelete?: () => void;
disclosure: UseDisclosureReturn;
}) => {
const cancelRef = useRef<HTMLButtonElement>(null);
const mutation = api.datasets.delete.useMutation();
const utils = api.useContext();
const [onDeleteConfirm, deletionInProgress] = useHandledAsyncCallback(async () => {
if (!datasetId) return;
await mutation.mutateAsync({ id: datasetId });
await utils.datasets.list.invalidate();
onDelete?.();
disclosure.onClose();
}, [mutation, datasetId, disclosure.onClose]);
console.log("dataset id", datasetId);
return (
<AlertDialog leastDestructiveRef={cancelRef} {...disclosure}>
<AlertDialogOverlay>
<AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold">
Delete Dataset
</AlertDialogHeader>
<AlertDialogBody>
If you delete this dataset all the associated dataset entries will be deleted as well.
Are you sure?
</AlertDialogBody>
<AlertDialogFooter>
<Button ref={cancelRef} onClick={disclosure.onClose}>
Cancel
</Button>
<Button
colorScheme="red"
isLoading={deletionInProgress}
onClick={onDeleteConfirm}
ml={3}
>
Delete
</Button>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialogOverlay>
</AlertDialog>
);
};
export default DeleteDatasetDialog;

View File

@@ -1,46 +0,0 @@
import { Card, Table, Tbody } from "@chakra-ui/react";
import { useState } from "react";
import { useDatasetEntries } from "~/utils/hooks";
import { TableHeader, TableRow, EmptyTableRow } from "./TableRow";
import DatasetEntryEditorDrawer from "./DatasetEntryEditorDrawer";
export default function DatasetEntriesTable() {
const [expandedDatasetEntryId, setExpandedDatasetEntryId] = useState<string | null>(null);
const datasetEntries = useDatasetEntries().data?.entries;
return (
<>
<Card width="100%" overflowX="auto">
<Table>
<TableHeader />
<Tbody>
{datasetEntries?.length ? (
datasetEntries?.map((entry) => {
return (
<TableRow
key={entry.id}
datasetEntry={entry}
onToggle={() => {
if (entry.id === expandedDatasetEntryId) {
setExpandedDatasetEntryId(null);
} else {
setExpandedDatasetEntryId(entry.id);
}
}}
showOptions
/>
);
})
) : (
<EmptyTableRow />
)}
</Tbody>
</Table>
</Card>
<DatasetEntryEditorDrawer
datasetEntryId={expandedDatasetEntryId}
clearDatasetEntryId={() => setExpandedDatasetEntryId(null)}
/>
</>
);
}

View File

@@ -1,174 +0,0 @@
import { useState, useEffect, useMemo } from "react";
import {
Drawer,
DrawerBody,
DrawerCloseButton,
DrawerContent,
DrawerHeader,
DrawerOverlay,
DrawerFooter,
Heading,
VStack,
HStack,
Button,
Text,
Divider,
Icon,
} from "@chakra-ui/react";
import { type CreateChatCompletionRequestMessage } from "openai/resources/chat";
import { BsPlus } from "react-icons/bs";
import { type DatasetEntryType } from "@prisma/client";
import { api } from "~/utils/api";
import { useDatasetEntry, useHandledAsyncCallback } from "~/utils/hooks";
import EditableMessage from "./EditableMessage";
import EntryTypeDropdown from "./EntryTypeDropdown";
export default function DatasetDentryEditorDrawer({
datasetEntryId,
clearDatasetEntryId,
}: {
datasetEntryId: string | null;
clearDatasetEntryId: () => void;
}) {
const utils = api.useContext();
const datasetEntry = useDatasetEntry(datasetEntryId).data;
const savedInputMessages = useMemo(
() => datasetEntry?.input as unknown as CreateChatCompletionRequestMessage[],
[datasetEntry],
);
const savedOutputMessage = useMemo(
() => datasetEntry?.output as unknown as CreateChatCompletionRequestMessage,
[datasetEntry],
);
const [inputMessagesToSave, setInputMessagesToSave] = useState<
CreateChatCompletionRequestMessage[]
>([]);
const [outputMessageToSave, setOutputMessageToSave] =
useState<CreateChatCompletionRequestMessage | null>(null);
useEffect(() => {
if (savedInputMessages) {
setInputMessagesToSave(savedInputMessages);
setOutputMessageToSave(savedOutputMessage);
}
}, [savedInputMessages, savedOutputMessage]);
const updateMutation = api.datasetEntries.update.useMutation();
const [onSave, savingInProgress] = useHandledAsyncCallback(async () => {
if (!datasetEntryId || !inputMessagesToSave) return;
await updateMutation.mutateAsync({
id: datasetEntryId,
updates: {
input: JSON.stringify(inputMessagesToSave),
output: JSON.stringify(outputMessageToSave),
},
});
await utils.datasetEntries.list.invalidate();
await utils.datasetEntries.get.invalidate({ id: datasetEntryId });
}, [updateMutation, datasetEntryId, inputMessagesToSave, outputMessageToSave, utils]);
const [onUpdateType] = useHandledAsyncCallback(
async (type: DatasetEntryType) => {
if (!datasetEntryId) return;
await updateMutation.mutateAsync({
id: datasetEntryId,
updates: {
type,
},
});
await utils.datasetEntries.list.invalidate();
await utils.datasetEntries.get.invalidate({ id: datasetEntryId });
},
[updateMutation, datasetEntryId, utils],
);
return (
<Drawer isOpen={!!datasetEntryId} onClose={clearDatasetEntryId} placement="right" size="md">
<DrawerOverlay />
<DrawerContent>
<DrawerCloseButton pt={6} />
<DrawerHeader bgColor="orange.50">
<HStack w="full" justifyContent="space-between" pr={8}>
<Heading size="md">Dataset Entry</Heading>
{datasetEntry && (
<EntryTypeDropdown type={datasetEntry.type} onTypeChange={onUpdateType} />
)}
</HStack>
</DrawerHeader>
<DrawerBody h="full" pb={4} bgColor="orange.50">
<VStack h="full" justifyContent="space-between">
<VStack w="full" spacing={12} py={4}>
<VStack w="full" alignItems="flex-start">
<Text fontWeight="bold">Input</Text>
{inputMessagesToSave.map((message, i) => {
return (
<>
<Divider key={`divider-${i}`} my={4} />
<EditableMessage
key={i}
message={message}
onEdit={(message) => {
const newInputMessages = [...inputMessagesToSave];
newInputMessages[i] = message;
setInputMessagesToSave(newInputMessages);
}}
onDelete={() => {
const newInputMessages = [...inputMessagesToSave];
newInputMessages.splice(i, 1);
setInputMessagesToSave(newInputMessages);
}}
/>
</>
);
})}
<Divider my={4} />
<Button
w="full"
onClick={() =>
setInputMessagesToSave([...inputMessagesToSave, { role: "user", content: "" }])
}
variant="outline"
color="gray.500"
_hover={{ bgColor: "orange.100" }}
>
<HStack spacing={0}>
<Text>Add Message</Text>
<Icon as={BsPlus} boxSize={6} />
</HStack>
</Button>
</VStack>
<VStack w="full" alignItems="flex-start">
<Text fontWeight="bold">Output</Text>
<Divider my={4} />
<EditableMessage
message={outputMessageToSave}
onEdit={(message) => setOutputMessageToSave(message)}
isOutput
/>
</VStack>
</VStack>
</VStack>
</DrawerBody>
<DrawerFooter bgColor="orange.50">
<HStack>
<Button
onClick={() => {
setInputMessagesToSave(savedInputMessages);
setOutputMessageToSave(savedOutputMessage);
}}
>
Reset
</Button>
<Button isLoading={savingInProgress} onClick={onSave} colorScheme="orange">
Save
</Button>
</HStack>
</DrawerFooter>
</DrawerContent>
</Drawer>
);
}

View File

@@ -1,105 +0,0 @@
import { VStack, HStack, Tooltip, IconButton, Icon } from "@chakra-ui/react";
import { type CreateChatCompletionRequestMessage } from "openai/resources/chat";
import { BsX } from "react-icons/bs";
import AutoResizeTextArea from "~/components/AutoResizeTextArea";
import InputDropdown from "~/components/InputDropdown";
import { parseableToFunctionCall } from "~/utils/utils";
import FunctionCallEditor from "./FunctionCallEditor";
const MESSAGE_ROLE_OPTIONS = ["system", "user", "assistant", "function"] as const;
const OUTPUT_OPTIONS = ["plaintext", "func_call"] as const;
const EditableMessage = ({
message,
onEdit,
onDelete,
isOutput,
}: {
message: CreateChatCompletionRequestMessage | null;
onEdit: (message: CreateChatCompletionRequestMessage) => void;
onDelete?: () => void;
isOutput?: boolean;
}) => {
const { role = "assistant", content = "", function_call } = message || {};
const currentOutputOption: (typeof OUTPUT_OPTIONS)[number] = function_call
? "func_call"
: "plaintext";
return (
<VStack w="full">
<HStack w="full" justifyContent="space-between">
<HStack>
{!isOutput && (
<InputDropdown
options={MESSAGE_ROLE_OPTIONS}
selectedOption={role}
onSelect={(option) => {
const updatedMessage = { role: option, content };
if (role === "assistant" && currentOutputOption === "func_call") {
updatedMessage.content = JSON.stringify(function_call, null, 2);
}
onEdit(updatedMessage);
}}
inputGroupProps={{ w: "32", bgColor: "white" }}
/>
)}
{role === "assistant" && (
<InputDropdown
options={OUTPUT_OPTIONS}
selectedOption={currentOutputOption}
onSelect={(option) => {
const updatedMessage: CreateChatCompletionRequestMessage = {
role,
content: null,
function_call: undefined,
};
if (option === "plaintext") {
updatedMessage.content = JSON.stringify(function_call, null, 2);
} else if (option === "func_call") {
updatedMessage.function_call =
content && parseableToFunctionCall(content)
? JSON.parse(content)
: { name: "", arguments: "{}" };
}
onEdit(updatedMessage);
}}
inputGroupProps={{ w: "32", bgColor: "white" }}
/>
)}
</HStack>
{!isOutput && (
<HStack>
<Tooltip label="Delete" hasArrow>
<IconButton
aria-label="Delete"
icon={<Icon as={BsX} boxSize={6} />}
onClick={onDelete}
size="xs"
display="flex"
colorScheme="gray"
color="gray.500"
variant="ghost"
/>
</Tooltip>
</HStack>
)}
</HStack>
{function_call ? (
<FunctionCallEditor
function_call={function_call}
onEdit={(function_call) => onEdit({ role, function_call, content: null })}
/>
) : (
<AutoResizeTextArea
value={content || JSON.stringify(function_call, null, 2)}
onChange={(e) => onEdit({ role, content: e.target.value })}
bgColor="white"
/>
)}
</VStack>
);
};
export default EditableMessage;

View File

@@ -1,24 +0,0 @@
import { type DatasetEntryType } from "@prisma/client";
import InputDropdown from "~/components/InputDropdown";
const ENTRY_TYPE_OPTIONS: DatasetEntryType[] = ["TRAIN", "TEST"];
const EntryTypeDropdown = ({
type,
onTypeChange,
}: {
type: DatasetEntryType;
onTypeChange: (type: DatasetEntryType) => void;
}) => {
return (
<InputDropdown
options={ENTRY_TYPE_OPTIONS}
selectedOption={type}
onSelect={onTypeChange}
inputGroupProps={{ w: "32", bgColor: "white" }}
/>
);
};
export default EntryTypeDropdown;

View File

@@ -1,125 +0,0 @@
import { useRef, useMemo, useEffect } from "react";
import { VStack, HStack, Text, Input, Box } from "@chakra-ui/react";
import { type CreateChatCompletionRequestMessage } from "openai/resources/chat";
import { useAppStore } from "~/state/store";
import { type CreatedEditor } from "~/state/sharedVariantEditor.slice";
const FunctionCallEditor = ({
function_call,
onEdit,
}: {
function_call: CreateChatCompletionRequestMessage.FunctionCall;
onEdit: (function_call: CreateChatCompletionRequestMessage.FunctionCall) => void;
}) => {
const monaco = useAppStore.use.sharedArgumentsEditor.monaco();
const editorRef = useRef<CreatedEditor | null>(null);
const editorId = useMemo(() => `editor_${Math.random().toString(36).substring(7)}`, []);
useEffect(() => {
if (monaco) {
const container = document.getElementById(editorId) as HTMLElement;
const editor = monaco.editor.create(container, {
value: function_call.arguments,
language: "json",
theme: "customTheme",
lineNumbers: "off",
minimap: { enabled: false },
wrappingIndent: "indent",
wrappingStrategy: "advanced",
wordWrap: "on",
folding: false,
scrollbar: {
alwaysConsumeMouseWheel: false,
verticalScrollbarSize: 0,
},
wordWrapBreakAfterCharacters: "",
wordWrapBreakBeforeCharacters: "",
quickSuggestions: true,
renderLineHighlight: "none",
fontSize: 14,
scrollBeyondLastLine: false,
});
editorRef.current = editor;
const updateHeight = () => {
const contentHeight = editor.getContentHeight();
container.style.height = `${contentHeight}px`;
editor.layout();
};
const attemptDocumentFormat = () => {
const action = editor.getAction("editor.action.formatDocument");
if (action) {
action
.run()
.then(updateHeight)
.catch((error) => {
console.error("Error running formatDocument:", error);
});
return true;
}
return false;
};
editor.onDidBlurEditorText(() => {
attemptDocumentFormat();
onEdit({ name: function_call.name, arguments: editor.getValue() });
});
// Interval function to check for action availability
const checkForActionInterval = setInterval(() => {
const formatted = attemptDocumentFormat();
if (formatted) {
clearInterval(checkForActionInterval); // Clear the interval once the action is found and run
}
}, 100); // Check every 100ms
// Add content change listener
const contentChangeListener = editor.onDidChangeModelContent(updateHeight);
const resizeObserver = new ResizeObserver(() => {
editor.layout();
});
resizeObserver.observe(container);
return () => {
contentChangeListener.dispose();
resizeObserver.disconnect();
editor?.dispose();
};
}
}, [monaco, editorId, function_call.name, function_call.arguments, onEdit]);
return (
<VStack w="full" alignItems="flex-start">
<HStack w="full">
<Text fontWeight="bold" w={192}>
Name:
</Text>
<Input
value={function_call.name}
onChange={(e) => onEdit({ name: e.target.value, arguments: function_call.arguments })}
bgColor="white"
/>
</HStack>
<Text fontWeight="bold" w={32}>
Arguments
</Text>
<VStack
borderRadius={4}
border="1px solid"
borderColor="gray.200"
w="full"
py={1}
bgColor="white"
>
<Box id={editorId} w="full" />
</VStack>
</VStack>
);
};
export default FunctionCallEditor;

View File

@@ -1,128 +0,0 @@
import { Box, Td, Tr, Thead, Th, Tooltip, HStack, Text, Checkbox } from "@chakra-ui/react";
import Link from "next/link";
import dayjs from "~/utils/dayjs";
import { type RouterOutputs } from "~/utils/api";
import { useAppStore } from "~/state/store";
import { useIsClientRehydrated, useDatasetEntries } from "~/utils/hooks";
import { useMemo } from "react";
type DatasetEntry = RouterOutputs["datasetEntries"]["list"]["entries"][0];
export const TableHeader = () => {
const matchingDatasetEntryIds = useDatasetEntries().data?.matchingEntryIds;
const selectedDatasetEntryIds = useAppStore((s) => s.selectedDatasetEntries.selectedIds);
const addSelectedIds = useAppStore((s) => s.selectedDatasetEntries.addSelectedIds);
const clearSelectedIds = useAppStore((s) => s.selectedDatasetEntries.clearSelectedIds);
const allSelected = useMemo(() => {
if (!matchingDatasetEntryIds || !matchingDatasetEntryIds.length) return false;
return matchingDatasetEntryIds.every((id) => selectedDatasetEntryIds.has(id));
}, [matchingDatasetEntryIds, selectedDatasetEntryIds]);
const isClientRehydrated = useIsClientRehydrated();
if (!isClientRehydrated) return null;
return (
<Thead>
<Tr>
<Th pr={0}>
<HStack minW={16}>
<Checkbox
isChecked={allSelected}
onChange={() => {
allSelected ? clearSelectedIds() : addSelectedIds(matchingDatasetEntryIds || []);
}}
/>
<Text>
({selectedDatasetEntryIds.size ? `${selectedDatasetEntryIds.size}/` : ""}
{matchingDatasetEntryIds?.length || 0})
</Text>
</HStack>
</Th>
<Th>Created At</Th>
<Th isNumeric>Input tokens</Th>
<Th isNumeric>Output tokens</Th>
<Th isNumeric>Type</Th>
</Tr>
</Thead>
);
};
export const TableRow = ({
datasetEntry,
onToggle,
showOptions,
}: {
datasetEntry: DatasetEntry;
onToggle: () => void;
showOptions?: boolean;
}) => {
const createdAt = dayjs(datasetEntry.createdAt).format("MMMM D h:mm A");
const fullTime = dayjs(datasetEntry.createdAt).toString();
const isChecked = useAppStore((s) => s.selectedDatasetEntries.selectedIds.has(datasetEntry.id));
const toggleChecked = useAppStore((s) => s.selectedDatasetEntries.toggleSelectedId);
const isClientRehydrated = useIsClientRehydrated();
if (!isClientRehydrated) return null;
return (
<Tr
onClick={onToggle}
key={datasetEntry.id}
_hover={{ bgColor: "gray.50", cursor: "pointer" }}
fontSize="sm"
>
{showOptions && (
<Td>
<Checkbox isChecked={isChecked} onChange={() => toggleChecked(datasetEntry.id)} />
</Td>
)}
<Td>
<Tooltip label={fullTime} placement="top">
<Box whiteSpace="nowrap" minW="120px">
{createdAt}
</Box>
</Tooltip>
</Td>
<Td isNumeric>{datasetEntry.inputTokens}</Td>
<Td isNumeric>{datasetEntry.outputTokens}</Td>
<Td isNumeric>{datasetEntry.type}</Td>
</Tr>
);
};
export const EmptyTableRow = ({ filtersApplied = true }: { filtersApplied?: boolean }) => {
const visibleColumns = useAppStore((s) => s.columnVisibility.visibleColumns);
const filters = useAppStore((state) => state.logFilters.filters);
const { isLoading } = useDatasetEntries();
if (isLoading) return null;
if (filters.length && filtersApplied) {
return (
<Tr>
<Td w="full" colSpan={visibleColumns.size + 1}>
<Text color="gray.500" textAlign="center" w="full" p={4}>
No matching entries found. Try removing some filters.
</Text>
</Td>
</Tr>
);
}
return (
<Tr>
<Td w="full" colSpan={visibleColumns.size + 1}>
<Text color="gray.500" textAlign="center" w="full" p={4}>
This dataset has no entries. Add some logs in the{" "}
<Link href="/request-logs">
<Text as="span" color="blue.600">
Request Logs
</Text>
</Link>{" "}
tab.
</Text>
</Td>
</Tr>
);
};

View File

@@ -1,16 +0,0 @@
import { type StackProps } from "@chakra-ui/react";
import { useDatasetEntries } from "~/utils/hooks";
import Paginator from "../Paginator";
const DatasetEntryPaginator = (props: StackProps) => {
const { data } = useDatasetEntries();
if (!data) return null;
const { matchingEntryIds } = data;
return <Paginator count={matchingEntryIds.length} {...props} />;
};
export default DatasetEntryPaginator;

View File

@@ -1,20 +0,0 @@
import { Button, HStack, Icon, Text } from "@chakra-ui/react";
import { useDataset } from "~/utils/hooks";
import { BsGearFill } from "react-icons/bs";
export const DatasetHeaderButtons = ({ openDrawer }: { openDrawer: () => void }) => {
const dataset = useDataset();
if (dataset.isLoading) return null;
return (
<HStack spacing={0} mt={{ base: 2, md: 0 }}>
<Button variant={{ base: "solid", md: "ghost" }} onClick={openDrawer}>
<HStack>
<Icon as={BsGearFill} />
<Text>Configure</Text>
</HStack>
</Button>
</HStack>
);
};

View File

@@ -1,52 +0,0 @@
import { Card, Table, Thead, Tr, Th, Tbody, Td, VStack, Icon, Text } from "@chakra-ui/react";
import { FaTable } from "react-icons/fa";
import Link from "next/link";
import dayjs from "~/utils/dayjs";
import { useDatasets } from "~/utils/hooks";
const DatasetsTable = ({}) => {
const { data } = useDatasets();
const datasets = data || [];
return (
<Card width="100%" overflowX="auto">
{datasets.length ? (
<Table>
<Thead>
<Tr>
<Th>Name</Th>
<Th>Created At</Th>
<Th>Size</Th>
</Tr>
</Thead>
<Tbody>
{datasets.map((dataset) => {
return (
<Tr key={dataset.id}>
<Td>
<Link href={{ pathname: "/datasets/[id]", query: { id: dataset.id } }}>
<Text color="blue.600">{dataset.name}</Text>
</Link>
</Td>
<Td>{dayjs(dataset.createdAt).format("MMMM D h:mm A")}</Td>
<Td>{dataset._count.datasetEntries}</Td>
</Tr>
);
})}
</Tbody>
</Table>
) : (
<VStack py={8}>
<Icon as={FaTable} boxSize={16} color="gray.300" />
<Text color="gray.400" fontSize="lg" fontWeight="bold">
No Datasets Found. Create your first dataset.
</Text>
</VStack>
)}
</Card>
);
};
export default DatasetsTable;

View File

@@ -1,107 +0,0 @@
import {
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalCloseButton,
ModalBody,
ModalFooter,
HStack,
VStack,
Icon,
Text,
Button,
useDisclosure,
type UseDisclosureReturn,
} from "@chakra-ui/react";
import { BsTrash } from "react-icons/bs";
import { useHandledAsyncCallback, useDataset } from "~/utils/hooks";
import { api } from "~/utils/api";
import { useAppStore } from "~/state/store";
import ActionButton from "../ActionButton";
import { maybeReportError } from "~/utils/errorHandling/maybeReportError";
import pluralize from "pluralize";
const DeleteButton = () => {
const selectedIds = useAppStore((s) => s.selectedDatasetEntries.selectedIds);
const disclosure = useDisclosure();
return (
<>
<ActionButton
onClick={disclosure.onOpen}
label="Delete"
icon={BsTrash}
isDisabled={selectedIds.size === 0}
requireBeta
/>
<DeleteDatasetEntriesModal disclosure={disclosure} />
</>
);
};
export default DeleteButton;
const DeleteDatasetEntriesModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
const dataset = useDataset().data;
const selectedIds = useAppStore((s) => s.selectedDatasetEntries.selectedIds);
const clearSelectedIds = useAppStore((s) => s.selectedDatasetEntries.clearSelectedIds);
const deleteRowsMutation = api.datasetEntries.delete.useMutation();
const utils = api.useContext();
const [deleteRows, deletionInProgress] = useHandledAsyncCallback(async () => {
if (!dataset?.id || !selectedIds.size) return;
// divide selectedIds into chunks of 15000 to reduce request size
const chunkSize = 15000;
const idsArray = Array.from(selectedIds);
for (let i = 0; i < idsArray.length; i += chunkSize) {
const response = await deleteRowsMutation.mutateAsync({
ids: idsArray.slice(i, i + chunkSize),
});
if (maybeReportError(response)) return;
}
await utils.datasetEntries.list.invalidate();
disclosure.onClose();
clearSelectedIds();
}, [deleteRowsMutation, dataset, selectedIds, utils]);
return (
<Modal size={{ base: "xl", md: "2xl" }} {...disclosure}>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader>
<HStack>
<Icon as={BsTrash} />
<Text>Delete Logs</Text>
</HStack>
</ModalHeader>
<ModalCloseButton />
<ModalBody maxW="unset">
<VStack w="full" spacing={8} pt={4} alignItems="flex-start">
<Text>
Are you sure you want to delete the <b>{selectedIds.size}</b>{" "}
{pluralize("row", selectedIds.size)} rows you've selected?
</Text>
</VStack>
</ModalBody>
<ModalFooter>
<HStack>
<Button colorScheme="gray" onClick={disclosure.onClose} minW={24}>
Cancel
</Button>
<Button colorScheme="red" onClick={deleteRows} isLoading={deletionInProgress} minW={24}>
Delete
</Button>
</HStack>
</ModalFooter>
</ModalContent>
</Modal>
);
};

View File

@@ -1,182 +0,0 @@
import { useState, useEffect } from "react";
import {
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalCloseButton,
ModalBody,
ModalFooter,
HStack,
VStack,
Icon,
Text,
Button,
Checkbox,
NumberInput,
NumberInputField,
NumberInputStepper,
NumberIncrementStepper,
NumberDecrementStepper,
Collapse,
Flex,
useDisclosure,
type UseDisclosureReturn,
} from "@chakra-ui/react";
import { AiOutlineDownload } from "react-icons/ai";
import { useHandledAsyncCallback, useDataset } from "~/utils/hooks";
import { api } from "~/utils/api";
import { useAppStore } from "~/state/store";
import ActionButton from "../ActionButton";
import { FiChevronUp, FiChevronDown } from "react-icons/fi";
import InfoCircle from "../InfoCircle";
const ExportButton = () => {
const selectedIds = useAppStore((s) => s.selectedDatasetEntries.selectedIds);
const disclosure = useDisclosure();
return (
<>
<ActionButton
onClick={disclosure.onOpen}
label="Download"
icon={AiOutlineDownload}
isDisabled={selectedIds.size === 0}
requireBeta
/>
<ExportDatasetEntriesModal disclosure={disclosure} />
</>
);
};
export default ExportButton;
const ExportDatasetEntriesModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
const dataset = useDataset().data;
const selectedIds = useAppStore((s) => s.selectedDatasetEntries.selectedIds);
const clearSelectedIds = useAppStore((s) => s.selectedDatasetEntries.clearSelectedIds);
const [testingSplit, setTestingSplit] = useState(10);
const [removeDuplicates, setRemoveDuplicates] = useState(false);
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
useEffect(() => {
if (disclosure.isOpen) {
setTestingSplit(10);
setRemoveDuplicates(false);
}
}, [disclosure.isOpen]);
const exportDataMutation = api.datasetEntries.export.useMutation();
const [exportData, exportInProgress] = useHandledAsyncCallback(async () => {
if (!dataset?.id || !selectedIds.size || !testingSplit) return;
const response = await exportDataMutation.mutateAsync({
datasetId: dataset.id,
datasetEntryIds: Array.from(selectedIds),
testingSplit,
removeDuplicates,
});
const dataUrl = `data:application/pdf;base64,${response}`;
const blob = await fetch(dataUrl).then((res) => res.blob());
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = `data.zip`;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
disclosure.onClose();
clearSelectedIds();
}, [exportDataMutation, dataset, selectedIds, testingSplit, removeDuplicates]);
return (
<Modal size={{ base: "xl", md: "2xl" }} {...disclosure}>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader>
<HStack>
<Icon as={AiOutlineDownload} />
<Text>Export Logs</Text>
</HStack>
</ModalHeader>
<ModalCloseButton />
<ModalBody maxW="unset">
<VStack w="full" spacing={8} pt={4} alignItems="flex-start">
<Text>
We'll export the <b>{selectedIds.size}</b> rows you have selected in the OpenAI
training format.
</Text>
<VStack alignItems="flex-start" spacing={4}>
<Flex
flexDir={{ base: "column", md: "row" }}
alignItems={{ base: "flex-start", md: "center" }}
>
<HStack w={48} alignItems="center" spacing={1}>
<Text fontWeight="bold">Testing Split:</Text>
<InfoCircle tooltipText="The percent of your logs that will be reserved for testing and saved in another file. Logs are split randomly." />
</HStack>
<HStack>
<NumberInput
defaultValue={10}
onChange={(_, num) => setTestingSplit(num)}
min={1}
max={100}
w={48}
>
<NumberInputField />
<NumberInputStepper>
<NumberIncrementStepper />
<NumberDecrementStepper />
</NumberInputStepper>
</NumberInput>
</HStack>
</Flex>
</VStack>
<VStack alignItems="flex-start" spacing={0}>
<Button
variant="unstyled"
color="blue.600"
onClick={() => setShowAdvancedOptions(!showAdvancedOptions)}
>
<HStack>
<Text>Advanced Options</Text>
<Icon as={showAdvancedOptions ? FiChevronUp : FiChevronDown} />
</HStack>
</Button>
<Collapse in={showAdvancedOptions} unmountOnExit={true}>
<VStack align="stretch" pt={4}>
<HStack>
<Checkbox
colorScheme="blue"
isChecked={removeDuplicates}
onChange={(e) => setRemoveDuplicates(e.target.checked)}
>
<Text>Remove duplicates</Text>
</Checkbox>
<InfoCircle tooltipText="To avoid overfitting and speed up training, automatically deduplicate logs with matching input and output." />
</HStack>
</VStack>
</Collapse>
</VStack>
</VStack>
</ModalBody>
<ModalFooter>
<HStack>
<Button colorScheme="gray" onClick={disclosure.onClose} minW={24}>
Cancel
</Button>
<Button colorScheme="blue" onClick={exportData} isLoading={exportInProgress} minW={24}>
Download
</Button>
</HStack>
</ModalFooter>
</ModalContent>
</Modal>
);
};

View File

@@ -1,21 +0,0 @@
import { RiFlaskLine } from "react-icons/ri";
import { useAppStore } from "~/state/store";
import ActionButton from "../ActionButton";
const ExperimentButton = () => {
const selectedIds = useAppStore((s) => s.selectedDatasetEntries.selectedIds);
return (
<ActionButton
onClick={() => {
console.log("experimenting with these ids", selectedIds);
}}
label="Experiment"
icon={RiFlaskLine}
isDisabled={selectedIds.size === 0}
requireBeta
/>
);
};
export default ExperimentButton;

View File

@@ -1,139 +0,0 @@
import { useState, useEffect } from "react";
import { VStack, HStack, Button, Text, Progress, IconButton, Portal } from "@chakra-ui/react";
import { BsX } from "react-icons/bs";
import { type RouterOutputs, api } from "~/utils/api";
import { useDataset, useHandledAsyncCallback } from "~/utils/hooks";
import { formatFileSize } from "~/utils/utils";
type FileUpload = RouterOutputs["datasets"]["listFileUploads"][0];
const FileUploadsCard = () => {
const dataset = useDataset();
const [fileUploadsRefetchInterval, setFileUploadsRefetchInterval] = useState<number>(500);
const fileUploads = api.datasets.listFileUploads.useQuery(
{ datasetId: dataset.data?.id as string },
{ enabled: !!dataset.data?.id, refetchInterval: fileUploadsRefetchInterval },
);
useEffect(() => {
if (fileUploads?.data?.some((fu) => fu.status !== "COMPLETE" && fu.status !== "ERROR")) {
setFileUploadsRefetchInterval(500);
} else {
setFileUploadsRefetchInterval(15000);
}
}, [fileUploads]);
const utils = api.useContext();
const hideFileUploadsMutation = api.datasets.hideFileUploads.useMutation();
const [hideAllFileUploads] = useHandledAsyncCallback(async () => {
if (!fileUploads.data?.length) return;
await hideFileUploadsMutation.mutateAsync({
fileUploadIds: fileUploads.data.map((upload) => upload.id),
});
await utils.datasets.listFileUploads.invalidate();
}, [hideFileUploadsMutation, fileUploads.data, utils]);
if (!fileUploads.data?.length) return null;
return (
<Portal>
<VStack
w={72}
borderRadius={8}
position="fixed"
bottom={8}
right={8}
overflow="hidden"
borderWidth={1}
boxShadow="0 0 40px 4px rgba(0, 0, 0, 0.1);"
minW={0}
bgColor="white"
>
<HStack p={4} w="full" bgColor="gray.200" justifyContent="space-between">
<Text fontWeight="bold">Uploads</Text>
<IconButton
aria-label="Close uploads"
as={BsX}
boxSize={6}
minW={0}
variant="ghost"
onClick={hideAllFileUploads}
cursor="pointer"
/>
</HStack>
{fileUploads?.data?.map((upload) => <FileUploadRow key={upload.id} fileUpload={upload} />)}
</VStack>
</Portal>
);
};
export default FileUploadsCard;
const FileUploadRow = ({ fileUpload }: { fileUpload: FileUpload }) => {
const { id, fileName, fileSize, progress, status, errorMessage } = fileUpload;
const utils = api.useContext();
const hideFileUploadsMutation = api.datasets.hideFileUploads.useMutation();
const [hideFileUpload, hidingInProgress] = useHandledAsyncCallback(async () => {
await hideFileUploadsMutation.mutateAsync({ fileUploadIds: [id] });
await utils.datasets.listFileUploads.invalidate();
}, [id, hideFileUploadsMutation, utils]);
useEffect(() => {
// Invalidate dataset entries list when upload is processed
if (status === "COMPLETE") void utils.datasetEntries.list.invalidate();
}, [status, utils]);
return (
<VStack w="full" alignItems="flex-start" p={4} borderBottomWidth={1}>
<HStack w="full" justifyContent="space-between" alignItems="flex-start">
<VStack alignItems="flex-start" spacing={0}>
<Text fontWeight="bold">{fileName}</Text>
<Text fontSize="xs">({formatFileSize(fileSize, 2)})</Text>
</VStack>
<Button
aria-label="Hide file upload"
minW={0}
variant="ghost"
isLoading={hidingInProgress}
onClick={hideFileUpload}
size="xs"
>
HIDE
</Button>
</HStack>
{errorMessage ? (
<Text alignSelf="center" pt={2}>
{errorMessage}
</Text>
) : (
<>
<Text alignSelf="center" fontSize="xs">
{getStatusText(status)}
</Text>
<Progress w="full" value={progress} borderRadius={2} />
</>
)}
</VStack>
);
};
const getStatusText = (status: FileUpload["status"]) => {
switch (status) {
case "PENDING":
return "Pending";
case "DOWNLOADING":
return "Loading Data";
case "PROCESSING":
return "Processing";
case "SAVING":
return "Saving";
case "COMPLETE":
return "Complete";
case "ERROR":
return "Error";
}
};

View File

@@ -1,288 +0,0 @@
import { useState, useEffect, useRef, useCallback } from "react";
import {
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalCloseButton,
ModalBody,
ModalFooter,
HStack,
VStack,
Icon,
Text,
Button,
Box,
useDisclosure,
type UseDisclosureReturn,
} from "@chakra-ui/react";
import pluralize from "pluralize";
import { AiOutlineCloudUpload, AiOutlineFile } from "react-icons/ai";
import { useDataset, useHandledAsyncCallback } from "~/utils/hooks";
import { api } from "~/utils/api";
import ActionButton from "../ActionButton";
import { validateTrainingRows, type TrainingRow, parseJSONL } from "./validateTrainingRows";
import { uploadDatasetEntryFile } from "~/utils/azure/website";
import { formatFileSize } from "~/utils/utils";
const UploadDataButton = () => {
const disclosure = useDisclosure();
return (
<>
<ActionButton
onClick={disclosure.onOpen}
label="Upload Data"
icon={AiOutlineCloudUpload}
iconBoxSize={4}
requireBeta
/>
<UploadDataModal disclosure={disclosure} />
</>
);
};
export default UploadDataButton;
const UploadDataModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
const dataset = useDataset().data;
const [validationError, setValidationError] = useState<string | null>(null);
const [trainingRows, setTrainingRows] = useState<TrainingRow[] | null>(null);
const [file, setFile] = useState<File | null>(null);
const fileInputRef = useRef<HTMLInputElement>(null);
const handleFileDrop = (e: React.DragEvent<HTMLDivElement>) => {
e.preventDefault();
const files = e.dataTransfer.files;
if (files.length > 0) {
processFile(files[0] as File);
}
};
const handleFileChange = (e: React.ChangeEvent<HTMLInputElement>) => {
const files = e.target.files;
if (files && files.length > 0) {
processFile(files[0] as File);
}
};
const processFile = (file: File) => {
setFile(file);
// skip reading if file is larger than 10MB
if (file.size > 10000000) {
setTrainingRows(null);
return;
}
const reader = new FileReader();
reader.onload = (e: ProgressEvent<FileReader>) => {
const content = e.target?.result as string;
// Process the content, e.g., set to state
let parsedJSONL;
try {
parsedJSONL = parseJSONL(content) as TrainingRow[];
const validationError = validateTrainingRows(parsedJSONL);
if (validationError) {
setValidationError(validationError);
setTrainingRows(null);
return;
}
setTrainingRows(parsedJSONL);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
setValidationError("Unable to parse JSONL file: " + (e.message as string));
setTrainingRows(null);
return;
}
};
reader.readAsText(file);
};
const resetState = useCallback(() => {
setValidationError(null);
setTrainingRows(null);
setFile(null);
}, [setValidationError, setTrainingRows, setFile]);
useEffect(() => {
if (disclosure.isOpen) {
resetState();
}
}, [disclosure.isOpen, resetState]);
const triggerFileDownloadMutation = api.datasets.triggerFileDownload.useMutation();
const utils = api.useContext();
const [sendJSONL, sendingInProgress] = useHandledAsyncCallback(async () => {
if (!dataset || !file) return;
const blobName = await uploadDatasetEntryFile(file);
await triggerFileDownloadMutation.mutateAsync({
datasetId: dataset.id,
blobName,
fileName: file.name,
fileSize: file.size,
});
await utils.datasets.listFileUploads.invalidate();
disclosure.onClose();
}, [dataset, trainingRows, triggerFileDownloadMutation, file, utils]);
return (
<Modal
size={{ base: "xl", md: "2xl" }}
closeOnOverlayClick={false}
closeOnEsc={false}
{...disclosure}
>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader>
<HStack>
<Text>Upload Training Logs</Text>
</HStack>
</ModalHeader>
{!sendingInProgress && <ModalCloseButton />}
<ModalBody maxW="unset" p={8}>
<Box w="full" aspectRatio={1.5}>
{validationError && (
<VStack w="full" h="full" justifyContent="center" spacing={8}>
<Icon as={AiOutlineFile} boxSize={24} color="gray.300" />
<VStack w="full">
<Text fontSize={32} color="gray.500" fontWeight="bold">
Error
</Text>
<Text color="gray.500">{validationError}</Text>
</VStack>
<Text
as="span"
textDecor="underline"
color="gray.500"
_hover={{ color: "orange.400" }}
cursor="pointer"
onClick={resetState}
>
Try again
</Text>
</VStack>
)}
{!validationError && !file && (
<VStack
w="full"
h="full"
stroke="gray.300"
justifyContent="center"
borderRadius={8}
sx={{
"background-image": `url("data:image/svg+xml,%3csvg width='100%25' height='100%25' xmlns='http://www.w3.org/2000/svg'%3e%3crect x='2%25' y='2%25' width='96%25' height='96%25' fill='none' stroke='%23eee' stroke-width='4' stroke-dasharray='6%2c 14' stroke-dashoffset='0' stroke-linecap='square' rx='8' ry='8'/%3e%3c/svg%3e")`,
}}
onDragOver={(e) => e.preventDefault()}
onDrop={handleFileDrop}
>
<JsonFileIcon />
<Icon as={AiOutlineCloudUpload} boxSize={24} color="gray.300" />
<Text fontSize={32} color="gray.500" fontWeight="bold">
Drag & Drop
</Text>
<Text color="gray.500">
your .jsonl file here, or{" "}
<input
type="file"
ref={fileInputRef}
onChange={handleFileChange}
style={{ display: "none" }}
accept=".jsonl"
/>
<Text
as="span"
textDecor="underline"
_hover={{ color: "orange.400" }}
cursor="pointer"
onClick={() => fileInputRef.current?.click()}
>
browse
</Text>
</Text>
</VStack>
)}
{!validationError && file && (
<VStack w="full" h="full" justifyContent="center" spacing={8}>
<JsonFileIcon />
<VStack w="full">
{trainingRows ? (
<>
<Text fontSize={32} color="gray.500" fontWeight="bold">
Success
</Text>
<Text color="gray.500">
We'll upload <b>{trainingRows.length}</b>{" "}
{pluralize("row", trainingRows.length)} into <b>{dataset?.name}</b>.{" "}
</Text>
</>
) : (
<>
<Text fontSize={32} color="gray.500" fontWeight="bold">
{file.name}
</Text>
<Text color="gray.500">{formatFileSize(file.size)}</Text>
</>
)}
</VStack>
{!sendingInProgress && (
<Text
as="span"
textDecor="underline"
color="gray.500"
_hover={{ color: "orange.400" }}
cursor="pointer"
onClick={resetState}
>
Change file
</Text>
)}
</VStack>
)}
</Box>
</ModalBody>
<ModalFooter>
<HStack>
<Button
colorScheme="gray"
isDisabled={sendingInProgress}
onClick={disclosure.onClose}
minW={24}
>
Cancel
</Button>
<Button
colorScheme="orange"
onClick={sendJSONL}
isLoading={sendingInProgress}
minW={24}
isDisabled={!file || !!validationError}
>
Upload
</Button>
</HStack>
</ModalFooter>
</ModalContent>
</Modal>
);
};
const JsonFileIcon = () => (
<Box position="relative" display="flex" alignItems="center" justifyContent="center">
<Icon as={AiOutlineFile} boxSize={24} color="gray.300" />
<Text position="absolute" color="orange.400" fontWeight="bold" fontSize={12} pt={4}>
JSONL
</Text>
</Box>
);

View File

@@ -1,71 +0,0 @@
import { type CreateChatCompletionRequestMessage } from "openai/resources/chat";
export type TrainingRow = {
input: CreateChatCompletionRequestMessage[];
output?: CreateChatCompletionRequestMessage;
};
export const parseJSONL = (jsonlString: string): unknown[] => {
const lines = jsonlString.trim().split("\n");
let lineNumber = 0;
const parsedLines = [];
try {
for (const line of lines) {
lineNumber++;
parsedLines.push(JSON.parse(line));
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
throw new Error(`Error parsing line ${lineNumber}: ${e.message as string}`);
}
return parsedLines;
};
export const validateTrainingRows = (rows: unknown): string | null => {
if (!Array.isArray(rows)) return "training data is not an array";
for (let i = 0; i < rows.length; i++) {
const row = rows[i] as TrainingRow;
let errorMessage: string | null = null;
try {
errorMessage = validateTrainingRow(row);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (error: any) {
errorMessage = error.message;
}
if (errorMessage) return `row ${i + 1}: ${errorMessage}`;
}
return null;
};
const validateTrainingRow = (row: TrainingRow): string | null => {
if (!row) return "empty row";
if (!row.input) return "missing input";
// Validate input
if (!Array.isArray(row.input)) return "input is not an array";
if ((row.input as unknown[]).some((x) => typeof x !== "object"))
return "input contains invalid item";
if (row.input.some((x) => !x)) return "input contains empty item";
if (row.input.some((x) => !x.content && !x.function_call))
return "input contains item with no content or function_call";
if (row.input.some((x) => x.function_call && !x.function_call.arguments))
return "input contains item with function_call but no arguments";
if (row.input.some((x) => x.function_call && !x.function_call.name))
return "input contains item with function_call but no name";
// Validate output
if (row.output) {
if (typeof row.output !== "object") return "output is not an object";
if (!row.output.content && !row.output.function_call)
return "output contains no content or function_call";
if (row.output.function_call && !row.output.function_call.arguments)
return "output contains function_call but no arguments";
if (row.output.function_call && !row.output.function_call.name)
return "output contains function_call but no name";
}
return null;
};

View File

@@ -27,7 +27,7 @@ const DeleteExperimentDialog = ({
const mutation = api.experiments.delete.useMutation();
const utils = api.useContext();
const [onDeleteConfirm, deletionInProgress] = useHandledAsyncCallback(async () => {
const [onDeleteConfirm] = useHandledAsyncCallback(async () => {
if (!experimentId) return;
await mutation.mutateAsync({ id: experimentId });
await utils.experiments.list.invalidate();
@@ -53,12 +53,7 @@ const DeleteExperimentDialog = ({
<Button ref={cancelRef} onClick={disclosure.onClose}>
Cancel
</Button>
<Button
colorScheme="red"
isLoading={deletionInProgress}
onClick={onDeleteConfirm}
ml={3}
>
<Button colorScheme="red" onClick={onDeleteConfirm} ml={3}>
Delete
</Button>
</AlertDialogFooter>

View File

@@ -0,0 +1,57 @@
import {
Button,
AlertDialog,
AlertDialogBody,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogContent,
AlertDialogOverlay,
} from "@chakra-ui/react";
import { useRouter } from "next/router";
import { useRef } from "react";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
export const DeleteDialog = ({ onClose }: { onClose: () => void }) => {
const experiment = useExperiment();
const deleteMutation = api.experiments.delete.useMutation();
const utils = api.useContext();
const router = useRouter();
const cancelRef = useRef<HTMLButtonElement>(null);
const [onDeleteConfirm] = useHandledAsyncCallback(async () => {
if (!experiment.data?.id) return;
await deleteMutation.mutateAsync({ id: experiment.data.id });
await utils.experiments.list.invalidate();
await router.push({ pathname: "/experiments" });
onClose();
}, [deleteMutation, experiment.data?.id, router]);
return (
<AlertDialog isOpen leastDestructiveRef={cancelRef} onClose={onClose}>
<AlertDialogOverlay>
<AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold">
Delete Experiment
</AlertDialogHeader>
<AlertDialogBody>
If you delete this experiment all the associated prompts and scenarios will be deleted
as well. Are you sure?
</AlertDialogBody>
<AlertDialogFooter>
<Button ref={cancelRef} onClick={onClose}>
Cancel
</Button>
<Button colorScheme="red" onClick={onDeleteConfirm} ml={3}>
Delete
</Button>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialogOverlay>
</AlertDialog>
);
};

View File

@@ -3,14 +3,17 @@ import { useOnForkButtonPressed } from "./useOnForkButtonPressed";
import { useExperiment } from "~/utils/hooks";
import { BsGearFill } from "react-icons/bs";
import { TbGitFork } from "react-icons/tb";
import { useAppStore } from "~/state/store";
export const ExperimentHeaderButtons = ({ openDrawer }: { openDrawer: () => void }) => {
export const ExperimentHeaderButtons = () => {
const experiment = useExperiment();
const canModify = experiment.data?.access.canModify ?? false;
const { onForkButtonPressed, isForking } = useOnForkButtonPressed();
const openDrawer = useAppStore((s) => s.openDrawer);
if (experiment.isLoading) return null;
return (

View File

@@ -17,7 +17,7 @@ import { useRouter } from "next/router";
import { BsGearFill, BsGithub, BsPersonCircle } from "react-icons/bs";
import { IoStatsChartOutline } from "react-icons/io5";
import { RiHome3Line, RiFlaskLine } from "react-icons/ri";
import { AiOutlineThunderbolt, AiOutlineDatabase } from "react-icons/ai";
import { AiOutlineThunderbolt } from "react-icons/ai";
import { FaReadme } from "react-icons/fa";
import { signIn, useSession } from "next-auth/react";
@@ -78,7 +78,6 @@ const NavSidebar = () => {
<IconLink icon={RiHome3Line} label="Dashboard" href="/dashboard" />
<IconLink icon={IoStatsChartOutline} label="Request Logs" href="/request-logs" />
<IconLink icon={AiOutlineDatabase} label="Datasets" href="/datasets" beta />
<IconLink icon={AiOutlineThunderbolt} label="Fine Tunes" href="/fine-tunes" beta />
<IconLink icon={RiFlaskLine} label="Experiments" href="/experiments" />
<VStack w="full" alignItems="flex-start" spacing={0} pt={8}>
@@ -117,8 +116,8 @@ const NavSidebar = () => {
</VStack>
<HStack
w="full"
px={{ base: 3, md: 4 }}
py={{ base: 0, md: 1 }}
px={{ base: 2, md: 4 }}
py={{ base: 1, md: 2 }}
as={ChakraLink}
justifyContent="start"
href="https://docs.openpipe.ai"
@@ -127,8 +126,8 @@ const NavSidebar = () => {
spacing={1}
>
<Icon as={FaReadme} boxSize={4} mr={2} />
<Text fontWeight="bold" fontSize="sm" display={{ base: "none", md: "flex" }}>
Open Documentation
<Text fontWeight="bold" fontSize="sm">
Read the Docs
</Text>
</HStack>
<Divider />

View File

@@ -3,18 +3,16 @@ import { useState } from "react";
import { Button, HStack, type ButtonProps, Icon, Text } from "@chakra-ui/react";
import { type IconType } from "react-icons";
import { useAppStore } from "~/state/store";
import { BetaModal } from "./BetaModal";
import { BetaModal } from "../BetaModal";
const ActionButton = ({
icon,
iconBoxSize = 3.5,
label,
requireBeta = false,
onClick,
...buttonProps
}: {
icon: IconType;
iconBoxSize?: number;
label: string;
requireBeta?: boolean;
onClick?: () => void;
@@ -41,9 +39,7 @@ const ActionButton = ({
{...buttonProps}
>
<HStack spacing={1}>
{icon && (
<Icon as={icon} boxSize={iconBoxSize} color={requireBeta ? "orange.400" : undefined} />
)}
{icon && <Icon as={icon} color={requireBeta ? "orange.400" : undefined} />}
<Text display={{ base: "none", md: "flex" }}>{label}</Text>
</HStack>
</Button>

View File

@@ -1,194 +0,0 @@
import { useState, useEffect, useMemo } from "react";
import {
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalCloseButton,
ModalBody,
ModalFooter,
HStack,
VStack,
Icon,
Text,
Button,
Flex,
Input,
useDisclosure,
type UseDisclosureReturn,
Checkbox,
} from "@chakra-ui/react";
import { FiPlusSquare } from "react-icons/fi";
import { useDatasets, useHandledAsyncCallback } from "~/utils/hooks";
import { api } from "~/utils/api";
import { useAppStore } from "~/state/store";
import ActionButton from "../ActionButton";
import InputDropdown from "../InputDropdown";
import { maybeReportError } from "~/utils/errorHandling/maybeReportError";
import { useRouter } from "next/router";
const AddToDatasetButton = () => {
const selectedLogIds = useAppStore((s) => s.selectedLogs.selectedLogIds);
const disclosure = useDisclosure();
return (
<>
<ActionButton
onClick={disclosure.onOpen}
label="Add to Dataset"
icon={FiPlusSquare}
isDisabled={selectedLogIds.size === 0}
requireBeta
/>
<AddToDatasetModal disclosure={disclosure} />
</>
);
};
export default AddToDatasetButton;
const AddToDatasetModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
const selectedProjectId = useAppStore((s) => s.selectedProjectId);
const selectedLogIds = useAppStore((s) => s.selectedLogs.selectedLogIds);
const clearSelectedLogIds = useAppStore((s) => s.selectedLogs.clearSelectedLogIds);
const router = useRouter();
const datasets = useDatasets().data;
const existingDatasetOptions = useMemo(
() =>
datasets?.length
? datasets.map((d) => ({ label: d.name, id: d.id }))
: [{ label: "", id: "" }],
[datasets],
);
const [selectedDatasetOption, setSelectedDatasetOption] = useState(existingDatasetOptions?.[0]);
const [newDatasetName, setNewDatasetName] = useState("");
const [createNewDataset, setCreateNewDataset] = useState(false);
useEffect(() => {
if (disclosure.isOpen) {
setSelectedDatasetOption(existingDatasetOptions?.[0]);
setCreateNewDataset(!existingDatasetOptions[0]?.id);
}
}, [disclosure.isOpen, existingDatasetOptions]);
const createDatasetEntriesMutation = api.datasetEntries.create.useMutation();
const [addToDataset, addingInProgress] = useHandledAsyncCallback(async () => {
if (
!selectedProjectId ||
!selectedLogIds.size ||
!(createNewDataset ? newDatasetName : selectedDatasetOption?.id)
)
return;
const datasetParams = createNewDataset
? { newDatasetParams: { projectId: selectedProjectId, name: newDatasetName } }
: { datasetId: selectedDatasetOption?.id };
const response = await createDatasetEntriesMutation.mutateAsync({
loggedCallIds: Array.from(selectedLogIds),
...datasetParams,
});
if (maybeReportError(response)) return;
const datasetId = response.payload;
await router.push({ pathname: "/datasets/[id]", query: { id: datasetId } });
disclosure.onClose();
clearSelectedLogIds();
}, [
selectedProjectId,
selectedLogIds,
createNewDataset,
selectedDatasetOption?.id,
newDatasetName,
router,
]);
return (
<Modal size={{ base: "xl", md: "2xl" }} {...disclosure}>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader>
<HStack>
<Icon as={FiPlusSquare} />
<Text>Add to Dataset</Text>
</HStack>
</ModalHeader>
<ModalCloseButton />
<ModalBody maxW="unset">
<VStack w="full" spacing={8} pt={4} alignItems="flex-start">
<Text>
We'll add the <b>{selectedLogIds.size}</b> logs you have selected to the dataset you
choose.
</Text>
<VStack alignItems="flex-start" spacing={4}>
{existingDatasetOptions?.length && selectedDatasetOption && (
<Flex
flexDir={{ base: "column", md: "row" }}
alignItems={{ base: "flex-start", md: "center" }}
>
<Text fontWeight="bold" w={48}>
Dataset:
</Text>
<InputDropdown
options={existingDatasetOptions}
selectedOption={selectedDatasetOption}
getDisplayLabel={(option) => option.label}
onSelect={(option) => setSelectedDatasetOption(option)}
inputGroupProps={{ w: 48 }}
isDisabled={createNewDataset}
/>
<Checkbox
isChecked={createNewDataset}
onChange={(e) => setCreateNewDataset(e.target.checked)}
paddingLeft={4}
isDisabled={!existingDatasetOptions[0]?.id}
>
<Text>Create New Dataset</Text>
</Checkbox>
</Flex>
)}
{createNewDataset && (
<Flex
flexDir={{ base: "column", md: "row" }}
alignItems={{ base: "flex-start", md: "center" }}
>
<Text w={48} fontWeight="bold">
Dataset Name:
</Text>
<Input
w={48}
value={newDatasetName}
onChange={(e) => setNewDatasetName(e.target.value)}
/>
</Flex>
)}
</VStack>
</VStack>
</ModalBody>
<ModalFooter>
<HStack>
<Button colorScheme="gray" onClick={disclosure.onClose} minW={24}>
Cancel
</Button>
<Button
colorScheme="blue"
onClick={addToDataset}
isLoading={addingInProgress}
minW={24}
>
Add
</Button>
</HStack>
</ModalFooter>
</ModalContent>
</Modal>
);
};

View File

@@ -17,7 +17,7 @@ import { useMemo } from "react";
import { useIsClientRehydrated, useTagNames } from "~/utils/hooks";
import { useAppStore } from "~/state/store";
import { StaticColumnKeys } from "~/state/columnVisiblitySlice";
import ActionButton from "../ActionButton";
import ActionButton from "./ActionButton";
const ColumnVisiblityDropdown = () => {
const tagNames = useTagNames().data;

View File

@@ -28,7 +28,7 @@ import { BiExport } from "react-icons/bi";
import { useHandledAsyncCallback } from "~/utils/hooks";
import { api } from "~/utils/api";
import { useAppStore } from "~/state/store";
import ActionButton from "../ActionButton";
import ActionButton from "./ActionButton";
import InputDropdown from "../InputDropdown";
import { FiChevronUp, FiChevronDown } from "react-icons/fi";
import InfoCircle from "../InfoCircle";
@@ -81,7 +81,7 @@ const ExportLogsModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) =>
return;
const response = await exportLogsMutation.mutateAsync({
projectId: selectedProjectId,
loggedCallIds: Array.from(selectedLogIds),
selectedLogIds: Array.from(selectedLogIds),
testingSplit,
selectedExportFormat,
removeDuplicates,

View File

@@ -20,18 +20,17 @@ import { AiTwotoneThunderbolt } from "react-icons/ai";
import humanId from "human-id";
import { useRouter } from "next/router";
import { useDataset, useDatasetEntries, useHandledAsyncCallback } from "~/utils/hooks";
import { useHandledAsyncCallback } from "~/utils/hooks";
import { api } from "~/utils/api";
import ActionButton from "../ActionButton";
import { useAppStore } from "~/state/store";
import ActionButton from "./ActionButton";
import InputDropdown from "../InputDropdown";
// import { FiChevronDown } from "react-icons/fi";
import { FiChevronDown } from "react-icons/fi";
const SUPPORTED_BASE_MODELS = ["llama2-7b", "llama2-13b", "llama2-70b", "gpt-3.5-turbo"];
const FineTuneButton = () => {
const datasetEntries = useDatasetEntries().data;
const numEntries = datasetEntries?.matchingEntryIds.length || 0;
const selectedLogIds = useAppStore((s) => s.selectedLogs.selectedLogIds);
const disclosure = useDisclosure();
@@ -41,7 +40,7 @@ const FineTuneButton = () => {
onClick={disclosure.onOpen}
label="Fine Tune"
icon={AiTwotoneThunderbolt}
isDisabled={numEntries === 0}
isDisabled={selectedLogIds.size === 0}
requireBeta
/>
<FineTuneModal disclosure={disclosure} />
@@ -52,8 +51,9 @@ const FineTuneButton = () => {
export default FineTuneButton;
const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
const dataset = useDataset().data;
const datasetEntries = useDatasetEntries().data;
const selectedProjectId = useAppStore((s) => s.selectedProjectId);
const selectedLogIds = useAppStore((s) => s.selectedLogs.selectedLogIds);
const clearSelectedLogIds = useAppStore((s) => s.selectedLogs.clearSelectedLogIds);
const [selectedBaseModel, setSelectedBaseModel] = useState(SUPPORTED_BASE_MODELS[0]);
const [modelSlug, setModelSlug] = useState(humanId({ separator: "-", capitalize: false }));
@@ -71,17 +71,19 @@ const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
const createFineTuneMutation = api.fineTunes.create.useMutation();
const [createFineTune, creationInProgress] = useHandledAsyncCallback(async () => {
if (!modelSlug || !selectedBaseModel || !dataset) return;
if (!selectedProjectId || !modelSlug || !selectedBaseModel || !selectedLogIds.size) return;
await createFineTuneMutation.mutateAsync({
projectId: selectedProjectId,
slug: modelSlug,
baseModel: selectedBaseModel,
datasetId: dataset.id,
selectedLogIds: Array.from(selectedLogIds),
});
await utils.fineTunes.list.invalidate();
await router.push({ pathname: "/fine-tunes" });
clearSelectedLogIds();
disclosure.onClose();
}, [createFineTuneMutation, modelSlug, selectedBaseModel]);
}, [createFineTuneMutation, selectedProjectId, selectedLogIds, modelSlug, selectedBaseModel]);
return (
<Modal size={{ base: "xl", md: "2xl" }} {...disclosure}>
@@ -97,8 +99,7 @@ const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
<ModalBody maxW="unset">
<VStack w="full" spacing={8} pt={4} alignItems="flex-start">
<Text>
We'll train on <b>{datasetEntries?.trainingCount}</b> and test on{" "}
<b>{datasetEntries?.testingCount}</b> entries in this dataset.
We'll train on the <b>{selectedLogIds.size}</b> logs you've selected.
</Text>
<VStack>
<HStack spacing={2} w="full">
@@ -131,12 +132,12 @@ const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
/>
</HStack>
</VStack>
{/* <Button variant="unstyled" color="blue.600">
<Button variant="unstyled" color="blue.600">
<HStack>
<Text>Advanced Options</Text>
<Icon as={FiChevronDown} />
</HStack>
</Button> */}
</Button>
</VStack>
</ModalBody>
<ModalFooter>

View File

@@ -9,14 +9,17 @@ import {
Collapse,
HStack,
VStack,
Button,
ButtonGroup,
Text,
Checkbox,
Link as ChakraLink,
} from "@chakra-ui/react";
import Link from "next/link";
import dayjs from "~/utils/dayjs";
import { type RouterOutputs } from "~/utils/api";
import { FormattedJson } from "../FormattedJson";
import { FormattedJson } from "./FormattedJson";
import { useAppStore } from "~/state/store";
import { useIsClientRehydrated, useLoggedCalls, useTagNames } from "~/utils/hooks";
import { useMemo } from "react";
@@ -173,16 +176,23 @@ export const TableRow = ({
<Tr>
<Td colSpan={visibleColumns.size + 1} w="full" p={0}>
<Collapse in={isExpanded} unmountOnExit={true}>
<HStack align="stretch" p={4}>
<VStack flex={1} align="stretch">
<Heading size="sm">Input</Heading>
<FormattedJson json={loggedCall.modelResponse?.reqPayload} />
</VStack>
<VStack flex={1} align="stretch">
<Heading size="sm">Output</Heading>
<FormattedJson json={loggedCall.modelResponse?.respPayload} />
</VStack>
</HStack>
<VStack p={4} align="stretch">
<HStack align="stretch">
<VStack flex={1} align="stretch">
<Heading size="sm">Input</Heading>
<FormattedJson json={loggedCall.modelResponse?.reqPayload} />
</VStack>
<VStack flex={1} align="stretch">
<Heading size="sm">Output</Heading>
<FormattedJson json={loggedCall.modelResponse?.respPayload} />
</VStack>
</HStack>
<ButtonGroup alignSelf="flex-end">
<Button as={Link} colorScheme="blue" href={{ pathname: "/experiments" }}>
Experiments
</Button>
</ButtonGroup>
</VStack>
</Collapse>
</Td>
</Tr>

View File

@@ -26,9 +26,6 @@ export const env = createEnv({
SMTP_PORT: z.string().default("placeholder"),
SMTP_LOGIN: z.string().default("placeholder"),
SMTP_PASSWORD: z.string().default("placeholder"),
AZURE_STORAGE_ACCOUNT_NAME: z.string().default("placeholder"),
AZURE_STORAGE_ACCOUNT_KEY: z.string().default("placeholder"),
AZURE_STORAGE_CONTAINER_NAME: z.string().default("placeholder"),
WORKER_CONCURRENCY: z
.string()
.default("10")
@@ -75,9 +72,6 @@ export const env = createEnv({
SMTP_PORT: process.env.SMTP_PORT,
SMTP_LOGIN: process.env.SMTP_LOGIN,
SMTP_PASSWORD: process.env.SMTP_PASSWORD,
AZURE_STORAGE_ACCOUNT_NAME: process.env.AZURE_STORAGE_ACCOUNT_NAME,
AZURE_STORAGE_ACCOUNT_KEY: process.env.AZURE_STORAGE_ACCOUNT_KEY,
AZURE_STORAGE_CONTAINER_NAME: process.env.AZURE_STORAGE_CONTAINER_NAME,
WORKER_CONCURRENCY: process.env.WORKER_CONCURRENCY,
WORKER_MAX_POOL_SIZE: process.env.WORKER_MAX_POOL_SIZE,
},

View File

@@ -14,7 +14,7 @@ export async function getCompletion(
let finalCompletion: ChatCompletion | null = null;
try {
if (onStream && !input.function_call) {
if (onStream) {
const resp = await openai.chat.completions.create(
{
...input,

View File

@@ -42,21 +42,24 @@ const modelProvider: OpenaiChatModelProvider = {
canStream: true,
getCompletion,
getUsage: (input, output) => {
if (output.choices.length === 0) return null;
const model = modelProvider.getModel(input);
if (!model) return null;
let inputTokens: number;
let outputTokens: number;
if (output?.usage) {
if (output.usage) {
inputTokens = output.usage.prompt_tokens;
outputTokens = output.usage.completion_tokens;
} else {
try {
inputTokens = countOpenAIChatTokens(model, input.messages);
outputTokens = output
? countOpenAIChatTokens(model, output.choices.map((c) => c.message).filter(truthyFilter))
: 0;
outputTokens = countOpenAIChatTokens(
model,
output.choices.map((c) => c.message).filter(truthyFilter),
);
} catch (err) {
inputTokens = 0;
outputTokens = 0;

View File

@@ -8,9 +8,9 @@ const replicate = new Replicate({
});
const modelIds: Record<ReplicateLlama2Input["model"], string> = {
"7b-chat": "658b64a1e83d7caaba4ef10d5ee9c12c40770003f45852f05c2564962f921d8e",
"13b-chat": "7457c09004773f9f9710f7eb3b270287ffcebcfb23a13c8ec30cfb98f6bff9b2",
"70b-chat": "4dfd64cc207097970659087cf5670e3c1fbe02f83aa0f751e079cfba72ca790a",
"7b-chat": "d24902e3fa9b698cc208b5e63136c4e26e828659a9f09827ca6ec5bb83014381",
"13b-chat": "9dff94b1bed5af738655d4a7cbcdcde2bd503aa85c94334fe1f42af7f3dd5ee3",
"70b-chat": "2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1",
};
export async function getCompletion(

View File

@@ -59,7 +59,7 @@ export type ModelProvider<SupportedModels extends string, InputSchema, OutputSch
) => Promise<CompletionResponse<OutputSchema>>;
getUsage: (
input: InputSchema,
output?: OutputSchema,
output: OutputSchema,
) => { gpuRuntime?: number; inputTokens?: number; outputTokens?: number; cost?: number } | null;
// This is just a convenience for type inference, don't use it at runtime

View File

@@ -1,121 +0,0 @@
import {
Breadcrumb,
BreadcrumbItem,
Center,
Flex,
Icon,
Input,
VStack,
HStack,
useDisclosure,
} from "@chakra-ui/react";
import Link from "next/link";
import { useState, useEffect } from "react";
import { AiOutlineDatabase } from "react-icons/ai";
import AppShell from "~/components/nav/AppShell";
import { api } from "~/utils/api";
import { useDataset, useHandledAsyncCallback } from "~/utils/hooks";
import PageHeaderContainer from "~/components/nav/PageHeaderContainer";
import ProjectBreadcrumbContents from "~/components/nav/ProjectBreadcrumbContents";
import DatasetConfigurationDrawer from "~/components/datasets/DatasetConfigurationDrawer/DatasetConfigurationDrawer";
import { DatasetHeaderButtons } from "~/components/datasets/DatasetHeaderButtons";
import DatasetEntriesTable from "~/components/datasets/DatasetEntriesTable/DatasetEntriesTable";
import DatasetEntryPaginator from "~/components/datasets/DatasetEntryPaginator";
import { useAppStore } from "~/state/store";
import FineTuneButton from "~/components/datasets/FineTuneButton";
// import ExperimentButton from "~/components/datasets/ExperimentButton";
import UploadDataButton from "~/components/datasets/UploadDataButton";
// import DownloadButton from "~/components/datasets/DownloadButton";
import DeleteButton from "~/components/datasets/DeleteButton";
import FileUploadsCard from "~/components/datasets/FileUploadsCard";
export default function Dataset() {
const utils = api.useContext();
const dataset = useDataset();
const drawerDisclosure = useDisclosure();
const [name, setName] = useState(dataset.data?.name || "");
useEffect(() => {
setName(dataset.data?.name || "");
}, [dataset.data?.name]);
useEffect(() => {
useAppStore.getState().sharedArgumentsEditor.loadMonaco().catch(console.error);
}, []);
const updateMutation = api.datasets.update.useMutation();
const [onSaveName] = useHandledAsyncCallback(async () => {
if (name && name !== dataset.data?.name && dataset.data?.id) {
await updateMutation.mutateAsync({
id: dataset.data.id,
name,
});
await Promise.all([utils.datasets.list.invalidate(), utils.datasets.get.invalidate()]);
}
}, [updateMutation, dataset.data?.id, dataset.data?.name, name]);
if (!dataset.isLoading && !dataset.data) {
return (
<AppShell title="Dataset not found">
<Center h="100%">
<div>Dataset not found 😕</div>
</Center>
</AppShell>
);
}
return (
<>
<AppShell title={dataset.data?.name}>
<VStack h="full" overflowY="scroll">
<PageHeaderContainer>
<Breadcrumb>
<BreadcrumbItem>
<ProjectBreadcrumbContents projectName={dataset.data?.project?.name} />
</BreadcrumbItem>
<BreadcrumbItem>
<Link href="/datasets">
<Flex alignItems="center" _hover={{ textDecoration: "underline" }}>
<Icon as={AiOutlineDatabase} boxSize={4} mr={2} /> Datasets
</Flex>
</Link>
</BreadcrumbItem>
<BreadcrumbItem isCurrentPage>
<Input
size="sm"
value={name}
onChange={(e) => setName(e.target.value)}
onBlur={onSaveName}
borderWidth={1}
borderColor="transparent"
fontSize={16}
px={0}
minW={{ base: 100, lg: 300 }}
flex={1}
_hover={{ borderColor: "gray.300" }}
_focus={{ borderColor: "blue.500", outline: "none" }}
/>
</BreadcrumbItem>
</Breadcrumb>
<DatasetHeaderButtons openDrawer={drawerDisclosure.onOpen} />
</PageHeaderContainer>
<VStack px={8} py={8} alignItems="flex-start" spacing={4} w="full">
<HStack w="full" justifyContent="flex-end">
<FineTuneButton />
<UploadDataButton />
{/* <ExperimentButton /> */}
{/* <DownloadButton /> */}
<DeleteButton />
</HStack>
<DatasetEntriesTable />
<DatasetEntryPaginator />
</VStack>
</VStack>
<FileUploadsCard />
</AppShell>
<DatasetConfigurationDrawer disclosure={drawerDisclosure} />
</>
);
}

View File

@@ -1,17 +0,0 @@
import { VStack, Text, Divider } from "@chakra-ui/react";
import AppShell from "~/components/nav/AppShell";
import DatasetsTable from "~/components/datasets/DatasetsTable";
export default function DatasetsPage() {
return (
<AppShell title="Datasets" requireAuth>
<VStack w="full" py={8} px={8} spacing={4} alignItems="flex-start">
<Text fontSize="2xl" fontWeight="bold">
Datasets
</Text>
<Divider />
<DatasetsTable />
</VStack>
</AppShell>
);
}

View File

@@ -8,25 +8,26 @@ import {
Input,
Text,
VStack,
useDisclosure,
} from "@chakra-ui/react";
import Link from "next/link";
import { useRouter } from "next/router";
import { useState, useEffect } from "react";
import { RiFlaskLine } from "react-icons/ri";
import OutputsTable from "~/components/OutputsTable";
import ExperimentSettingsDrawer from "~/components/experiments/ExperimentSettingsDrawer/ExperimentSettingsDrawer";
import { ExperimentHeaderButtons } from "~/components/experiments/ExperimentHeaderButtons/ExperimentHeaderButtons";
import ExperimentSettingsDrawer from "~/components/ExperimentSettingsDrawer/ExperimentSettingsDrawer";
import AppShell from "~/components/nav/AppShell";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import { useAppStore } from "~/state/store";
import { useSyncVariantEditor } from "~/state/sync";
import { ExperimentHeaderButtons } from "~/components/experiments/ExperimentHeaderButtons/ExperimentHeaderButtons";
import Head from "next/head";
import PageHeaderContainer from "~/components/nav/PageHeaderContainer";
import ProjectBreadcrumbContents from "~/components/nav/ProjectBreadcrumbContents";
export default function Experiment() {
const router = useRouter();
const utils = api.useContext();
useSyncVariantEditor();
@@ -43,7 +44,6 @@ export default function Experiment() {
useAppStore.getState().sharedVariantEditor.loadMonaco().catch(console.error);
}, []);
const drawerDisclosure = useDisclosure();
const [label, setLabel] = useState(experiment.data?.label || "");
useEffect(() => {
setLabel(experiment.data?.label || "");
@@ -121,11 +121,11 @@ export default function Experiment() {
)}
</BreadcrumbItem>
</Breadcrumb>
<ExperimentHeaderButtons openDrawer={drawerDisclosure.onOpen} />
<ExperimentHeaderButtons />
</PageHeaderContainer>
<ExperimentSettingsDrawer disclosure={drawerDisclosure} />
<ExperimentSettingsDrawer />
<Box w="100%" overflowX="auto" flex={1} id="output-container">
<OutputsTable experimentId={experiment.data?.id} openDrawer={drawerDisclosure.onOpen} />
<OutputsTable experimentId={experiment.data?.id} />
</Box>
</VStack>
</AppShell>

View File

@@ -4,13 +4,14 @@ import { Text, VStack, Divider, HStack, Box } from "@chakra-ui/react";
import AppShell from "~/components/nav/AppShell";
import LoggedCallTable from "~/components/requestLogs/LoggedCallsTable";
import LoggedCallsPaginator from "~/components/requestLogs/LoggedCallsPaginator";
import ActionButton from "~/components/ActionButton";
import ActionButton from "~/components/requestLogs/ActionButton";
import { useAppStore } from "~/state/store";
import { RiFlaskLine } from "react-icons/ri";
import { FiFilter } from "react-icons/fi";
import LogFilters from "~/components/requestLogs/LogFilters/LogFilters";
import ColumnVisiblityDropdown from "~/components/requestLogs/ColumnVisiblityDropdown";
import FineTuneButton from "~/components/requestLogs/FineTuneButton";
import ExportButton from "~/components/requestLogs/ExportButton";
import AddToDatasetButton from "~/components/requestLogs/AddToDatasetButton";
export default function LoggedCalls() {
const selectedLogIds = useAppStore((s) => s.selectedLogs.selectedLogIds);
@@ -26,7 +27,16 @@ export default function LoggedCalls() {
</Text>
<Divider />
<HStack w="full" justifyContent="flex-end">
<AddToDatasetButton />
<FineTuneButton />
<ActionButton
onClick={() => {
console.log("experimenting with these ids", selectedLogIds);
}}
label="Experiment"
icon={RiFlaskLine}
isDisabled={selectedLogIds.size === 0}
requireBeta
/>
<ExportButton />
<ColumnVisiblityDropdown />
<ActionButton

View File

@@ -119,10 +119,10 @@ export const v1ApiRouter = createOpenApiRouter({
let usage;
let model;
if (reqPayload.success) {
if (reqPayload.success && respPayload.success) {
usage = modelProvider.getUsage(
input.reqPayload as CompletionCreateParams,
respPayload.success ? (input.respPayload as ChatCompletion) : undefined,
input.respPayload as ChatCompletion,
);
model = reqPayload.data.model;
}

View File

@@ -9,8 +9,6 @@ import { worldChampsRouter } from "./routers/worldChamps.router";
import { projectsRouter } from "./routers/projects.router";
import { dashboardRouter } from "./routers/dashboard.router";
import { loggedCallsRouter } from "./routers/loggedCalls.router";
import { datasetsRouter } from "./routers/datasets.router";
import { datasetEntriesRouter } from "./routers/datasetEntries.router";
import { fineTunesRouter } from "./routers/fineTunes.router";
import { usersRouter } from "./routers/users.router";
import { adminJobsRouter } from "./routers/adminJobs.router";
@@ -31,8 +29,6 @@ export const appRouter = createTRPCRouter({
projects: projectsRouter,
dashboard: dashboardRouter,
loggedCalls: loggedCallsRouter,
datasets: datasetsRouter,
datasetEntries: datasetEntriesRouter,
fineTunes: fineTunesRouter,
users: usersRouter,
adminJobs: adminJobsRouter,

View File

@@ -1,337 +0,0 @@
import { z } from "zod";
import { v4 as uuidv4 } from "uuid";
import {
type ChatCompletion,
type CompletionCreateParams,
type CreateChatCompletionRequestMessage,
} from "openai/resources/chat";
import { TRPCError } from "@trpc/server";
import archiver from "archiver";
import { WritableStreamBuffer } from "stream-buffers";
import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { requireCanModifyProject, requireCanViewProject } from "~/utils/accessControl";
import { error, success } from "~/utils/errorHandling/standardResponses";
import { countOpenAIChatTokens } from "~/utils/countTokens";
import { type TrainingRow } from "~/components/datasets/validateTrainingRows";
import hashObject from "~/server/utils/hashObject";
import { type JsonValue } from "type-fest";
import { formatEntriesFromTrainingRows } from "~/server/utils/createEntriesFromTrainingRows";
export const datasetEntriesRouter = createTRPCRouter({
list: protectedProcedure
.input(z.object({ datasetId: z.string(), page: z.number(), pageSize: z.number() }))
.query(async ({ input, ctx }) => {
const { datasetId, page, pageSize } = input;
const { projectId } = await prisma.dataset.findUniqueOrThrow({
where: { id: datasetId },
});
await requireCanViewProject(projectId, ctx);
const [entries, matchingEntries, trainingCount, testingCount] = await prisma.$transaction([
prisma.datasetEntry.findMany({
where: {
datasetId: datasetId,
},
orderBy: [{ createdAt: "desc" }, { id: "desc" }],
skip: (page - 1) * pageSize,
take: pageSize,
}),
prisma.datasetEntry.findMany({
where: {
datasetId: datasetId,
},
select: {
id: true,
},
}),
prisma.datasetEntry.count({
where: {
datasetId: datasetId,
type: "TRAIN",
},
}),
prisma.datasetEntry.count({
where: {
datasetId: datasetId,
type: "TEST",
},
}),
]);
return {
entries,
matchingEntryIds: matchingEntries.map((entry) => entry.id),
trainingCount,
testingCount,
};
}),
get: protectedProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
const entry = await prisma.datasetEntry.findUniqueOrThrow({
where: { id: input.id },
include: {
dataset: true,
},
});
if (!entry.dataset) {
throw new TRPCError({ message: "Dataset not found for dataset entry", code: "NOT_FOUND" });
}
await requireCanViewProject(entry.dataset.projectId, ctx);
if (!entry) {
throw new TRPCError({ message: "Dataset entry not found", code: "NOT_FOUND" });
}
return entry;
}),
create: protectedProcedure
.input(
z.object({
datasetId: z.string().optional(),
newDatasetParams: z
.object({
projectId: z.string(),
name: z.string(),
})
.optional(),
loggedCallIds: z.string().array().optional(),
}),
)
.mutation(async ({ input, ctx }) => {
let datasetId: string;
let trainingRatio = 0.8;
if (input.datasetId) {
datasetId = input.datasetId;
const { projectId, trainingRatio: datasetTrainingRatio } =
await prisma.dataset.findUniqueOrThrow({
where: { id: input.datasetId },
});
trainingRatio = datasetTrainingRatio;
await requireCanModifyProject(projectId, ctx);
} else if (input.newDatasetParams) {
await requireCanModifyProject(input.newDatasetParams.projectId, ctx);
datasetId = uuidv4();
} else {
return error("No datasetId or newDatasetParams provided");
}
if (!input.loggedCallIds) {
return error("No loggedCallIds provided");
}
const loggedCalls = await prisma.loggedCall.findMany({
where: {
id: {
in: input.loggedCallIds,
},
modelResponse: {
isNot: null,
},
},
include: {
modelResponse: {
select: {
reqPayload: true,
respPayload: true,
inputTokens: true,
outputTokens: true,
},
},
},
orderBy: { createdAt: "desc" },
});
const trainingRows = loggedCalls.map((loggedCall) => {
const inputMessages = (
loggedCall.modelResponse?.reqPayload as unknown as CompletionCreateParams
).messages;
let output: ChatCompletion.Choice.Message | undefined = undefined;
const resp = loggedCall.modelResponse?.respPayload as unknown as ChatCompletion | undefined;
if (resp && resp.choices?.[0]) {
output = resp.choices[0].message;
}
return {
input: inputMessages as unknown as CreateChatCompletionRequestMessage[],
output: output as unknown as CreateChatCompletionRequestMessage,
};
});
const datasetEntriesToCreate = await formatEntriesFromTrainingRows(datasetId, trainingRows);
// Ensure dataset and dataset entries are created atomically
await prisma.$transaction([
prisma.dataset.upsert({
where: { id: datasetId },
update: {},
create: {
id: datasetId,
projectId: input.newDatasetParams?.projectId ?? "",
name: input.newDatasetParams?.name ?? "",
trainingRatio,
},
}),
prisma.datasetEntry.createMany({
data: datasetEntriesToCreate,
}),
]);
return success(datasetId);
}),
update: protectedProcedure
.input(
z.object({
id: z.string(),
updates: z.object({
type: z.enum(["TRAIN", "TEST"]).optional(),
input: z.string().optional(),
output: z.string().optional(),
}),
}),
)
.mutation(async ({ input, ctx }) => {
const { dataset } = await prisma.datasetEntry.findUniqueOrThrow({
where: { id: input.id },
include: {
dataset: true,
},
});
if (!dataset) {
return error("Dataset not found for dataset entry");
}
await requireCanModifyProject(dataset.projectId, ctx);
let parsedInput = undefined;
let inputTokens = undefined;
if (input.updates.input) {
parsedInput = JSON.parse(input.updates.input);
inputTokens = countOpenAIChatTokens(
"gpt-4-0613",
parsedInput as unknown as CreateChatCompletionRequestMessage[],
);
}
let parsedOutput = undefined;
let outputTokens = undefined;
// The client might send "null" as a string, so we need to check for that
if (input.updates.output && input.updates.output !== "null") {
parsedOutput = JSON.parse(input.updates.output);
outputTokens = countOpenAIChatTokens("gpt-4-0613", [
parsedOutput as unknown as ChatCompletion.Choice.Message,
]);
}
await prisma.datasetEntry.update({
where: { id: input.id },
data: {
type: input.updates.type,
input: parsedInput,
output: parsedOutput,
inputTokens,
outputTokens,
},
});
return success("Dataset entry updated");
}),
delete: protectedProcedure
.input(z.object({ ids: z.string().array() }))
.mutation(async ({ input, ctx }) => {
if (input.ids.length === 0) {
return error("No ids provided");
}
const { dataset } = await prisma.datasetEntry.findUniqueOrThrow({
where: { id: input.ids[0] },
include: {
dataset: true,
},
});
if (!dataset) {
return error("Dataset not found for dataset entry");
}
await requireCanModifyProject(dataset.projectId, ctx);
await prisma.datasetEntry.deleteMany({
where: {
id: {
in: input.ids,
},
datasetId: dataset?.id,
},
});
return success("Dataset entries deleted");
}),
export: protectedProcedure
.input(
z.object({
datasetId: z.string(),
datasetEntryIds: z.string().array(),
testingSplit: z.number(),
removeDuplicates: z.boolean(),
}),
)
.mutation(async ({ input, ctx }) => {
const { projectId } = await prisma.dataset.findUniqueOrThrow({
where: { id: input.datasetId },
});
await requireCanViewProject(projectId, ctx);
const datasetEntries = await ctx.prisma.datasetEntry.findMany({
where: {
id: {
in: input.datasetEntryIds,
},
},
});
let rows: TrainingRow[] = datasetEntries.map((entry) => ({
input: entry.input as unknown as CreateChatCompletionRequestMessage[],
output: entry.output as unknown as CreateChatCompletionRequestMessage,
}));
if (input.removeDuplicates) {
const deduplicatedRows = [];
const rowHashSet = new Set<string>();
for (const row of rows) {
const rowHash = hashObject(row as unknown as JsonValue);
if (!rowHashSet.has(rowHash)) {
rowHashSet.add(rowHash);
deduplicatedRows.push(row);
}
}
rows = deduplicatedRows;
}
const splitIndex = Math.floor((rows.length * input.testingSplit) / 100);
const testingData = rows.slice(0, splitIndex);
const trainingData = rows.slice(splitIndex);
// Convert arrays to JSONL format
const trainingDataJSONL = trainingData.map((item) => JSON.stringify(item)).join("\n");
const testingDataJSONL = testingData.map((item) => JSON.stringify(item)).join("\n");
const output = new WritableStreamBuffer();
const archive = archiver("zip");
archive.pipe(output);
archive.append(trainingDataJSONL, { name: "train.jsonl" });
archive.append(testingDataJSONL, { name: "test.jsonl" });
await archive.finalize();
// Convert buffer to base64
const base64 = output.getContents().toString("base64");
return base64;
}),
});

View File

@@ -1,183 +0,0 @@
import { z } from "zod";
import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { requireCanModifyProject, requireCanViewProject } from "~/utils/accessControl";
import { error, success } from "~/utils/errorHandling/standardResponses";
import { generateServiceClientUrl } from "~/utils/azure/server";
import { queueImportDatasetEntries } from "~/server/tasks/importDatasetEntries.task";
export const datasetsRouter = createTRPCRouter({
get: protectedProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
const dataset = await prisma.dataset.findUniqueOrThrow({
where: { id: input.id },
include: {
project: true,
},
});
await requireCanViewProject(dataset.projectId, ctx);
return dataset;
}),
list: protectedProcedure
.input(z.object({ projectId: z.string() }))
.query(async ({ input, ctx }) => {
await requireCanViewProject(input.projectId, ctx);
return await prisma.dataset.findMany({
where: {
projectId: input.projectId,
},
include: {
_count: {
select: {
datasetEntries: true,
},
},
},
orderBy: { createdAt: "desc" },
});
}),
create: protectedProcedure
.input(
z.object({
projectId: z.string(),
name: z.string(),
}),
)
.mutation(async ({ input, ctx }) => {
await requireCanModifyProject(input.projectId, ctx);
const dataset = await prisma.dataset.create({
data: {
projectId: input.projectId,
name: input.name,
},
});
return success(dataset);
}),
update: protectedProcedure
.input(
z.object({
id: z.string(),
name: z.string(),
}),
)
.mutation(async ({ input, ctx }) => {
const { projectId } = await prisma.dataset.findUniqueOrThrow({
where: { id: input.id },
});
await requireCanModifyProject(projectId, ctx);
await prisma.dataset.update({
where: { id: input.id },
data: {
name: input.name,
},
});
return success("Dataset updated");
}),
delete: protectedProcedure
.input(z.object({ id: z.string() }))
.mutation(async ({ input, ctx }) => {
const { projectId } = await prisma.dataset.findUniqueOrThrow({
where: { id: input.id },
});
await requireCanModifyProject(projectId, ctx);
await prisma.dataset.delete({
where: { id: input.id },
});
return success("Dataset deleted");
}),
getServiceClientUrl: protectedProcedure
.input(z.object({ projectId: z.string() }))
.query(async ({ input, ctx }) => {
// The user must at least be authenticated to get a SAS token
await requireCanModifyProject(input.projectId, ctx);
return generateServiceClientUrl();
}),
triggerFileDownload: protectedProcedure
.input(
z.object({
datasetId: z.string(),
blobName: z.string(),
fileName: z.string(),
fileSize: z.number(),
}),
)
.mutation(async ({ input, ctx }) => {
const { projectId } = await prisma.dataset.findUniqueOrThrow({
where: { id: input.datasetId },
});
await requireCanViewProject(projectId, ctx);
const { id } = await prisma.datasetFileUpload.create({
data: {
datasetId: input.datasetId,
blobName: input.blobName,
status: "PENDING",
fileName: input.fileName,
fileSize: input.fileSize,
uploadedAt: new Date(),
},
});
await queueImportDatasetEntries(id);
}),
listFileUploads: protectedProcedure
.input(z.object({ datasetId: z.string() }))
.query(async ({ input, ctx }) => {
const { projectId } = await prisma.dataset.findUniqueOrThrow({
where: { id: input.datasetId },
});
await requireCanViewProject(projectId, ctx);
return await prisma.datasetFileUpload.findMany({
where: {
datasetId: input.datasetId,
visible: true,
},
orderBy: { createdAt: "desc" },
});
}),
hideFileUploads: protectedProcedure
.input(z.object({ fileUploadIds: z.string().array() }))
.mutation(async ({ input, ctx }) => {
if (!input.fileUploadIds.length) return error("No file upload ids provided");
const {
dataset: { projectId, id: datasetId },
} = await prisma.datasetFileUpload.findUniqueOrThrow({
where: { id: input.fileUploadIds[0] },
select: {
dataset: {
select: {
id: true,
projectId: true,
},
},
},
});
await requireCanModifyProject(projectId, ctx);
await prisma.datasetFileUpload.updateMany({
where: {
id: {
in: input.fileUploadIds,
},
datasetId,
},
data: {
visible: false,
},
});
}),
});

View File

@@ -1,4 +1,6 @@
import { z } from "zod";
import { v4 as uuidv4 } from "uuid";
import { type Prisma } from "@prisma/client";
import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
@@ -53,18 +55,14 @@ export const fineTunesRouter = createTRPCRouter({
create: protectedProcedure
.input(
z.object({
datasetId: z.string(),
projectId: z.string(),
selectedLogIds: z.array(z.string()),
slug: z.string(),
baseModel: z.string(),
}),
)
.mutation(async ({ input, ctx }) => {
const { projectId } = await prisma.dataset.findUniqueOrThrow({
where: {
id: input.datasetId,
},
});
await requireCanModifyProject(projectId, ctx);
await requireCanModifyProject(input.projectId, ctx);
const existingFineTune = await prisma.fineTune.findFirst({
where: {
@@ -76,14 +74,39 @@ export const fineTunesRouter = createTRPCRouter({
return error("A fine tune with that slug already exists");
}
await prisma.fineTune.create({
data: {
projectId,
slug: input.slug,
baseModel: input.baseModel,
datasetId: input.datasetId,
},
});
const newDatasetId = uuidv4();
const datasetEntriesToCreate: Prisma.DatasetEntryCreateManyDatasetInput[] =
input.selectedLogIds.map((loggedCallId) => ({
loggedCallId,
}));
await prisma.$transaction([
prisma.dataset.create({
data: {
id: newDatasetId,
name: input.slug,
project: {
connect: {
id: input.projectId,
},
},
datasetEntries: {
createMany: {
data: datasetEntriesToCreate,
},
},
},
}),
prisma.fineTune.create({
data: {
projectId: input.projectId,
slug: input.slug,
baseModel: input.baseModel,
datasetId: newDatasetId,
},
}),
]);
return success();
}),

View File

@@ -189,7 +189,7 @@ export const loggedCallsRouter = createTRPCRouter({
.input(
z.object({
projectId: z.string(),
loggedCallIds: z.string().array(),
selectedLogIds: z.string().array(),
testingSplit: z.number(),
selectedExportFormat: z.string(),
removeDuplicates: z.boolean(),
@@ -203,7 +203,7 @@ export const loggedCallsRouter = createTRPCRouter({
where: {
originalLoggedCall: {
projectId: input.projectId,
id: { in: input.loggedCallIds },
id: { in: input.selectedLogIds },
},
statusCode: 200,
},

View File

@@ -93,12 +93,17 @@ export const promptVariantsRouter = createTRPCRouter({
visible: true,
},
});
const finishedCount = await prisma.scenarioVariantCell.count({
const outputCount = await prisma.scenarioVariantCell.count({
where: {
promptVariantId: input.variantId,
testScenario: { visible: true },
retrievalStatus: {
in: ["COMPLETE", "ERROR"],
modelResponses: {
some: {
outdated: false,
respPayload: {
not: Prisma.AnyNull,
},
},
},
},
});
@@ -126,7 +131,7 @@ export const promptVariantsRouter = createTRPCRouter({
const inputTokens = overallTokens._sum?.inputTokens ?? 0;
const outputTokens = overallTokens._sum?.outputTokens ?? 0;
const awaitingCompletions = finishedCount < scenarioCount;
const awaitingCompletions = outputCount < scenarioCount;
const awaitingEvals = !!evalResults.find(
(result) => result.totalCount < scenarioCount * evals.length,
@@ -138,7 +143,7 @@ export const promptVariantsRouter = createTRPCRouter({
outputTokens,
overallCost: overallTokens._sum?.cost ?? 0,
scenarioCount,
finishedCount,
outputCount,
awaitingCompletions,
awaitingEvals,
};

View File

@@ -1,152 +0,0 @@
import { type DatasetFileUpload } from "@prisma/client";
import { prisma } from "~/server/db";
import defineTask from "./defineTask";
import { downloadBlobToString } from "~/utils/azure/server";
import {
type TrainingRow,
validateTrainingRows,
parseJSONL,
} from "~/components/datasets/validateTrainingRows";
import { formatEntriesFromTrainingRows } from "~/server/utils/createEntriesFromTrainingRows";
export type ImportDatasetEntriesJob = {
datasetFileUploadId: string;
};
export const importDatasetEntries = defineTask<ImportDatasetEntriesJob>(
"importDatasetEntries",
async (task) => {
const { datasetFileUploadId } = task;
const datasetFileUpload = await prisma.datasetFileUpload.findUnique({
where: { id: datasetFileUploadId },
});
if (!datasetFileUpload) {
await prisma.datasetFileUpload.update({
where: { id: datasetFileUploadId },
data: {
errorMessage: "Dataset File Upload not found",
status: "ERROR",
},
});
return;
}
await prisma.datasetFileUpload.update({
where: { id: datasetFileUploadId },
data: {
status: "DOWNLOADING",
progress: 5,
},
});
const onBlobDownloadProgress = async (progress: number) => {
await prisma.datasetFileUpload.update({
where: { id: datasetFileUploadId },
data: {
progress: 5 + Math.floor((progress / datasetFileUpload.fileSize) * 25),
},
});
};
const jsonlStr = await downloadBlobToString(datasetFileUpload.blobName, onBlobDownloadProgress);
let trainingRows: TrainingRow[] = [];
let validationError: string | null = null;
try {
trainingRows = parseJSONL(jsonlStr) as TrainingRow[];
validationError = validateTrainingRows(trainingRows);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
validationError = e.message;
}
if (validationError) {
await prisma.datasetFileUpload.update({
where: { id: datasetFileUploadId },
data: {
errorMessage: `Invalid JSONL: ${validationError}`,
status: "ERROR",
},
});
return;
}
await prisma.datasetFileUpload.update({
where: { id: datasetFileUploadId },
data: {
status: "PROCESSING",
progress: 30,
},
});
const updatePromises: Promise<DatasetFileUpload>[] = [];
const updateCallback = async (progress: number) => {
await prisma.datasetFileUpload.update({
where: { id: datasetFileUploadId },
data: {
progress: 30 + Math.floor((progress / trainingRows.length) * 69),
},
});
};
let datasetEntriesToCreate;
try {
datasetEntriesToCreate = await formatEntriesFromTrainingRows(
datasetFileUpload.datasetId,
trainingRows,
updateCallback,
500,
);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
await prisma.datasetFileUpload.update({
where: { id: datasetFileUploadId },
data: {
errorMessage: `Error formatting rows: ${e.message as string}`,
status: "ERROR",
visible: true,
},
});
return;
}
await Promise.all(updatePromises);
await prisma.datasetFileUpload.update({
where: { id: datasetFileUploadId },
data: {
status: "SAVING",
progress: 99,
},
});
await prisma.datasetEntry.createMany({
data: datasetEntriesToCreate,
});
await prisma.datasetFileUpload.update({
where: { id: datasetFileUploadId },
data: {
status: "COMPLETE",
progress: 100,
visible: true,
},
});
},
);
export const queueImportDatasetEntries = async (datasetFileUploadId: string) => {
await Promise.all([
prisma.datasetFileUpload.update({
where: {
id: datasetFileUploadId,
},
data: {
errorMessage: null,
status: "PENDING",
},
}),
importDatasetEntries.enqueue({ datasetFileUploadId }),
]);
};

View File

@@ -5,11 +5,10 @@ import "../../../sentry.server.config";
import { env } from "~/env.mjs";
import { queryModel } from "./queryModel.task";
import { runNewEval } from "./runNewEval.task";
import { importDatasetEntries } from "./importDatasetEntries.task";
console.log("Starting worker");
const registeredTasks = [queryModel, runNewEval, importDatasetEntries];
const registeredTasks = [queryModel, runNewEval];
const taskList = registeredTasks.reduce((acc, task) => {
acc[task.task.identifier] = task.task.handler;

View File

@@ -1,70 +0,0 @@
import { type Prisma } from "@prisma/client";
import { shuffle } from "lodash-es";
import {
type CreateChatCompletionRequestMessage,
type ChatCompletion,
} from "openai/resources/chat";
import { prisma } from "~/server/db";
import { type TrainingRow } from "~/components/datasets/validateTrainingRows";
import { countLlamaChatTokens } from "~/utils/countTokens";
export const formatEntriesFromTrainingRows = async (
datasetId: string,
trainingRows: TrainingRow[],
updateCallback?: (progress: number) => Promise<void>,
updateFrequency = 1000,
) => {
const [dataset, existingTrainingCount, existingTestingCount] = await prisma.$transaction([
prisma.dataset.findUnique({ where: { id: datasetId } }),
prisma.datasetEntry.count({
where: {
datasetId,
type: "TRAIN",
},
}),
prisma.datasetEntry.count({
where: {
datasetId,
type: "TEST",
},
}),
]);
const trainingRatio = dataset?.trainingRatio ?? 0.8;
const newTotalEntries = existingTrainingCount + existingTestingCount + trainingRows.length;
const numTrainingToAdd = Math.floor(trainingRatio * newTotalEntries) - existingTrainingCount;
const numTestingToAdd = trainingRows.length - numTrainingToAdd;
const typesToAssign = shuffle([
...Array(numTrainingToAdd).fill("TRAIN"),
...Array(numTestingToAdd).fill("TEST"),
]);
const datasetEntriesToCreate: Prisma.DatasetEntryCreateManyInput[] = [];
let i = 0;
for (const row of trainingRows) {
// console.log(row);
if (updateCallback && i % updateFrequency === 0) await updateCallback(i);
let outputTokens = 0;
if (row.output) {
outputTokens = countLlamaChatTokens([row.output as unknown as ChatCompletion.Choice.Message]);
}
// console.log("outputTokens", outputTokens);
datasetEntriesToCreate.push({
datasetId: datasetId,
input: row.input as unknown as Prisma.InputJsonValue,
output: (row.output as unknown as Prisma.InputJsonValue) ?? {
role: "assistant",
content: "",
},
inputTokens: countLlamaChatTokens(
row.input as unknown as CreateChatCompletionRequestMessage[],
),
outputTokens,
type: typesToAssign.pop() as "TRAIN" | "TEST",
});
i++;
}
return datasetEntriesToCreate;
};

View File

@@ -1,33 +0,0 @@
import { type SliceCreator } from "./store";
export type SelectedDatasetEntriesSlice = {
selectedIds: Set<string>;
toggleSelectedId: (id: string) => void;
addSelectedIds: (ids: string[]) => void;
clearSelectedIds: () => void;
};
export const createSelectedDatasetEntriesSlice: SliceCreator<SelectedDatasetEntriesSlice> = (
set,
) => ({
selectedIds: new Set(),
toggleSelectedId: (id: string) =>
set((state) => {
if (state.selectedDatasetEntries.selectedIds.has(id)) {
state.selectedDatasetEntries.selectedIds.delete(id);
} else {
state.selectedDatasetEntries.selectedIds.add(id);
}
}),
addSelectedIds: (ids: string[]) =>
set((state) => {
state.selectedDatasetEntries.selectedIds = new Set([
...state.selectedDatasetEntries.selectedIds,
...ids,
]);
}),
clearSelectedIds: () =>
set((state) => {
state.selectedDatasetEntries.selectedIds = new Set();
}),
});

View File

@@ -7,7 +7,7 @@ export type SelectedLogsSlice = {
clearSelectedLogIds: () => void;
};
export const createSelectedLogsSlice: SliceCreator<SelectedLogsSlice> = (set) => ({
export const createSelectedLogsSlice: SliceCreator<SelectedLogsSlice> = (set, get) => ({
selectedLogIds: new Set(),
toggleSelectedLogId: (id: string) =>
set((state) => {

View File

@@ -1,33 +0,0 @@
import loader, { type Monaco } from "@monaco-editor/loader";
import { type SliceCreator } from "./store";
export const editorBackground = "#fafafa";
export type SharedArgumentsEditorSlice = {
monaco: null | Monaco;
loadMonaco: () => Promise<void>;
};
export const createArgumentsEditorSlice: SliceCreator<SharedArgumentsEditorSlice> = (set, get) => ({
monaco: loader.__getMonacoInstance(),
loadMonaco: async () => {
// We only want to run this client-side
if (typeof window === "undefined") return;
const monaco = await loader.init();
monaco.editor.defineTheme("customTheme", {
base: "vs",
inherit: true,
rules: [],
colors: {
"editor.background": "#ffffff",
},
});
set((state) => {
state.sharedArgumentsEditor.monaco = monaco;
});
},
});

View File

@@ -7,17 +7,9 @@ import {
type SharedVariantEditorSlice,
createVariantEditorSlice,
} from "./sharedVariantEditor.slice";
import {
type SharedArgumentsEditorSlice,
createArgumentsEditorSlice,
} from "./sharedArgumentsEditor.slice";
import { type APIClient } from "~/utils/api";
import { type PersistedState, persistOptions } from "./persist";
import { type SelectedLogsSlice, createSelectedLogsSlice } from "./selectedLogsSlice";
import {
type SelectedDatasetEntriesSlice,
createSelectedDatasetEntriesSlice,
} from "./selectedDatasetEntriesSlice";
import { type LogFiltersSlice, createLogFiltersSlice } from "./logFiltersSlice";
import { type ColumnVisibilitySlice, createColumnVisibilitySlice } from "./columnVisiblitySlice";
import { type FeatureFlagsSlice, createFeatureFlagsSlice } from "./featureFlags";
@@ -26,14 +18,15 @@ enableMapSet();
export type State = {
isRehydrated: boolean;
drawerOpen: boolean;
openDrawer: () => void;
closeDrawer: () => void;
api: APIClient | null;
setApi: (api: APIClient) => void;
sharedVariantEditor: SharedVariantEditorSlice;
sharedArgumentsEditor: SharedArgumentsEditorSlice;
selectedProjectId: string | null;
setSelectedProjectId: (id: string) => void;
selectedLogs: SelectedLogsSlice;
selectedDatasetEntries: SelectedDatasetEntriesSlice;
logFilters: LogFiltersSlice;
columnVisibility: ColumnVisibilitySlice;
featureFlags: FeatureFlagsSlice;
@@ -53,15 +46,22 @@ const useBaseStore = create<State, [["zustand/persist", PersistedState], ["zusta
set((state) => {
state.api = api;
}),
drawerOpen: false,
openDrawer: () =>
set((state) => {
state.drawerOpen = true;
}),
closeDrawer: () =>
set((state) => {
state.drawerOpen = false;
}),
sharedVariantEditor: createVariantEditorSlice(set, get, ...rest),
sharedArgumentsEditor: createArgumentsEditorSlice(set, get, ...rest),
selectedProjectId: null,
setSelectedProjectId: (id: string) =>
set((state) => {
state.selectedProjectId = id;
}),
selectedLogs: createSelectedLogsSlice(set, get, ...rest),
selectedDatasetEntries: createSelectedDatasetEntriesSlice(set, get, ...rest),
logFilters: createLogFiltersSlice(set, get, ...rest),
columnVisibility: createColumnVisibilitySlice(set, get, ...rest),
featureFlags: createFeatureFlagsSlice(set, get, ...rest),

View File

@@ -1,93 +0,0 @@
import {
BlobServiceClient,
generateAccountSASQueryParameters,
AccountSASPermissions,
AccountSASServices,
AccountSASResourceTypes,
StorageSharedKeyCredential,
SASProtocol,
} from "@azure/storage-blob";
const accountName = process.env.AZURE_STORAGE_ACCOUNT_NAME;
if (!accountName) throw Error("Azure Storage accountName not found");
const accountKey = process.env.AZURE_STORAGE_ACCOUNT_KEY;
if (!accountKey) throw Error("Azure Storage accountKey not found");
const containerName = process.env.AZURE_STORAGE_CONTAINER_NAME;
if (!containerName) throw Error("Azure Storage containerName not found");
const sharedKeyCredential = new StorageSharedKeyCredential(accountName, accountKey);
const blobServiceClient = new BlobServiceClient(
`https://${accountName}.blob.core.windows.net`,
sharedKeyCredential,
);
const containerClient = blobServiceClient.getContainerClient(containerName);
export const generateServiceClientUrl = () => {
const sasOptions = {
services: AccountSASServices.parse("b").toString(), // blobs
resourceTypes: AccountSASResourceTypes.parse("sco").toString(), // service, container, object
permissions: AccountSASPermissions.parse("w"), // write permissions
protocol: SASProtocol.Https,
startsOn: new Date(),
expiresOn: new Date(new Date().valueOf() + 10 * 60 * 1000), // 10 minutes
};
let sasToken = generateAccountSASQueryParameters(sasOptions, sharedKeyCredential).toString();
// remove leading "?"
sasToken = sasToken[0] === "?" ? sasToken.substring(1) : sasToken;
return {
serviceClientUrl: `https://${accountName}.blob.core.windows.net?${sasToken}`,
containerName,
};
};
export async function downloadBlobToString(
blobName: string,
onProgress?: (progress: number) => Promise<void>,
chunkInterval?: number,
) {
const blobClient = containerClient.getBlobClient(blobName);
const downloadResponse = await blobClient.download();
if (!downloadResponse) throw Error("error downloading blob");
if (!downloadResponse.readableStreamBody)
throw Error("downloadResponse.readableStreamBody not found");
const downloaded = await streamToBuffer(
downloadResponse.readableStreamBody,
onProgress,
chunkInterval,
);
return downloaded.toString();
}
async function streamToBuffer(
readableStream: NodeJS.ReadableStream,
onProgress?: (progress: number) => Promise<void>,
chunkInterval = 1048576, // send progress every 1MB
): Promise<Buffer> {
return new Promise((resolve, reject) => {
const chunks: Uint8Array[] = [];
let bytesDownloaded = 0;
let lastReportedByteCount = 0;
readableStream.on("data", (data: ArrayBuffer) => {
chunks.push(data instanceof Buffer ? data : Buffer.from(data));
bytesDownloaded += data.byteLength;
if (onProgress && bytesDownloaded - lastReportedByteCount >= chunkInterval) {
void onProgress(bytesDownloaded); // progress in Bytes
lastReportedByteCount = bytesDownloaded;
}
});
readableStream.on("end", () => {
resolve(Buffer.concat(chunks));
});
readableStream.on("error", reject);
});
}

View File

@@ -1,30 +0,0 @@
import { BlobServiceClient } from "@azure/storage-blob";
import { v4 as uuidv4 } from "uuid";
import { useAppStore } from "~/state/store";
export const uploadDatasetEntryFile = async (file: File) => {
const { selectedProjectId: projectId, api } = useAppStore.getState();
if (!projectId) throw Error("projectId not found");
if (!api) throw Error("api not initialized");
const { serviceClientUrl, containerName } = await api.client.datasets.getServiceClientUrl.query({
projectId,
});
const blobServiceClient = new BlobServiceClient(serviceClientUrl);
// create container client
const containerClient = blobServiceClient.getContainerClient(containerName);
// base name without extension
const basename = file.name.split("/").pop()?.split(".").shift();
if (!basename) throw Error("basename not found");
const blobName = `${basename}-${uuidv4()}.jsonl`;
// create blob client
const blobClient = containerClient.getBlockBlobClient(blobName);
// upload file
await blobClient.uploadData(file);
return blobName;
};

View File

@@ -1,7 +1,5 @@
import { type ChatCompletion } from "openai/resources/chat";
import { GPTTokens } from "gpt-tokens";
import llamaTokenizer from "llama-tokenizer-js";
import { type SupportedModel } from "~/modelProviders/openai-ChatCompletion";
interface GPTTokensMessageItem {
@@ -14,21 +12,6 @@ export const countOpenAIChatTokens = (
model: SupportedModel,
messages: ChatCompletion.Choice.Message[],
) => {
const reformattedMessages = messages.map((message) => ({
role: message.role,
// Not completely accurate, but gives a rough idea of the token count
content: message.content ?? JSON.stringify(message.function_call),
}));
return new GPTTokens({
model,
messages: reformattedMessages as unknown as GPTTokensMessageItem[],
}).usedTokens;
};
export const countLlamaChatTokens = (messages: ChatCompletion.Choice.Message[]) => {
const stringToTokenize = messages
.map((message) => message.content || JSON.stringify(message.function_call))
.join("\n");
const tokens = llamaTokenizer.encode(stringToTokenize);
return tokens.length;
return new GPTTokens({ model, messages: messages as unknown as GPTTokensMessageItem[] })
.usedTokens;
};

View File

@@ -148,49 +148,6 @@ export const useScenarioVars = () => {
);
};
export const useDatasets = () => {
const selectedProjectId = useAppStore((state) => state.selectedProjectId);
return api.datasets.list.useQuery(
{ projectId: selectedProjectId ?? "" },
{ enabled: !!selectedProjectId },
);
};
export const useDataset = () => {
const router = useRouter();
const dataset = api.datasets.get.useQuery(
{ id: router.query.id as string },
{ enabled: !!router.query.id },
);
return dataset;
};
export const useDatasetEntries = () => {
const dataset = useDataset().data;
const { page, pageSize } = usePageParams();
const { data, isLoading, ...rest } = api.datasetEntries.list.useQuery(
{ datasetId: dataset?.id ?? "", page, pageSize },
{ enabled: !!dataset?.id },
);
const [stableData, setStableData] = useState(data);
useEffect(() => {
// Prevent annoying flashes while logs are loading from the server
if (!isLoading) {
setStableData(data);
}
}, [data, isLoading]);
return { data: stableData, isLoading, ...rest };
};
export const useDatasetEntry = (entryId: string | null) => {
return api.datasetEntries.get.useQuery({ id: entryId as string }, { enabled: !!entryId });
};
export const useLoggedCalls = (applyFilters = true) => {
const selectedProjectId = useAppStore((state) => state.selectedProjectId);
const { page, pageSize } = usePageParams();

View File

@@ -10,60 +10,3 @@ export const lookupModel = (provider: string, model: string) => {
export const modelLabel = (provider: string, model: string) =>
`${provider}/${lookupModel(provider, model)?.name ?? model}`;
// Check if the str could be parsed to a message function call
export const parseableToFunctionCall = (str: string) => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let parsedJSON: any;
try {
parsedJSON = JSON.parse(str);
} catch {
return false;
}
// Check if the parsedJSON is an object and not null
if (typeof parsedJSON !== "object" || parsedJSON === null) {
return false;
}
// Check if only the keys "name" and "arguments" exist
const keys = Object.keys(parsedJSON as Record<string, unknown>);
if (keys.length !== 2 || !keys.includes("name") || !keys.includes("arguments")) {
return false;
}
// Check if both "name" and "arguments" are of type string
if (typeof parsedJSON.name !== "string" || typeof parsedJSON.arguments !== "string") {
return false;
}
// Check if the "arguments" value is parseable to an object
let parsedArguments: unknown;
try {
parsedArguments = JSON.parse(parsedJSON["arguments"]);
} catch {
return false;
}
// Check if parsedArguments is an object and not null
if (typeof parsedArguments !== "object" || parsedArguments === null) {
return false;
}
return true;
};
export const formatFileSize = (bytes: number, decimals = 2) => {
if (bytes === 0) return "0 Bytes";
const k = 1024;
const dm = decimals < 0 ? 0 : decimals;
const sizes = ["Bytes", "KB", "MB", "GB", "TB"];
for (const size of sizes) {
if (bytes < k) return `${parseFloat(bytes.toFixed(dm))} ${size}`;
bytes /= k;
}
return "> 1024 TB";
};

View File

@@ -19,9 +19,7 @@
"baseUrl": ".",
"paths": {
"~/*": ["./src/*"]
},
"typeRoots": ["./types", "./node_modules/@types"],
"types": ["llama-tokenizer-js", "node"]
}
},
"include": [
".eslintrc.cjs",

View File

@@ -1,4 +0,0 @@
declare module "llama-tokenizer-js" {
export function encode(input: string): number[];
export function decode(input: number[]): string;
}

View File

@@ -1,28 +1,27 @@
#!/usr/bin/env bash
# Adapted from https://github.com/openai/openai-node/blob/master/build
set -exuo pipefail
rm -rf dist
rm -rf dist /tmp/openpipe-build-dist
npx tsup
mkdir /tmp/openpipe-build-dist
# copy the package.json file to /dist
cp package.json dist
# copy the README.md file to /dist
cp README.md dist
cp -rp * /tmp/openpipe-build-dist
# Rename package name in package.json
python3 -c "
import json
# Load the package.json file
with open('dist/package.json', 'r') as file:
data = json.load(file)
# Change the names
with open('/tmp/openpipe-build-dist/package.json', 'r') as f:
data = json.load(f)
data['name'] = 'openpipe'
# Write the changes back to the package.json file
with open('dist/package.json', 'w') as file:
json.dump(data, file, indent=2)
with open('/tmp/openpipe-build-dist/package.json', 'w') as f:
json.dump(data, f, indent=4)
"
rm -rf /tmp/openpipe-build-dist/node_modules
mv /tmp/openpipe-build-dist dist
# build to .js files
(cd dist && npm exec tsc -- --noEmit false)

View File

@@ -1,12 +1,12 @@
import dotenv from "dotenv";
import { expect, test } from "vitest";
import OpenAI from "../openai";
import OpenAI from ".";
import {
ChatCompletion,
CompletionCreateParams,
CreateChatCompletionRequestMessage,
} from "openai-beta/resources/chat/completions";
import { OPClient } from "../../codegen";
import { OPClient } from "../codegen";
import mergeChunks from "./mergeChunks";
import assert from "assert";

View File

@@ -7,10 +7,10 @@ import {
CompletionCreateParams,
} from "openai-beta/resources/chat/completions";
import { WrappedStream } from "./openai/streaming";
import { WrappedStream } from "./streaming";
import { DefaultService, OPClient } from "../codegen";
import { Stream } from "openai-beta/streaming";
import { OpenPipeArgs, OpenPipeMeta, type OpenPipeConfig, getTags } from "./shared";
import { OpenPipeArgs, OpenPipeMeta, type OpenPipeConfig, getTags } from "../shared";
export type ClientOptions = openai.ClientOptions & { openpipe?: OpenPipeConfig };
export default class OpenAI extends openai.OpenAI {

View File

@@ -1,34 +1,16 @@
{
"name": "openpipe-dev",
"version": "0.4.0-beta.3",
"version": "0.3.5",
"type": "module",
"description": "Metrics and auto-evaluation for LLM calls",
"scripts": {
"build": "./build.sh",
"build-update": "./build.sh && ./update-app.sh",
"test": "vitest"
},
"main": "./src/index.ts",
"main": "./index.ts",
"publishConfig": {
"name": "openpipe",
"access": "public",
"main": "./index.cjs",
"module": "./index.js",
"types": "./index.d.ts",
"exports": {
".": {
"import": "./index.js",
"require": "./index.cjs"
},
"./openai": {
"import": "./openai.js",
"require": "./openai.cjs"
},
"./openai/mergeChunks": {
"import": "./openai/mergeChunks.js",
"require": "./openai/mergeChunks.cjs"
}
}
"main": "./index.js"
},
"keywords": [],
"author": "",
@@ -42,16 +24,10 @@
"openai-legacy": "npm:openai@3.3.0"
},
"devDependencies": {
"@rollup/plugin-json": "^6.0.0",
"@rollup/plugin-node-resolve": "^15.2.1",
"@types/lodash-es": "^4.17.8",
"@types/node": "^20.4.8",
"@types/node-fetch": "^2.6.4",
"dotenv": "^16.3.1",
"rollup": "^3.28.1",
"rollup-plugin-typescript2": "^0.35.0",
"tslib": "^2.6.2",
"tsup": "^7.2.0",
"tsx": "^3.12.7",
"typescript": "^5.0.4",
"vitest": "^0.33.0"

View File

@@ -6,4 +6,4 @@ set -exuo pipefail
./build.sh
(cd dist && pnpm publish --access public --tag beta --no-git-checks)
(cd dist && pnpm publish --access public)

View File

@@ -1,6 +1,6 @@
import pkg from "../package.json";
import pkg from "./package.json";
import { DefaultService } from "../codegen";
import { DefaultService } from "./codegen";
export type OpenPipeConfig = {
apiKey?: string;

View File

@@ -12,6 +12,7 @@
"moduleResolution": "node",
"resolveJsonModule": true,
"isolatedModules": true,
"incremental": true,
"noUncheckedIndexedAccess": true,
"noEmit": true,
"sourceMap": true,

View File

@@ -1,24 +0,0 @@
import { Options } from "tsup";
const config: Options = {
splitting: false, // Disable code splitting
sourcemap: true, // Include sourcemaps
target: "es2020", // Target ES2020 syntax for modern environments
dts: true, // Generate declaration files
// Define entry points
entry: [
"src/index.ts", // Main entry
"src/openai.ts", // 'openai' sub-module
"src/openai/mergeChunks.ts", // 'openai/mergeChunks' sub-module
"src/openai/streaming.ts", // 'openai/streaming' sub-module
],
// Define format of the output bundles
format: ["cjs", "esm"],
// External libraries that shouldn't be bundled
external: [...Object.keys(require("./package.json").dependencies || {})],
};
export default config;

View File

@@ -1,4 +0,0 @@
(cd dist && yalc publish)
(cd ../../app && yalc add openpipe)
(cd ../../app && pnpm install)

View File

@@ -1,19 +0,0 @@
---
title: "Import Data - Beta"
sidebarTitle: "Import Data"
description: "
Import external data to kickstart your fine-tuning process. Use the OpenAI chat fine-tuning format."
---
Upload a JSONL file populated with an array of OpenAI messages as input, and a single message as output.
<Frame>![](/images/features/importing-data.png)</Frame>
Compatible with both typical content messages and function calls. Input arrays must contain one or more messages.
```jsonl
...
{"input":[{"role":"system","content":"You are a helpful assistant"},{"role":"user","content":"What is the capitol of Sweden?"}],"output":{"role":"assistant","content":null,"function_call":{"name":"log_capitol","arguments":"{\"capitol\":\"Stockholm\"}"}}}
{"input":[{"role":"system","content":"You are a helpful assistant"},{"role":"user","content":"What is the capitol of Tasmania?"}],"output":{"role":"assistant","content":null,"function_call":{"name":"log_capitol","arguments":"{\"capitol\":\"Hobart\"}"}}}
...
```

Binary file not shown.

Before

Width:  |  Height:  |  Size: 408 KiB

View File

@@ -41,7 +41,6 @@
"group": "Features",
"pages": [
"features/log-filters",
"features/importing-data",
"features/exporting-data",
"features/fine-tuning",
"features/experiments"

4
examples/.gitignore vendored
View File

@@ -1,6 +1,4 @@
axolotl/
models/
data/
wandb/
cache/
.ipynb_checkpoints/
wandb/

View File

@@ -0,0 +1,473 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Current Time: 2023-08-24 21:25:06\n",
"Current Time: 2023-08-24 21:25:36\n"
]
}
],
"source": [
"import time\n",
"\n",
"while True:\n",
" current_time = time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime())\n",
" print(f\"Current Time: {current_time}\")\n",
" time.sleep(30)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"I'm pretty happy with my model's accuracy relative to GPT-4. How does it compare cost-wise?\n",
"\n",
"I'll really push this to its limits -- let's see how quickly our poor model can classify the [full 2-million-recipe dataset](https://huggingface.co/datasets/corbt/all-recipes) 😈."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: datasets==2.14.4 in /usr/local/lib/python3.10/dist-packages (2.14.4)\n",
"Requirement already satisfied: vllm==0.1.3 in /usr/local/lib/python3.10/dist-packages (0.1.3)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (1.24.4)\n",
"Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (12.0.1)\n",
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.3.7)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.0.3)\n",
"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.28.1)\n",
"Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (4.66.1)\n",
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.3.0)\n",
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.70.15)\n",
"Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2023.6.0)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.8.5)\n",
"Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.16.4)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (23.1)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (6.0)\n",
"Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (5.9.5)\n",
"Requirement already satisfied: ray>=2.5.1 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.6.3)\n",
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.1.99)\n",
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.0.1+cu118)\n",
"Requirement already satisfied: transformers>=4.31.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (4.33.0.dev0)\n",
"Requirement already satisfied: xformers>=0.0.19 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.0.21)\n",
"Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.101.1)\n",
"Requirement already satisfied: uvicorn in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.23.2)\n",
"Requirement already satisfied: pydantic<2 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.10.12)\n",
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (23.1.0)\n",
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (2.1.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (4.0.3)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.9.2)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.4.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.3.1)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (3.9.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (4.7.1)\n",
"Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (8.1.7)\n",
"Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.18.0)\n",
"Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.0.5)\n",
"Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.24.1)\n",
"Requirement already satisfied: grpcio>=1.42.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.57.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (3.4)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (1.26.13)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (2022.12.7)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.0)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.1.2)\n",
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (2.0.0)\n",
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (3.25.0)\n",
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (15.0.7)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (2023.8.8)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.3.2)\n",
"Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi->vllm==0.1.3) (0.27.0)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
"Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn->vllm==0.1.3) (0.14.0)\n",
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->datasets==2.14.4) (1.16.0)\n",
"Requirement already satisfied: anyio<5,>=3.4.0 in /usr/local/lib/python3.10/dist-packages (from starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (3.7.1)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.0.0->vllm==0.1.3) (2.1.2)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (2023.6.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.29.1)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.8.10)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.0.0->vllm==0.1.3) (1.2.1)\n",
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.3.0)\n",
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.1.2)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install datasets==2.14.4 vllm==0.1.3"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of recipes: 2,147,248\n"
]
}
],
"source": [
"from datasets import load_dataset\n",
"\n",
"all_recipes = load_dataset(\"corbt/all-recipes\")[\"train\"][\"input\"]\n",
"\n",
"print(f\"Number of recipes: {len(all_recipes):,}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO 08-24 19:38:29 llm_engine.py:70] Initializing an LLM engine with config: model='./models/run1/merged', tokenizer='./models/run1/merged', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n",
"INFO 08-24 19:39:48 llm_engine.py:196] # GPU blocks: 3419, # CPU blocks: 512\n"
]
}
],
"source": [
"from vllm import LLM, SamplingParams\n",
"\n",
"llm = LLM(model=\"./models/run1/merged\", max_num_batched_tokens=4096)\n",
"\n",
"sampling_params = SamplingParams(\n",
" # 120 should be fine for the work we're doing here.\n",
" max_tokens=120,\n",
" # This is a deterministic task so temperature=0 is best.\n",
" temperature=0,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Start time: 1692906050.3340027\n",
"Processing recipes 0 to 10,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 10000/10000 [04:51<00:00, 34.30it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing recipes 10,000 to 20,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 10000/10000 [04:54<00:00, 33.98it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing recipes 20,000 to 30,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 10000/10000 [04:53<00:00, 34.11it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing recipes 30,000 to 40,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 10000/10000 [04:53<00:00, 34.11it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing recipes 40,000 to 50,000...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 48%|████▊ | 4796/10000 [02:21<03:18, 26.22it/s]"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[6], line 12\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(\u001b[39m0\u001b[39m, \u001b[39mlen\u001b[39m(all_recipes), BATCH_SIZE):\n\u001b[1;32m 11\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mProcessing recipes \u001b[39m\u001b[39m{\u001b[39;00mi\u001b[39m:\u001b[39;00m\u001b[39m,\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m to \u001b[39m\u001b[39m{\u001b[39;00mi\u001b[39m+\u001b[39mBATCH_SIZE\u001b[39m:\u001b[39;00m\u001b[39m,\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m...\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m---> 12\u001b[0m outputs \u001b[39m=\u001b[39m llm\u001b[39m.\u001b[39;49mgenerate(all_recipes[i:i\u001b[39m+\u001b[39;49mBATCH_SIZE], sampling_params\u001b[39m=\u001b[39;49msampling_params)\n\u001b[1;32m 14\u001b[0m all_outputs\u001b[39m.\u001b[39mextend([o\u001b[39m.\u001b[39moutputs[\u001b[39m0\u001b[39m]\u001b[39m.\u001b[39mtext \u001b[39mfor\u001b[39;00m o \u001b[39min\u001b[39;00m outputs])\n\u001b[1;32m 16\u001b[0m end_time \u001b[39m=\u001b[39m time\u001b[39m.\u001b[39mtime()\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/llm.py:130\u001b[0m, in \u001b[0;36mLLM.generate\u001b[0;34m(self, prompts, sampling_params, prompt_token_ids, use_tqdm)\u001b[0m\n\u001b[1;32m 128\u001b[0m token_ids \u001b[39m=\u001b[39m prompt_token_ids[i]\n\u001b[1;32m 129\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_add_request(prompt, sampling_params, token_ids)\n\u001b[0;32m--> 130\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_run_engine(use_tqdm)\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/llm.py:150\u001b[0m, in \u001b[0;36mLLM._run_engine\u001b[0;34m(self, use_tqdm)\u001b[0m\n\u001b[1;32m 148\u001b[0m outputs: List[RequestOutput] \u001b[39m=\u001b[39m []\n\u001b[1;32m 149\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mllm_engine\u001b[39m.\u001b[39mhas_unfinished_requests():\n\u001b[0;32m--> 150\u001b[0m step_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mllm_engine\u001b[39m.\u001b[39;49mstep()\n\u001b[1;32m 151\u001b[0m \u001b[39mfor\u001b[39;00m output \u001b[39min\u001b[39;00m step_outputs:\n\u001b[1;32m 152\u001b[0m \u001b[39mif\u001b[39;00m output\u001b[39m.\u001b[39mfinished:\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py:313\u001b[0m, in \u001b[0;36mLLMEngine.step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 307\u001b[0m \u001b[39mreturn\u001b[39;00m [\n\u001b[1;32m 308\u001b[0m RequestOutput\u001b[39m.\u001b[39mfrom_seq_group(seq_group)\n\u001b[1;32m 309\u001b[0m \u001b[39mfor\u001b[39;00m seq_group \u001b[39min\u001b[39;00m scheduler_outputs\u001b[39m.\u001b[39mignored_seq_groups\n\u001b[1;32m 310\u001b[0m ]\n\u001b[1;32m 312\u001b[0m \u001b[39m# Execute the model.\u001b[39;00m\n\u001b[0;32m--> 313\u001b[0m output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_run_workers(\n\u001b[1;32m 314\u001b[0m \u001b[39m\"\u001b[39;49m\u001b[39mexecute_model\u001b[39;49m\u001b[39m\"\u001b[39;49m,\n\u001b[1;32m 315\u001b[0m seq_group_metadata_list\u001b[39m=\u001b[39;49mseq_group_metadata_list,\n\u001b[1;32m 316\u001b[0m blocks_to_swap_in\u001b[39m=\u001b[39;49mscheduler_outputs\u001b[39m.\u001b[39;49mblocks_to_swap_in,\n\u001b[1;32m 317\u001b[0m blocks_to_swap_out\u001b[39m=\u001b[39;49mscheduler_outputs\u001b[39m.\u001b[39;49mblocks_to_swap_out,\n\u001b[1;32m 318\u001b[0m blocks_to_copy\u001b[39m=\u001b[39;49mscheduler_outputs\u001b[39m.\u001b[39;49mblocks_to_copy,\n\u001b[1;32m 319\u001b[0m )\n\u001b[1;32m 320\u001b[0m \u001b[39m# Update the scheduler with the model outputs.\u001b[39;00m\n\u001b[1;32m 321\u001b[0m seq_groups \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mscheduler\u001b[39m.\u001b[39mupdate(output)\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py:470\u001b[0m, in \u001b[0;36mLLMEngine._run_workers\u001b[0;34m(self, method, get_all_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 467\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 468\u001b[0m executor \u001b[39m=\u001b[39m \u001b[39mgetattr\u001b[39m(worker, method)\n\u001b[0;32m--> 470\u001b[0m output \u001b[39m=\u001b[39m executor(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 471\u001b[0m all_outputs\u001b[39m.\u001b[39mappend(output)\n\u001b[1;32m 473\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mparallel_config\u001b[39m.\u001b[39mworker_use_ray:\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[39m@functools\u001b[39m\u001b[39m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mdecorate_context\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[39mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[39mreturn\u001b[39;00m func(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/worker/worker.py:293\u001b[0m, in \u001b[0;36mWorker.execute_model\u001b[0;34m(self, seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)\u001b[0m\n\u001b[1;32m 289\u001b[0m input_tokens, input_positions, input_metadata \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_prepare_inputs(\n\u001b[1;32m 290\u001b[0m seq_group_metadata_list)\n\u001b[1;32m 292\u001b[0m \u001b[39m# Execute the model.\u001b[39;00m\n\u001b[0;32m--> 293\u001b[0m output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel(\n\u001b[1;32m 294\u001b[0m input_ids\u001b[39m=\u001b[39;49minput_tokens,\n\u001b[1;32m 295\u001b[0m positions\u001b[39m=\u001b[39;49minput_positions,\n\u001b[1;32m 296\u001b[0m kv_caches\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mgpu_cache,\n\u001b[1;32m 297\u001b[0m input_metadata\u001b[39m=\u001b[39;49minput_metadata,\n\u001b[1;32m 298\u001b[0m cache_events\u001b[39m=\u001b[39;49mcache_events,\n\u001b[1;32m 299\u001b[0m )\n\u001b[1;32m 300\u001b[0m \u001b[39mreturn\u001b[39;00m output\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/llama.py:255\u001b[0m, in \u001b[0;36mLlamaForCausalLM.forward\u001b[0;34m(self, input_ids, positions, kv_caches, input_metadata, cache_events)\u001b[0m\n\u001b[1;32m 245\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\n\u001b[1;32m 246\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 247\u001b[0m input_ids: torch\u001b[39m.\u001b[39mTensor,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 251\u001b[0m cache_events: Optional[List[torch\u001b[39m.\u001b[39mcuda\u001b[39m.\u001b[39mEvent]],\n\u001b[1;32m 252\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Dict[\u001b[39mint\u001b[39m, SequenceOutputs]:\n\u001b[1;32m 253\u001b[0m hidden_states \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel(input_ids, positions, kv_caches,\n\u001b[1;32m 254\u001b[0m input_metadata, cache_events)\n\u001b[0;32m--> 255\u001b[0m next_tokens \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49msampler(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mlm_head\u001b[39m.\u001b[39;49mweight, hidden_states,\n\u001b[1;32m 256\u001b[0m input_metadata)\n\u001b[1;32m 257\u001b[0m \u001b[39mreturn\u001b[39;00m next_tokens\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/sampler.py:44\u001b[0m, in \u001b[0;36mSampler.forward\u001b[0;34m(self, embedding, hidden_states, input_metadata, embedding_bias)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\n\u001b[1;32m 37\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 38\u001b[0m embedding: torch\u001b[39m.\u001b[39mTensor,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 42\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Dict[\u001b[39mint\u001b[39m, SequenceOutputs]:\n\u001b[1;32m 43\u001b[0m \u001b[39m# Get the hidden states that we use for sampling.\u001b[39;00m\n\u001b[0;32m---> 44\u001b[0m hidden_states \u001b[39m=\u001b[39m _prune_hidden_states(hidden_states, input_metadata)\n\u001b[1;32m 46\u001b[0m \u001b[39m# Get the logits for the next tokens.\u001b[39;00m\n\u001b[1;32m 47\u001b[0m logits \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mmatmul(hidden_states, embedding\u001b[39m.\u001b[39mt())\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"# We'll process our recipes in batches of 10,000.\n",
"\n",
"import time\n",
"\n",
"BATCH_SIZE = 10000\n",
"all_outputs = []\n",
"\n",
"start_time = time.time()\n",
"print(f\"Start time: {start_time}\")\n",
"for i in range(0, len(all_recipes), BATCH_SIZE):\n",
" print(f\"Processing recipes {i:,} to {i+BATCH_SIZE:,}...\")\n",
" outputs = llm.generate(\n",
" all_recipes[i : i + BATCH_SIZE], sampling_params=sampling_params\n",
" )\n",
"\n",
" all_outputs.extend([o.outputs[0].text for o in outputs])\n",
"\n",
"end_time = time.time()\n",
"print(f\"End time: {end_time}\")\n",
"print(f\"Total hours: {((end_time - start_time) / 3600):.2f}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Nice! I've processed all 2,147,248 recipes in under 17 hours. Let's do a cost comparison with GPT-3.5 and GPT-4. I'll use the GPT-4 latency/cost numbers based on the 5000 samples used to generate our model's training data."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Model</th>\n",
" <th>Cost to Classify One Recipe</th>\n",
" <th>Cost to Classify Entire Dataset</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Llama 2 7B (finetuned)</td>\n",
" <td>0.000009</td>\n",
" <td>18.86</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>GPT-3.5</td>\n",
" <td>0.000481</td>\n",
" <td>1,033.26</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>GPT-3.5 (finetuned)</td>\n",
" <td>0.004044</td>\n",
" <td>8,683.47</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>GPT-4</td>\n",
" <td>0.010800</td>\n",
" <td>23,190.28</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Model Cost to Classify One Recipe \\\n",
"0 Llama 2 7B (finetuned) 0.000009 \n",
"1 GPT-3.5 0.000481 \n",
"2 GPT-3.5 (finetuned) 0.004044 \n",
"3 GPT-4 0.010800 \n",
"\n",
" Cost to Classify Entire Dataset \n",
"0 18.86 \n",
"1 1,033.26 \n",
"2 8,683.47 \n",
"3 23,190.28 "
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"# I used an on-demand Nvidia L40 on RunPod for this, at an hourly cost of $1.14.\n",
"finetuned_hourly_cost = 1.14\n",
"\n",
"finetuned_total_hours = 16.54\n",
"\n",
"finetuned_avg_cost = finetuned_hourly_cost * finetuned_total_hours / len(all_recipes)\n",
"\n",
"# The average input and output tokens calculated by OpenAI, based on the 5000 recipes I sent them\n",
"avg_input_tokens = 276\n",
"avg_output_tokens = 42\n",
"\n",
"# Token pricing from https://openai.com/pricing\n",
"gpt_4_avg_cost = avg_input_tokens * 0.03 / 1000 + avg_output_tokens * 0.06 / 1000\n",
"\n",
"gpt_35_avg_cost = avg_input_tokens * 0.0015 / 1000 + avg_output_tokens * 0.0016 / 1000\n",
"\n",
"gpt_35_finetuned_avg_cost = (\n",
" avg_input_tokens * 0.012 / 1000 + avg_output_tokens * 0.016 / 1000 + 0.06 / 1000\n",
")\n",
"\n",
"# Multiply the number of recipes\n",
"# gpt_4_cost = len(all_recipes) * gpt_4_avg_cost\n",
"# gpt_35_cost = len(all_recipes) * gpt_35_avg_cost\n",
"# gpt_35_finetuned_cost = len(all_recipes) * gpt_35_finetuned_avg_cost\n",
"\n",
"# Let's put this in a dataframe for easier comparison.\n",
"\n",
"costs = pd.DataFrame(\n",
" {\n",
" \"Model\": [\n",
" \"Llama 2 7B (finetuned)\",\n",
" \"GPT-3.5\",\n",
" \"GPT-3.5 (finetuned)\",\n",
" \"GPT-4\",\n",
" ],\n",
" \"Cost to Classify One Recipe\": [\n",
" finetuned_avg_cost,\n",
" gpt_35_avg_cost,\n",
" gpt_35_finetuned_avg_cost,\n",
" gpt_4_avg_cost,\n",
" ],\n",
" }\n",
")\n",
"\n",
"costs[\"Cost to Classify Entire Dataset\"] = (\n",
" costs[\"Cost to Classify One Recipe\"] * len(all_recipes)\n",
").map(lambda x: f\"{x:,.2f}\")\n",
"\n",
"\n",
"costs\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"...and just for fun, let's figure out how many recipes my pescatarian basement-dwelling brother can make! 😂"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -1,14 +1,10 @@
# Tutorial: Fine-Tune your Own Llama 2
# OpenPipe demo: fine-tuning your own model
Hi there! This directory should give you a brief overview of how to fine-tune a Llama 2 model from start to finish. The example model we're training will classify recipes from [a large dataset scraped from the internet](https://www.kaggle.com/datasets/wilmerarltstrmberg/recipe-dataset-over-2m). We'll use GPT-4 to generate labels for our training and test set, then fine-tune a Llama 2 model using the [axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) library. You should review the notebooks in this directory in the following order:
Hi there! This repository should give you a brief overview of how to fine-tune a competitive model from start to finish. You should review the notebooks in this directory in the following order:
1. [./1-generate-data.ipynb](./1-generate-data.ipynb): Demonstrates how to generate a sample dataset of GPT-4 completions, store it using OpenPipe, and then export it in a format suitable for training a model.
2. [./2-train.ipynb](./2-train.ipynb): Trains a Llama 2 7B model on the dataset from step (1).
3. [./3-evaluate.ipynb](./3-evaluate.ipynb): Evaluates the model we trained using a special test set that we set aside in step (1).
4. [./4-benchmark.ipynb](./4-benchmark.ipynb): A script to compare costs and completion latencies between our fine-tuned model, GPT-3.5, and GPT-4.
1. [./generate-data.ipynb](./generate-data.ipynb): Demonstrates how to generate a sample dataset of GPT-4 completions, store it using OpenPipe, and then export it in a format suitable for training a model.
2. [./train.ipynb](./train.ipynb): Trains a Llama 2 7B model on the dataset from step (1).
3. [./evaluate.ipynb](./evaluate.ipynb): Evaluates the model we trained using a special test set that we set aside in step (1).
4. [./benchmark.ipynb](./benchmark.ipynb): A script to compare costs and completion latencies between our fine-tuned model, GPT-3.5, and GPT-4.
If you want to follow along yourself, I recommend using [RunPod](https://www.runpod.io/). The training scripts we use will run on any of their GPUs with 24GB of vRAM or more.
## About OpenPipe
[OpenPipe](https://openpipe.ai) is an open-source company that makes it easy for product engineers to build and deploy their own fine-tuned models. OpenPipe actually takes care of a lot of the steps this repository covers for you automatically, but we still wanted to give back and explain how fine-tuning works under the hood.
If you want to follow along yourself, I recommend using [RunPod](https://www.runpod.io/). The training scripts we use will run on any of their GPUs with 24GB of vRAM or more.

View File

@@ -6,22 +6,104 @@
"source": [
"I'm pretty happy with my model's accuracy relative to GPT-4. How does it compare cost-wise?\n",
"\n",
"I'll really push this to its limits -- let's see how quickly our poor model can classify the [full 2-million-recipe dataset](https://huggingface.co/datasets/corbt/all-recipes) 😈.\n"
"I'll really push this to its limits -- let's see how quickly our poor model can classify the [full 2-million-recipe dataset](https://huggingface.co/datasets/corbt/all-recipes) 😈."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: datasets==2.14.4 in /usr/local/lib/python3.10/dist-packages (2.14.4)\n",
"Requirement already satisfied: vllm==0.1.3 in /usr/local/lib/python3.10/dist-packages (0.1.3)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (1.24.4)\n",
"Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (12.0.1)\n",
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.3.7)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.0.3)\n",
"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2.28.1)\n",
"Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (4.66.1)\n",
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.3.0)\n",
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.70.15)\n",
"Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (2023.6.0)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (3.8.5)\n",
"Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (0.16.4)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (23.1)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.4) (6.0)\n",
"Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (5.9.5)\n",
"Requirement already satisfied: ray>=2.5.1 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.6.3)\n",
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.1.99)\n",
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.0.1+cu118)\n",
"Requirement already satisfied: transformers>=4.31.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (4.33.0.dev0)\n",
"Requirement already satisfied: xformers>=0.0.19 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.0.21)\n",
"Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.101.1)\n",
"Requirement already satisfied: uvicorn in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.23.2)\n",
"Requirement already satisfied: pydantic<2 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.10.12)\n",
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (23.1.0)\n",
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (2.1.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (4.0.3)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.9.2)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.4.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.4) (1.3.1)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (3.9.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.4) (4.7.1)\n",
"Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (8.1.7)\n",
"Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.18.0)\n",
"Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.0.5)\n",
"Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.24.1)\n",
"Requirement already satisfied: grpcio>=1.42.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.57.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (3.4)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (1.26.13)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.14.4) (2022.12.7)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.0)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.1.2)\n",
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (2.0.0)\n",
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (3.25.0)\n",
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (15.0.7)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (2023.8.8)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.3.2)\n",
"Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi->vllm==0.1.3) (0.27.0)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.14.4) (2023.3)\n",
"Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn->vllm==0.1.3) (0.14.0)\n",
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->datasets==2.14.4) (1.16.0)\n",
"Requirement already satisfied: anyio<5,>=3.4.0 in /usr/local/lib/python3.10/dist-packages (from starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (3.7.1)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.0.0->vllm==0.1.3) (2.1.2)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (2023.6.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.29.1)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.8.10)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.0.0->vllm==0.1.3) (1.2.1)\n",
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.3.0)\n",
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.1.2)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%%capture\n",
"%pip install datasets==2.14.4 vllm==0.1.3"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"outputs": [
{
@@ -194,12 +276,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Nice! I've processed all 2,147,248 recipes in under 17 hours. Let's do a cost comparison with GPT-3.5 and GPT-4. I'll use the GPT-4 latency/cost numbers based on the 5000 samples used to generate our model's training data.\n"
"Nice! I've processed all 2,147,248 recipes in under 17 hours. Let's do a cost comparison with GPT-3.5 and GPT-4. I'll use the GPT-4 latency/cost numbers based on the 5000 samples used to generate our model's training data."
]
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 19,
"metadata": {},
"outputs": [
{
@@ -231,47 +313,47 @@
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Llama2 7B (FT)</td>\n",
" <td>Llama 2 7B (finetuned)</td>\n",
" <td>0.000009</td>\n",
" <td>18.81</td>\n",
" <td>18.86</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>GPT-3.5</td>\n",
" <td>0.000481</td>\n",
" <td>1033.26</td>\n",
" <td>1,033.26</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>GPT-3.5 (FT)</td>\n",
" <td>GPT-3.5 (finetuned)</td>\n",
" <td>0.004044</td>\n",
" <td>8683.47</td>\n",
" <td>8,683.47</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>GPT-4</td>\n",
" <td>0.010800</td>\n",
" <td>23190.28</td>\n",
" <td>23,190.28</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Model Cost to Classify One Recipe \\\n",
"0 Llama2 7B (FT) 0.000009 \n",
"1 GPT-3.5 0.000481 \n",
"2 GPT-3.5 (FT) 0.004044 \n",
"3 GPT-4 0.010800 \n",
" Model Cost to Classify One Recipe \\\n",
"0 Llama 2 7B (finetuned) 0.000009 \n",
"1 GPT-3.5 0.000481 \n",
"2 GPT-3.5 (finetuned) 0.004044 \n",
"3 GPT-4 0.010800 \n",
"\n",
" Cost to Classify Entire Dataset \n",
"0 18.81 \n",
"1 1033.26 \n",
"2 8683.47 \n",
"3 23190.28 "
" Cost to Classify Entire Dataset \n",
"0 18.86 \n",
"1 1,033.26 \n",
"2 8,683.47 \n",
"3 23,190.28 "
]
},
"execution_count": 23,
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
@@ -300,12 +382,12 @@
" avg_input_tokens * 0.012 / 1000 + avg_output_tokens * 0.016 / 1000 + 0.06 / 1000\n",
")\n",
"\n",
"models = pd.DataFrame(\n",
"costs = pd.DataFrame(\n",
" {\n",
" \"Model\": [\n",
" \"Llama2 7B (FT)\",\n",
" \"Llama 2 7B (finetuned)\",\n",
" \"GPT-3.5\",\n",
" \"GPT-3.5 (FT)\",\n",
" \"GPT-3.5 (finetuned)\",\n",
" \"GPT-4\",\n",
" ],\n",
" \"Cost to Classify One Recipe\": [\n",
@@ -317,11 +399,12 @@
" }\n",
")\n",
"\n",
"models[\"Cost to Classify Entire Dataset\"] = (\n",
" models[\"Cost to Classify One Recipe\"] * len(all_recipes)\n",
").round(2)\n",
"costs[\"Cost to Classify Entire Dataset\"] = (\n",
" costs[\"Cost to Classify One Recipe\"] * len(all_recipes)\n",
").map(lambda x: f\"{x:,.2f}\")\n",
"\n",
"models\n"
"\n",
"costs\n"
]
}
],

View File

@@ -4,29 +4,96 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"I have a model in `./models/run1/merged` that was trained on GPT-4's outputs to classify recipes. I need to figure out whether it does a good job at classifying recipes. I'll install dependencies first.\n"
"I have a model in `./models/run1/merged` that was trained on GPT-4's outputs to classify recipes. I need to figure out whether it does a good job at classifying recipes. I'll install dependencies first."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: vllm==0.1.3 in /usr/local/lib/python3.10/dist-packages (0.1.3)\n",
"Requirement already satisfied: pandas==2.0.3 in /usr/local/lib/python3.10/dist-packages (2.0.3)\n",
"Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (5.9.5)\n",
"Requirement already satisfied: ray>=2.5.1 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.6.3)\n",
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.1.99)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.24.4)\n",
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (2.0.1+cu118)\n",
"Requirement already satisfied: transformers>=4.31.0 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (4.33.0.dev0)\n",
"Requirement already satisfied: xformers>=0.0.19 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.0.21)\n",
"Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.101.1)\n",
"Requirement already satisfied: uvicorn in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (0.23.2)\n",
"Requirement already satisfied: pydantic<2 in /usr/local/lib/python3.10/dist-packages (from vllm==0.1.3) (1.10.12)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas==2.0.3) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas==2.0.3) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas==2.0.3) (2023.3)\n",
"Requirement already satisfied: typing-extensions>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<2->vllm==0.1.3) (4.7.1)\n",
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas==2.0.3) (1.16.0)\n",
"Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (8.1.7)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (3.9.0)\n",
"Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.18.0)\n",
"Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.0.5)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (23.1)\n",
"Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (4.24.1)\n",
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (6.0)\n",
"Requirement already satisfied: aiosignal in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.3.1)\n",
"Requirement already satisfied: frozenlist in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.4.0)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (2.28.1)\n",
"Requirement already satisfied: grpcio>=1.42.0 in /usr/local/lib/python3.10/dist-packages (from ray>=2.5.1->vllm==0.1.3) (1.57.0)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (1.11.1)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.0)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (3.1.2)\n",
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.0.0->vllm==0.1.3) (2.0.0)\n",
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (3.25.0)\n",
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=2.0.0->vllm==0.1.3) (15.0.7)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.15.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.16.4)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (2023.8.8)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (0.3.2)\n",
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->vllm==0.1.3) (4.66.1)\n",
"Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi->vllm==0.1.3) (0.27.0)\n",
"Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn->vllm==0.1.3) (0.14.0)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.15.1->transformers>=4.31.0->vllm==0.1.3) (2023.6.0)\n",
"Requirement already satisfied: anyio<5,>=3.4.0 in /usr/local/lib/python3.10/dist-packages (from starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (3.7.1)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.0.0->vllm==0.1.3) (2.1.2)\n",
"Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (23.1.0)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (2023.6.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.29.1)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->ray>=2.5.1->vllm==0.1.3) (0.8.10)\n",
"Requirement already satisfied: charset-normalizer<3,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm==0.1.3) (2.1.1)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm==0.1.3) (3.4)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm==0.1.3) (1.26.13)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->ray>=2.5.1->vllm==0.1.3) (2022.12.7)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.0.0->vllm==0.1.3) (1.2.1)\n",
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.3.0)\n",
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->vllm==0.1.3) (1.1.2)\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.10 -m pip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%%capture\n",
"%pip install vllm==0.1.3 pandas==2.0.3 joblib==1.3.2"
"%pip install vllm==0.1.3 pandas==2.0.3"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Remember I got a \"test.jsonl\" file from OpenPipe back in [./prepare.ipynb](./prepare.ipynb)? That's data from our dataset that we didn't use in training, so we can use it to check our model's performance.\n"
"Remember I got a \"test.jsonl\" file from OpenPipe back in [./prepare.ipynb](./prepare.ipynb)? That's data from our dataset that we didn't use in training, so we can use it to check our model's performance."
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -39,12 +106,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"During the training process Axolotl transformed our data into an instruction/response format known as the \"Alpaca format\" based on [the project that introduced it](https://github.com/tatsu-lab/stanford_alpaca). I need to transform my test data into the same format for best results.\n"
"During the training process Axolotl transformed our data into an instruction/response format known as the \"Alpaca format\" based on [the project that introduced it](https://github.com/tatsu-lab/stanford_alpaca). I need to transform my test data into the same format for best results."
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -80,27 +147,27 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Next up, I'll use [vLLM](https://vllm.readthedocs.io/en/latest/) to efficiently process all the prompts in our test data with our own model.\n"
"Next up, I'll use [vLLM](https://vllm.readthedocs.io/en/latest/) to efficiently process all the prompts in our test data with our own model."
]
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO 08-28 00:26:23 llm_engine.py:70] Initializing an LLM engine with config: model='./models/run1/merged', tokenizer='./models/run1/merged', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n",
"INFO 08-28 00:27:26 llm_engine.py:196] # GPU blocks: 3419, # CPU blocks: 512\n"
"INFO 08-25 03:58:49 llm_engine.py:70] Initializing an LLM engine with config: model='./models/run1/merged', tokenizer='./models/run1/merged', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)\n",
"INFO 08-25 03:59:40 llm_engine.py:196] # GPU blocks: 3419, # CPU blocks: 512\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████| 500/500 [00:37<00:00, 13.34it/s]"
"Processed prompts: 100%|██████████| 500/500 [00:37<00:00, 13.42it/s]"
]
},
{
@@ -144,12 +211,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok, we have our outputs! There are 5 categories we classify each recipe on, so let's check what percentage of the time our model's output matches GPT-4's. I'll write a quick eval function for that:\n"
"Ok, we have our outputs! There are 5 categories we classify each recipe on, so let's check what percentage of the time our model's output matches GPT-4's. I'll write a quick eval function for that:"
]
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 5,
"metadata": {},
"outputs": [
{
@@ -172,29 +239,20 @@
" return args_dict\n",
"\n",
"\n",
"test_data[\"output_parsed\"] = test_data[\"output\"].apply(parse_fn_call)\n",
"test_data[\"my_outputs_parsed\"] = test_data[\"my_outputs\"].apply(parse_fn_call)\n",
"\n",
"\n",
"def calculate_accuracy(row, labels_col):\n",
"def calculate_accuracy(row):\n",
" \"\"\"Calculate the fraction of my model's outputs that match the reference outputs\"\"\"\n",
" true_outputs = row[\"output_parsed\"]\n",
" labels_outputs = row[labels_col]\n",
"\n",
" # print(f\"true_outputs: {true_outputs}\")\n",
" # print(f\"my_outputs: {row[labels_col]}\")\n",
" true_outputs = parse_fn_call(row[\"output\"])\n",
" my_outputs = parse_fn_call(row[\"my_outputs\"])\n",
"\n",
" num_matching_outputs = 0\n",
" for key in true_outputs.keys():\n",
" if key in labels_outputs and true_outputs[key] == labels_outputs[key]:\n",
" if key in my_outputs and true_outputs[key] == my_outputs[key]:\n",
" num_matching_outputs += 1\n",
"\n",
" return num_matching_outputs / len(true_outputs)\n",
"\n",
"\n",
"test_data[\"accuracy\"] = test_data.apply(\n",
" calculate_accuracy, axis=1, labels_col=\"my_outputs_parsed\"\n",
")\n",
"test_data[\"accuracy\"] = test_data.apply(calculate_accuracy, axis=1)\n",
"\n",
"print(f\"Overall accuracy: {test_data['accuracy'].mean():.2f}\")\n"
]
@@ -203,293 +261,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"95% seems good! However, we don't have much to compare it to. Let's see how GPT-3.5 would do on the same task as a baseline. We'll use the same prompt we used with GPT-4 to generate the labels.\n"
"Not bad! However, there are still a few rows where the model outputs don't match. Let's take a closer look."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sample recipe:\n",
"--------------\n",
"Pan Gravy\n",
"\n",
"Ingredients:\n",
"- 1/3 cup all purpose flour\n",
"- 1/3 cup turkey drippings\n",
"- 3 cup water or broth\n",
"- 1/8 to 1/4 teaspoon salt\n",
"- 1/8 tsp pepper\n",
"\n",
"Directions:\n",
"- In a skillet or roasting pan, add flour to drippings; blend well.\n",
"- Cook over medium heat 2 to 3 minutes until smooth and light brown, stirring constantly.\n",
"- Add water; cook until mixture boils and thickens, stirring constantly.\n",
"- Stir in salt and pepper.\n",
"- *Flour and drippings can be decreased to 1/4 cup each for thinner gravy.\n",
"- *\n"
]
}
],
"source": [
"import json\n",
"\n",
"\n",
"def extract_recipe(row):\n",
" \"\"\"Extract the recipe from the instruction\"\"\"\n",
" return json.loads(row[\"instruction\"])[1][\"content\"]\n",
"\n",
"\n",
"recipes = test_data.apply(extract_recipe, axis=1)\n",
"print(f\"Sample recipe:\\n--------------\\n{recipes[0]}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Classifying first recipe:\n",
"------------------\n",
"{'has_non_fish_meat': False, 'requires_oven': False, 'requires_stove': True, 'cook_time_over_30_mins': False, 'main_dish': False}\n"
]
}
],
"source": [
"import joblib\n",
"import openai\n",
"import os\n",
"import dotenv\n",
"\n",
"dotenv.load_dotenv()\n",
"openai.api_key = os.getenv(\"OPENAI_API_KEY\")\n",
"\n",
"memory = joblib.Memory(\"./cache\", verbose=0)\n",
"\n",
"\n",
"@memory.cache\n",
"def classify_recipe_35(recipe: str):\n",
" completion = openai.ChatCompletion.create(\n",
" model=\"gpt-3.5-turbo\",\n",
" messages=[\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": \"Your goal is to classify a recipe along several dimensions.Pay attention to the instructions.\",\n",
" },\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": recipe,\n",
" },\n",
" ],\n",
" functions=[\n",
" {\n",
" \"name\": \"classify\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"has_non_fish_meat\": {\n",
" \"type\": \"boolean\",\n",
" \"description\": \"True if the recipe contains any meat or meat products (eg. chicken broth) besides fish\",\n",
" },\n",
" \"requires_oven\": {\n",
" \"type\": \"boolean\",\n",
" \"description\": \"True if the recipe requires an oven\",\n",
" },\n",
" \"requires_stove\": {\n",
" \"type\": \"boolean\",\n",
" \"description\": \"True if the recipe requires a stove\",\n",
" },\n",
" \"cook_time_over_30_mins\": {\n",
" \"type\": \"boolean\",\n",
" \"description\": \"True if the recipe takes over 30 minutes to prepare and cook, including waiting time\",\n",
" },\n",
" \"main_dish\": {\n",
" \"type\": \"boolean\",\n",
" \"description\": \"True if the recipe can be served as a main dish\",\n",
" },\n",
" },\n",
" \"required\": [\n",
" \"has_non_fish_meat\",\n",
" \"requires_oven\",\n",
" \"requires_stove\",\n",
" \"cook_time_over_30_mins\",\n",
" \"main_dish\",\n",
" ],\n",
" },\n",
" }\n",
" ],\n",
" function_call={\n",
" \"name\": \"classify\",\n",
" },\n",
" )\n",
" return json.loads(completion.choices[0].message.function_call.arguments)\n",
"\n",
"\n",
"print(\"Classifying first recipe:\\n------------------\")\n",
"print(classify_recipe_35(recipes[0]))\n"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"\u001b[A\n",
"100%|██████████| 500/500 [00:31<00:00, 15.77it/s]\n"
]
}
],
"source": [
"from tqdm import tqdm\n",
"\n",
"test_data[\"gpt_3.5\"] = [classify_recipe_35(r) for r in tqdm(recipes)]\n"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GPT-3.5 accuracy: 0.91\n"
]
}
],
"source": [
"test_data[\"gpt_3.5_accuracy\"] = test_data.apply(\n",
" calculate_accuracy, axis=1, labels_col=\"gpt_3.5\"\n",
")\n",
"\n",
"print(f\"GPT-3.5 accuracy: {test_data['gpt_3.5_accuracy'].mean():.2f}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And for completeness, let's try a fine-tuned GPT-3.5 model. You can find the fine-tuning code in [finetune-gpt-3.5.ipynb](./finetune-gpt-3.5.ipynb)\n"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'has_non_fish_meat': True,\n",
" 'requires_oven': False,\n",
" 'requires_stove': True,\n",
" 'cook_time_over_30_mins': False,\n",
" 'main_dish': False}"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@memory.cache\n",
"def classify_recipe_35_ft(recipe: str):\n",
" completion = openai.ChatCompletion.create(\n",
" model=\"ft:gpt-3.5-turbo-0613:openpipe::7rZpPqYn\",\n",
" messages=[\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": \"Your goal is to classify a recipe along several \"\n",
" \"dimensions.Pay attention to the instructions.\",\n",
" },\n",
" {\"role\": \"user\", \"content\": recipe},\n",
" ],\n",
" )\n",
"\n",
" return json.loads(completion.choices[0].message.content)\n",
"\n",
"\n",
"classify_recipe_35_ft(recipes[0])\n"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 500/500 [07:31<00:00, 1.11it/s]\n"
]
}
],
"source": [
"test_data[\"gpt_3.5_ft\"] = [classify_recipe_35_ft(r) for r in tqdm(recipes)]\n"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GPT-3.5 FT accuracy: 0.94\n"
]
}
],
"source": [
"test_data[\"gpt_3.5_ft_accuracy\"] = test_data.apply(\n",
" calculate_accuracy, axis=1, labels_col=\"gpt_3.5_ft\"\n",
")\n",
"\n",
"print(f\"GPT-3.5 FT accuracy: {test_data['gpt_3.5_ft_accuracy'].mean():.2f}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Not bad! However, there are still a few rows where the model outputs don't match. Let's take a closer look.\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 6,
"metadata": {},
"outputs": [
{
@@ -848,28 +625,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Looking at the outputs, it's clear that our model still makes some mistakes. But at the same time, there are plenty of examples like \"Rhonda's Butter Chess Pie\" where our model gets it right, even though GPT-4 got it wrong! And there are also cases like the \"Veggie Casserole\", where the \"right\" answer is truly ambiguous and really both answers are defensible.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A realistic point of comparison here might be GPT-3.5. Let's try to classify the same set of recipes using GPT-3.5 and see how it does. We'll use the same prompt that we used with GPT-4 to generate the initial training data.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Interested in cost/latency benchmarking? You can check out [./benchmarking.ipynb](./benchmarking.ipynb) for an overview of my findings!\n"
"Looking at the outputs, it's clear that our model still makes some mistakes. But at the same time, there are plenty of examples like \"Rhonda's Butter Chess Pie\" where our model gets it right, even though GPT-4 got it wrong! And there are also cases like the \"Veggie Casserole\", where the \"right\" answer is truly ambiguous and really both answers are defensible.\n",
"\n",
"Interested in cost/latency benchmarking? You can check out [./benchmarking.ipynb](./benchmarking.ipynb) for an overview of my findings!"
]
},
{

View File

@@ -1,207 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sample training data:\n",
"{'messages': [{'content': 'Your goal is to classify a recipe along several '\n",
" 'dimensions.Pay attention to the instructions.',\n",
" 'role': 'system'},\n",
" {'content': 'Homemade Salad Dressing\\n'\n",
" '\\n'\n",
" 'Ingredients:\\n'\n",
" \"- 1 pt. Hellmann's mayonnaise\\n\"\n",
" '- 1 pt. buttermilk\\n'\n",
" '- 1 tsp. Accent\\n'\n",
" '- 2 Tbsp. dry parsley\\n'\n",
" '- 2 pkg. low-calorie Italian salad dressing mix\\n'\n",
" '- 1 can jalapeno peppers or 4 oz. Jimenez green '\n",
" 'sauce\\n'\n",
" '\\n'\n",
" 'Directions:\\n'\n",
" '- Blend well in blender; store in refrigerator.\\n'\n",
" '- For dip, decrease liquid.',\n",
" 'role': 'user'},\n",
" {'content': '{\\n'\n",
" '\"has_non_fish_meat\": false,\\n'\n",
" '\"requires_oven\": false,\\n'\n",
" '\"requires_stove\": false,\\n'\n",
" '\"cook_time_over_30_mins\": false,\\n'\n",
" '\"main_dish\": false\\n'\n",
" '}',\n",
" 'role': 'assistant'}]}\n"
]
}
],
"source": [
"import pandas as pd\n",
"from pprint import pprint\n",
"import json\n",
"\n",
"df = pd.read_json(\"data/train.jsonl\", lines=True)\n",
"\n",
"training_data = []\n",
"for row in df.itertuples():\n",
" input = json.loads(row.instruction)\n",
" output = json.loads(row.output)\n",
"\n",
" output[\"content\"] = output[\"function_call\"][\"arguments\"]\n",
" del output[\"function_call\"]\n",
"\n",
" sample = {\"messages\": input.copy() + [output]}\n",
" training_data.append(sample)\n",
"\n",
"# save the training data to data/train-gpt3.5.jsonl\n",
"\n",
"with open(\"data/train-gpt3.5.jsonl\", \"w\") as f:\n",
" for sample in training_data:\n",
" f.write(json.dumps(sample) + \"\\n\")\n",
"\n",
"print(f\"Sample training data:\")\n",
"pprint(training_data[0])\n"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<File file id=file-faAdQ1KPxZH79ThW4Dbu4z1y at 0x7fa55db5c6d0> JSON: {\n",
" \"object\": \"file\",\n",
" \"id\": \"file-faAdQ1KPxZH79ThW4Dbu4z1y\",\n",
" \"purpose\": \"fine-tune\",\n",
" \"filename\": \"recipe-classification\",\n",
" \"bytes\": 4210831,\n",
" \"created_at\": 1693000959,\n",
" \"status\": \"uploaded\",\n",
" \"status_details\": null\n",
"}"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import os\n",
"import openai\n",
"\n",
"import dotenv\n",
"\n",
"dotenv.load_dotenv()\n",
"\n",
"openai.api_key = os.getenv(\"OPENAI_API_KEY\")\n",
"\n",
"openai.File.create(\n",
" file=open(\"data/train-gpt3.5.jsonl\", \"rb\"),\n",
" purpose=\"fine-tune\",\n",
" user_provided_filename=\"recipe-classification\",\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<OpenAIObject list at 0x7fa55dbf6930> JSON: {\n",
" \"object\": \"list\",\n",
" \"data\": [\n",
" {\n",
" \"object\": \"file\",\n",
" \"id\": \"file-faAdQ1KPxZH79ThW4Dbu4z1y\",\n",
" \"purpose\": \"fine-tune\",\n",
" \"filename\": \"recipe-classification\",\n",
" \"bytes\": 4210831,\n",
" \"created_at\": 1693000959,\n",
" \"status\": \"processed\",\n",
" \"status_details\": null\n",
" }\n",
" ]\n",
"}"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"openai.File.list()\n"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<FineTuningJob fine_tuning.job id=ftjob-EjjLxmj9P8apwPRk5s2NPSeB at 0x7fa55ddc4360> JSON: {\n",
" \"object\": \"fine_tuning.job\",\n",
" \"id\": \"ftjob-EjjLxmj9P8apwPRk5s2NPSeB\",\n",
" \"model\": \"gpt-3.5-turbo-0613\",\n",
" \"created_at\": 1693001190,\n",
" \"finished_at\": null,\n",
" \"fine_tuned_model\": null,\n",
" \"organization_id\": \"org-jRz4nVPMoeGHWL5nVR3Mb0kp\",\n",
" \"result_files\": [],\n",
" \"status\": \"created\",\n",
" \"validation_file\": null,\n",
" \"training_file\": \"file-faAdQ1KPxZH79ThW4Dbu4z1y\",\n",
" \"hyperparameters\": {\n",
" \"n_epochs\": 3\n",
" },\n",
" \"trained_tokens\": null\n",
"}"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"openai.FineTuningJob.create(\n",
" training_file=\"file-faAdQ1KPxZH79ThW4Dbu4z1y\", model=\"gpt-3.5-turbo\"\n",
")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -6,24 +6,61 @@
"source": [
"In this notebook I'm using the OpenPipe client to capture a set of calls to the OpenAI API.\n",
"\n",
"For this example I'll blithely throw engineering best practices to the wind and use the notebook itself to manage dependencies. 😁\n"
"For this example I'll blithely throw engineering best practices to the wind and use the notebook itself to manage dependencies. 😁"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: openpipe==3.0.3 in /usr/local/lib/python3.10/dist-packages (3.0.3)\n",
"Requirement already satisfied: python-dotenv==1.0.0 in /usr/local/lib/python3.10/dist-packages (1.0.0)\n",
"Requirement already satisfied: joblib==1.3.2 in /usr/local/lib/python3.10/dist-packages (1.3.2)\n",
"Requirement already satisfied: attrs<24.0.0,>=23.1.0 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (23.1.0)\n",
"Requirement already satisfied: httpx<0.25.0,>=0.24.1 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (0.24.1)\n",
"Requirement already satisfied: openai<0.28.0,>=0.27.8 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (0.27.9)\n",
"Requirement already satisfied: python-dateutil<3.0.0,>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (2.8.2)\n",
"Requirement already satisfied: toml<0.11.0,>=0.10.2 in /usr/local/lib/python3.10/dist-packages (from openpipe==3.0.3) (0.10.2)\n",
"Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (2022.12.7)\n",
"Requirement already satisfied: httpcore<0.18.0,>=0.15.0 in /usr/local/lib/python3.10/dist-packages (from httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (0.17.3)\n",
"Requirement already satisfied: idna in /usr/local/lib/python3.10/dist-packages (from httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (3.4)\n",
"Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (1.3.0)\n",
"Requirement already satisfied: requests>=2.20 in /usr/local/lib/python3.10/dist-packages (from openai<0.28.0,>=0.27.8->openpipe==3.0.3) (2.28.1)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from openai<0.28.0,>=0.27.8->openpipe==3.0.3) (4.66.1)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from openai<0.28.0,>=0.27.8->openpipe==3.0.3) (3.8.5)\n",
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil<3.0.0,>=2.8.2->openpipe==3.0.3) (1.16.0)\n",
"Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.10/dist-packages (from httpcore<0.18.0,>=0.15.0->httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (0.14.0)\n",
"Requirement already satisfied: anyio<5.0,>=3.0 in /usr/local/lib/python3.10/dist-packages (from httpcore<0.18.0,>=0.15.0->httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (3.7.1)\n",
"Requirement already satisfied: charset-normalizer<3,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.20->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (2.1.1)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.20->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (1.26.13)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (4.0.3)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (1.9.2)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (1.4.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->openai<0.28.0,>=0.27.8->openpipe==3.0.3) (1.3.1)\n",
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5.0,>=3.0->httpcore<0.18.0,>=0.15.0->httpx<0.25.0,>=0.24.1->openpipe==3.0.3) (1.1.2)\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.10 -m pip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%%capture\n",
"%pip install openpipe==3.0.3 python-dotenv==1.0.0 datasets==2.14.4"
"%pip install openpipe==3.0.3 python-dotenv==1.0.0 joblib==1.3.2 datasets==2.14.4"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When working with remote datasets (or any data, really), it's a good idea to visually inspect some samples to make sure it looks like you expect. I'll print a recipe.\n"
"When working with remote datasets (or any data, really), it's a good idea to visually inspect some samples to make sure it looks like you expect. I'll print a recipe."
]
},
{
@@ -95,16 +132,15 @@
"Mm, delicious. Anyway, we need to generate a training dataset. We'll call GPT-4 on each of our examples.\n",
"\n",
"In this case, I'll ask GPT-4 to classify each recipe along 5 dimensions:\n",
"\n",
"- has_non_fish_meat\n",
"- requires_oven\n",
"- requires_stove\n",
"- cook_time_over_30_mins\n",
"- main_dish\n",
" - has_non_fish_meat\n",
" - requires_oven\n",
" - requires_stove\n",
" - cook_time_over_30_mins\n",
" - main_dish\n",
"\n",
"That looks like a pretty random list, but there's actually an important unifying thread: I'm looking for meals that my pescatarian brother/co-founder can make in his kitchen-less, near-window-less basement apartment in San Francisco! (If you haven't tried to get an apartment in SF you probably think I'm joking 😂.)\n",
"\n",
"I'll use [OpenPipe](https://github.com/openpipe/openpipe) to track the API calls and form a training dataset. To follow along you'll need to create a free OpenPipe account, then copy your API key from https://app.openpipe.ai/project/settings into a file called `.env`. You can see an example in [./.env.example](./.env.example).\n"
"I'll use [OpenPipe](https://github.com/openpipe/openpipe) to track the API calls and form a training dataset. To follow along you'll need to create a free OpenPipe account, then copy your API key from https://app.openpipe.ai/project/settings into a file called `.env`. You can see an example in [./.env.example](./.env.example)."
]
},
{
@@ -204,7 +240,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"That's working, so I'll go ahead and classify all 5000 recipes with GPT-4. Using GPT-4 for this is slowwww and costs about $40. The model I'm fine-tuning will be much faster -- we'll see if we can make it as good!\n"
"That's working, so I'll go ahead and classify all 5000 recipes with GPT-4. Using GPT-4 for this is slowwww and costs about $40. The model I'm fine-tuning will be much faster -- we'll see if we can make it as good!"
]
},
{
@@ -284,11 +320,11 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok, now that my recipes are classified I'll download the training data.\n",
"Ok, now that my recipes are classified I'll download the training data. \n",
"\n",
"Next up I'll train the model -- check out [./train.ipynb](./train.ipynb) for details! Just go to https://app.openpipe.ai/request-logs, select all the logs you created, and click \"Export\". The default 10% testing split is fine for this dataset size.\n",
"\n",
"I got two files from that: `train.jsonl` and `test.jsonl`. I moved both of them into this repository under `./data/`.\n"
"I got two files from that: `train.jsonl` and `test.jsonl`. I moved both of them into this repository under `./data/`."
]
}
],

Some files were not shown because too many files have changed in this diff Show More