From f59150ff5b96dfc1fed5d99a8e34bed58b6c6628 Mon Sep 17 00:00:00 2001 From: arcticfly <41524992+arcticfly@users.noreply.github.com> Date: Wed, 23 Aug 2023 16:13:21 -0700 Subject: [PATCH] Add flow for fine-tuning (#183) * Remove unnecessary dataset code * Fix jump on row selection * Add FineTuneButton * Add model slug to modal * Add fine tunes to schema * Remove dataset routers * Remove more dataset-specific code * Remove more data code * Fix horizontal scroll bar jumping * Add fine tunes page * Actually create the fine tune entry * Add beta modal * Require beta for fine tunes and request logs * Send user to waitlist link * control beta features in .env variable * Combine migration files * Show beta features in app shell * Clear selected log ids last when closing fine tune modal * Remove ModalCloseButton from BetaModal * Remove unused import * Change timestamps to camelCase --- app/@types/nextjs-routes.d.ts | 3 +- app/Dockerfile | 2 +- app/package.json | 1 + .../migration.sql | 48 ++++++ app/prisma/schema.prisma | 43 ++++- app/src/components/InputDropdown.tsx | 15 +- app/src/components/datasets/DatasetCard.tsx | 112 ------------ .../datasets/DatasetEntriesPaginator.tsx | 16 -- .../datasets/DatasetEntriesTable.tsx | 31 ---- .../DatasetHeaderButtons.tsx | 26 --- .../GenerateDataModal.tsx | 128 -------------- app/src/components/datasets/TableRow.tsx | 13 -- .../components/fineTunes/FineTunesTable.tsx | 65 +++++++ app/src/components/nav/AppShell.tsx | 44 +++-- app/src/components/nav/BetaModal.tsx | 67 ++++++++ .../components/requestLogs/FineTuneButton.tsx | 161 ++++++++++++++++++ .../requestLogs/LoggedCallsTable.tsx | 4 +- app/src/components/requestLogs/TableRow.tsx | 27 +-- app/src/env.mjs | 6 +- app/src/pages/dashboard/index.tsx | 2 +- app/src/pages/data/[id].tsx | 97 ----------- app/src/pages/data/index.tsx | 49 ------ app/src/pages/fine-tunes/index.tsx | 18 ++ app/src/pages/request-logs/index.tsx | 62 +++---- .../autogenerateDatasetEntries.ts | 113 ------------ app/src/server/api/root.router.ts | 6 +- .../api/routers/datasetEntries.router.ts | 145 ---------------- app/src/server/api/routers/datasets.router.ts | 88 ---------- .../server/api/routers/fineTunes.router.ts | 113 ++++++++++++ app/src/utils/accessControl.ts | 27 --- app/src/utils/hooks.ts | 38 ++--- pnpm-lock.yaml | 7 + 32 files changed, 617 insertions(+), 960 deletions(-) create mode 100644 app/prisma/migrations/20230822234224_add_fine_tunes/migration.sql delete mode 100644 app/src/components/datasets/DatasetCard.tsx delete mode 100644 app/src/components/datasets/DatasetEntriesPaginator.tsx delete mode 100644 app/src/components/datasets/DatasetEntriesTable.tsx delete mode 100644 app/src/components/datasets/DatasetHeaderButtons/DatasetHeaderButtons.tsx delete mode 100644 app/src/components/datasets/DatasetHeaderButtons/GenerateDataModal.tsx delete mode 100644 app/src/components/datasets/TableRow.tsx create mode 100644 app/src/components/fineTunes/FineTunesTable.tsx create mode 100644 app/src/components/nav/BetaModal.tsx create mode 100644 app/src/components/requestLogs/FineTuneButton.tsx delete mode 100644 app/src/pages/data/[id].tsx delete mode 100644 app/src/pages/data/index.tsx create mode 100644 app/src/pages/fine-tunes/index.tsx delete mode 100644 app/src/server/api/autogenerate/autogenerateDatasetEntries.ts delete mode 100644 app/src/server/api/routers/datasetEntries.router.ts delete mode 100644 app/src/server/api/routers/datasets.router.ts create mode 100644 app/src/server/api/routers/fineTunes.router.ts diff --git a/app/@types/nextjs-routes.d.ts b/app/@types/nextjs-routes.d.ts index 3c9bf4d..4e6ca91 100644 --- a/app/@types/nextjs-routes.d.ts +++ b/app/@types/nextjs-routes.d.ts @@ -19,10 +19,9 @@ declare module "nextjs-routes" { | DynamicRoute<"/api/v1/[...trpc]", { "trpc": string[] }> | StaticRoute<"/api/v1/openapi"> | StaticRoute<"/dashboard"> - | DynamicRoute<"/data/[id]", { "id": string }> - | StaticRoute<"/data"> | DynamicRoute<"/experiments/[experimentSlug]", { "experimentSlug": string }> | StaticRoute<"/experiments"> + | StaticRoute<"/fine-tunes"> | StaticRoute<"/"> | DynamicRoute<"/invitations/[invitationToken]", { "invitationToken": string }> | StaticRoute<"/project/settings"> diff --git a/app/Dockerfile b/app/Dockerfile index 9c98aea..0ca2846 100644 --- a/app/Dockerfile +++ b/app/Dockerfile @@ -23,7 +23,7 @@ ARG NEXT_PUBLIC_SOCKET_URL ARG NEXT_PUBLIC_HOST ARG NEXT_PUBLIC_SENTRY_DSN ARG SENTRY_AUTH_TOKEN -ARG NEXT_PUBLIC_FF_SHOW_LOGGED_CALLS +ARG NEXT_PUBLIC_FF_SHOW_BETA_FEATURES WORKDIR /code COPY --from=deps /code/node_modules ./node_modules diff --git a/app/package.json b/app/package.json index c3a8f1f..cfc4c5f 100644 --- a/app/package.json +++ b/app/package.json @@ -60,6 +60,7 @@ "framer-motion": "^10.12.17", "gpt-tokens": "^1.0.10", "graphile-worker": "^0.13.0", + "human-id": "^4.0.0", "immer": "^10.0.2", "isolated-vm": "^4.5.0", "json-schema-to-typescript": "^13.0.2", diff --git a/app/prisma/migrations/20230822234224_add_fine_tunes/migration.sql b/app/prisma/migrations/20230822234224_add_fine_tunes/migration.sql new file mode 100644 index 0000000..a8c7872 --- /dev/null +++ b/app/prisma/migrations/20230822234224_add_fine_tunes/migration.sql @@ -0,0 +1,48 @@ +/* + Warnings: + + - You are about to drop the column `input` on the `DatasetEntry` table. All the data in the column will be lost. + - You are about to drop the column `output` on the `DatasetEntry` table. All the data in the column will be lost. + - Added the required column `loggedCallId` to the `DatasetEntry` table without a default value. This is not possible if the table is not empty. + +*/ +-- AlterTable +ALTER TABLE "DatasetEntry" DROP COLUMN "input", +DROP COLUMN "output", +ADD COLUMN "loggedCallId" UUID NOT NULL; + +-- AddForeignKey +ALTER TABLE "DatasetEntry" ADD CONSTRAINT "DatasetEntry_loggedCallId_fkey" FOREIGN KEY ("loggedCallId") REFERENCES "LoggedCall"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AlterTable +ALTER TABLE "LoggedCallModelResponse" ALTER COLUMN "cost" SET DATA TYPE DOUBLE PRECISION; + +-- CreateEnum +CREATE TYPE "FineTuneStatus" AS ENUM ('PENDING', 'TRAINING', 'AWAITING_DEPLOYMENT', 'DEPLOYING', 'DEPLOYED', 'ERROR'); + +-- CreateTable +CREATE TABLE "FineTune" ( + "id" UUID NOT NULL, + "slug" TEXT NOT NULL, + "baseModel" TEXT NOT NULL, + "status" "FineTuneStatus" NOT NULL DEFAULT 'PENDING', + "trainingStartedAt" TIMESTAMP(3), + "trainingFinishedAt" TIMESTAMP(3), + "deploymentStartedAt" TIMESTAMP(3), + "deploymentFinishedAt" TIMESTAMP(3), + "datasetId" UUID NOT NULL, + "projectId" UUID NOT NULL, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" TIMESTAMP(3) NOT NULL, + + CONSTRAINT "FineTune_pkey" PRIMARY KEY ("id") +); + +-- CreateIndex +CREATE UNIQUE INDEX "FineTune_slug_key" ON "FineTune"("slug"); + +-- AddForeignKey +ALTER TABLE "FineTune" ADD CONSTRAINT "FineTune_datasetId_fkey" FOREIGN KEY ("datasetId") REFERENCES "Dataset"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "FineTune" ADD CONSTRAINT "FineTune_projectId_fkey" FOREIGN KEY ("projectId") REFERENCES "Project"("id") ON DELETE CASCADE ON UPDATE CASCADE; \ No newline at end of file diff --git a/app/prisma/schema.prisma b/app/prisma/schema.prisma index dd12489..7f05bdf 100644 --- a/app/prisma/schema.prisma +++ b/app/prisma/schema.prisma @@ -181,6 +181,7 @@ model Dataset { name String datasetEntries DatasetEntry[] + fineTunes FineTune[] projectId String @db.Uuid project Project @relation(fields: [projectId], references: [id], onDelete: Cascade) @@ -192,8 +193,8 @@ model Dataset { model DatasetEntry { id String @id @default(uuid()) @db.Uuid - input String - output String? + loggedCallId String @db.Uuid + loggedCall LoggedCall @relation(fields: [loggedCallId], references: [id], onDelete: Cascade) datasetId String @db.Uuid dataset Dataset? @relation(fields: [datasetId], references: [id], onDelete: Cascade) @@ -216,6 +217,7 @@ model Project { experiments Experiment[] datasets Dataset[] loggedCalls LoggedCall[] + fineTunes FineTune[] apiKeys ApiKey[] } @@ -276,8 +278,9 @@ model LoggedCall { projectId String @db.Uuid project Project? @relation(fields: [projectId], references: [id], onDelete: Cascade) - model String? - tags LoggedCallTag[] + model String? + tags LoggedCallTag[] + datasetEntries DatasetEntry[] createdAt DateTime @default(now()) updatedAt DateTime @updatedAt @@ -312,7 +315,7 @@ model LoggedCallModelResponse { outputTokens Int? finishReason String? completionId String? - cost Decimal? @db.Decimal(18, 12) + cost Float? // The LoggedCall that created this LoggedCallModelResponse originalLoggedCallId String @unique @db.Uuid @@ -427,3 +430,33 @@ model VerificationToken { @@unique([identifier, token]) } + +enum FineTuneStatus { + PENDING + TRAINING + AWAITING_DEPLOYMENT + DEPLOYING + DEPLOYED + ERROR +} + +model FineTune { + id String @id @default(uuid()) @db.Uuid + + slug String @unique + baseModel String + status FineTuneStatus @default(PENDING) + trainingStartedAt DateTime? + trainingFinishedAt DateTime? + deploymentStartedAt DateTime? + deploymentFinishedAt DateTime? + + datasetId String @db.Uuid + dataset Dataset @relation(fields: [datasetId], references: [id], onDelete: Cascade) + + projectId String @db.Uuid + project Project @relation(fields: [projectId], references: [id], onDelete: Cascade) + + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt +} diff --git a/app/src/components/InputDropdown.tsx b/app/src/components/InputDropdown.tsx index 61af9c6..eade73b 100644 --- a/app/src/components/InputDropdown.tsx +++ b/app/src/components/InputDropdown.tsx @@ -11,6 +11,7 @@ import { Button, Text, useDisclosure, + type InputGroupProps, } from "@chakra-ui/react"; import { FiChevronDown } from "react-icons/fi"; @@ -20,15 +21,25 @@ type InputDropdownProps = { options: ReadonlyArray; selectedOption: T; onSelect: (option: T) => void; + inputGroupProps?: InputGroupProps; }; -const InputDropdown = ({ options, selectedOption, onSelect }: InputDropdownProps) => { +const InputDropdown = ({ + options, + selectedOption, + onSelect, + inputGroupProps, +}: InputDropdownProps) => { const popover = useDisclosure(); return ( - + { - return ( - - - - - {dataset.name} - - - - - - Created {formatTimePast(dataset.createdAt)} - - Updated {formatTimePast(dataset.updatedAt)} - - - - ); -}; - -const CountLabel = ({ label, count }: { label: string; count: number }) => { - return ( - - - {label} - - - {count} - - - ); -}; - -export const NewDatasetCard = () => { - const router = useRouter(); - const selectedProjectId = useAppStore((s) => s.selectedProjectId); - const createMutation = api.datasets.create.useMutation(); - const [createDataset, isLoading] = useHandledAsyncCallback(async () => { - const newDataset = await createMutation.mutateAsync({ projectId: selectedProjectId ?? "" }); - await router.push({ pathname: "/data/[id]", query: { id: newDataset.id } }); - }, [createMutation, router, selectedProjectId]); - - return ( - - - - - New Dataset - - - - ); -}; - -export const DatasetCardSkeleton = () => ( - - - - - - - -); diff --git a/app/src/components/datasets/DatasetEntriesPaginator.tsx b/app/src/components/datasets/DatasetEntriesPaginator.tsx deleted file mode 100644 index b3d2f4e..0000000 --- a/app/src/components/datasets/DatasetEntriesPaginator.tsx +++ /dev/null @@ -1,16 +0,0 @@ -import { type StackProps } from "@chakra-ui/react"; - -import { useDatasetEntries } from "~/utils/hooks"; -import Paginator from "../Paginator"; - -const DatasetEntriesPaginator = (props: StackProps) => { - const { data } = useDatasetEntries(); - - if (!data) return null; - - const { count } = data; - - return ; -}; - -export default DatasetEntriesPaginator; diff --git a/app/src/components/datasets/DatasetEntriesTable.tsx b/app/src/components/datasets/DatasetEntriesTable.tsx deleted file mode 100644 index a90973b..0000000 --- a/app/src/components/datasets/DatasetEntriesTable.tsx +++ /dev/null @@ -1,31 +0,0 @@ -import { type StackProps, VStack, Table, Th, Tr, Thead, Tbody, Text } from "@chakra-ui/react"; -import { useDatasetEntries } from "~/utils/hooks"; -import TableRow from "./TableRow"; -import DatasetEntriesPaginator from "./DatasetEntriesPaginator"; - -const DatasetEntriesTable = (props: StackProps) => { - const { data } = useDatasetEntries(); - - return ( - - - - - - - - - {data?.entries.map((entry) => )} -
InputOutput
- {(!data || data.entries.length) === 0 ? ( - - No entries found - - ) : ( - - )} -
- ); -}; - -export default DatasetEntriesTable; diff --git a/app/src/components/datasets/DatasetHeaderButtons/DatasetHeaderButtons.tsx b/app/src/components/datasets/DatasetHeaderButtons/DatasetHeaderButtons.tsx deleted file mode 100644 index 840b8e0..0000000 --- a/app/src/components/datasets/DatasetHeaderButtons/DatasetHeaderButtons.tsx +++ /dev/null @@ -1,26 +0,0 @@ -import { Button, HStack, useDisclosure } from "@chakra-ui/react"; -import { BiImport } from "react-icons/bi"; -import { BsStars } from "react-icons/bs"; - -import { GenerateDataModal } from "./GenerateDataModal"; - -export const DatasetHeaderButtons = () => { - const generateModalDisclosure = useDisclosure(); - - return ( - <> - - - - - - - ); -}; diff --git a/app/src/components/datasets/DatasetHeaderButtons/GenerateDataModal.tsx b/app/src/components/datasets/DatasetHeaderButtons/GenerateDataModal.tsx deleted file mode 100644 index a748ac0..0000000 --- a/app/src/components/datasets/DatasetHeaderButtons/GenerateDataModal.tsx +++ /dev/null @@ -1,128 +0,0 @@ -import { - Modal, - ModalBody, - ModalCloseButton, - ModalContent, - ModalHeader, - ModalOverlay, - ModalFooter, - Text, - HStack, - VStack, - Icon, - NumberInput, - NumberInputField, - NumberInputStepper, - NumberIncrementStepper, - NumberDecrementStepper, - Button, -} from "@chakra-ui/react"; -import { BsStars } from "react-icons/bs"; -import { useState } from "react"; -import { useDataset, useHandledAsyncCallback } from "~/utils/hooks"; -import { api } from "~/utils/api"; -import AutoResizeTextArea from "~/components/AutoResizeTextArea"; - -export const GenerateDataModal = ({ - isOpen, - onClose, -}: { - isOpen: boolean; - onClose: () => void; -}) => { - const utils = api.useContext(); - - const datasetId = useDataset().data?.id; - - const [numToGenerate, setNumToGenerate] = useState(20); - const [inputDescription, setInputDescription] = useState( - "Each input should contain an email body. Half of the emails should contain event details, and the other half should not.", - ); - const [outputDescription, setOutputDescription] = useState( - `Each output should contain "true" or "false", where "true" indicates that the email contains event details.`, - ); - - const generateEntriesMutation = api.datasetEntries.autogenerateEntries.useMutation(); - - const [generateEntries, generateEntriesInProgress] = useHandledAsyncCallback(async () => { - if (!inputDescription || !outputDescription || !numToGenerate || !datasetId) return; - await generateEntriesMutation.mutateAsync({ - datasetId, - inputDescription, - outputDescription, - numToGenerate, - }); - await utils.datasetEntries.list.invalidate(); - onClose(); - }, [ - generateEntriesMutation, - onClose, - inputDescription, - outputDescription, - numToGenerate, - datasetId, - ]); - - return ( - - - - - - - Generate Data - - - - - - - Number of Rows: - setNumToGenerate(parseInt(valueString) || 0)} - value={numToGenerate} - w="24" - > - - - - - - - - - Input Description: - setInputDescription(e.target.value)} - placeholder="Each input should contain..." - /> - - - Output Description (optional): - setOutputDescription(e.target.value)} - placeholder="The output should contain..." - /> - - - - - - - - - ); -}; diff --git a/app/src/components/datasets/TableRow.tsx b/app/src/components/datasets/TableRow.tsx deleted file mode 100644 index 08ad2bb..0000000 --- a/app/src/components/datasets/TableRow.tsx +++ /dev/null @@ -1,13 +0,0 @@ -import { Td, Tr } from "@chakra-ui/react"; -import { type DatasetEntry } from "@prisma/client"; - -const TableRow = ({ entry }: { entry: DatasetEntry }) => { - return ( - - {entry.input} - {entry.output} - - ); -}; - -export default TableRow; diff --git a/app/src/components/fineTunes/FineTunesTable.tsx b/app/src/components/fineTunes/FineTunesTable.tsx new file mode 100644 index 0000000..8a3f897 --- /dev/null +++ b/app/src/components/fineTunes/FineTunesTable.tsx @@ -0,0 +1,65 @@ +import { Card, Table, Thead, Tr, Th, Tbody, Td, VStack, Icon, Text } from "@chakra-ui/react"; +import { FaTable } from "react-icons/fa"; +import { type FineTuneStatus } from "@prisma/client"; + +import dayjs from "~/utils/dayjs"; +import { useFineTunes } from "~/utils/hooks"; + +const FineTunesTable = ({}) => { + const { data } = useFineTunes(); + + const fineTunes = data?.fineTunes || []; + + return ( + + {fineTunes.length ? ( + + + + + + + + + + + + {fineTunes.map((fineTune) => { + return ( + + + + + + + + ); + })} + +
IDCreated AtBase ModelDataset SizeStatus
{fineTune.slug}{dayjs(fineTune.createdAt).format("MMMM D h:mm A")}{fineTune.baseModel}{fineTune.dataset._count.datasetEntries} + {fineTune.status} +
+ ) : ( + + + + No Fine Tunes Found + + + )} +
+ ); +}; + +export default FineTunesTable; + +const getStatusColor = (status: FineTuneStatus) => { + switch (status) { + case "DEPLOYED": + return "green.500"; + case "ERROR": + return "red.500"; + default: + return "yellow.500"; + } +}; diff --git a/app/src/components/nav/AppShell.tsx b/app/src/components/nav/AppShell.tsx index 8600528..7606ad2 100644 --- a/app/src/components/nav/AppShell.tsx +++ b/app/src/components/nav/AppShell.tsx @@ -15,12 +15,14 @@ import Head from "next/head"; import Link from "next/link"; import { BsGearFill, BsGithub, BsPersonCircle } from "react-icons/bs"; import { IoStatsChartOutline } from "react-icons/io5"; -import { RiHome3Line, RiDatabase2Line, RiFlaskLine } from "react-icons/ri"; +import { RiHome3Line, RiFlaskLine } from "react-icons/ri"; +import { FaRobot } from "react-icons/fa"; import { signIn, useSession } from "next-auth/react"; import { env } from "~/env.mjs"; import ProjectMenu from "./ProjectMenu"; import NavSidebarOption from "./NavSidebarOption"; import IconLink from "./IconLink"; +import { BetaModal } from "./BetaModal"; const Divider = () => ; @@ -71,21 +73,10 @@ const NavSidebar = () => { - {env.NEXT_PUBLIC_FF_SHOW_LOGGED_CALLS && ( - <> - - - - )} + + + - {env.NEXT_PUBLIC_SHOW_DATA && ( - - )} - - {title ? `${title} | OpenPipe` : "OpenPipe"} - - - - {children} - - + <> + + + {title ? `${title} | OpenPipe` : "OpenPipe"} + + + + {children} + + + {requireBeta && !env.NEXT_PUBLIC_FF_SHOW_BETA_FEATURES && } + ); } diff --git a/app/src/components/nav/BetaModal.tsx b/app/src/components/nav/BetaModal.tsx new file mode 100644 index 0000000..bf66342 --- /dev/null +++ b/app/src/components/nav/BetaModal.tsx @@ -0,0 +1,67 @@ +import { + Button, + Modal, + ModalBody, + ModalContent, + ModalFooter, + ModalHeader, + ModalOverlay, + VStack, + Text, + HStack, + Icon, + Link, +} from "@chakra-ui/react"; +import { BsStars } from "react-icons/bs"; +import { useRouter } from "next/router"; +import { useSession } from "next-auth/react"; + +export const BetaModal = () => { + const router = useRouter(); + const session = useSession(); + + const email = session.data?.user.email ?? ""; + + return ( + + + + + + + Beta-Only Feature + + + + + + This feature is currently in beta. To receive early access to beta-only features, join + the waitlist. You'll receive an email at {email} when you're approved. + + + + + + + + + + + + ); +}; diff --git a/app/src/components/requestLogs/FineTuneButton.tsx b/app/src/components/requestLogs/FineTuneButton.tsx new file mode 100644 index 0000000..2143140 --- /dev/null +++ b/app/src/components/requestLogs/FineTuneButton.tsx @@ -0,0 +1,161 @@ +import { useState, useEffect } from "react"; +import { + Modal, + ModalOverlay, + ModalContent, + ModalHeader, + ModalCloseButton, + ModalBody, + ModalFooter, + HStack, + VStack, + Icon, + Text, + Button, + useDisclosure, + type UseDisclosureReturn, + Input, +} from "@chakra-ui/react"; +import { FaRobot } from "react-icons/fa"; +import humanId from "human-id"; +import { useRouter } from "next/router"; + +import { useHandledAsyncCallback } from "~/utils/hooks"; +import { api } from "~/utils/api"; +import { useAppStore } from "~/state/store"; +import ActionButton from "./ActionButton"; +import InputDropdown from "../InputDropdown"; +import { FiChevronDown } from "react-icons/fi"; + +const SUPPORTED_BASE_MODELS = ["llama2-7b", "llama2-13b", "llama2-70b", "gpt-3.5-turbo"]; + +const FineTuneButton = () => { + const selectedLogIds = useAppStore((s) => s.selectedLogs.selectedLogIds); + + const disclosure = useDisclosure(); + + return ( + <> + + + + ); +}; + +export default FineTuneButton; + +const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => { + const selectedProjectId = useAppStore((s) => s.selectedProjectId); + const selectedLogIds = useAppStore((s) => s.selectedLogs.selectedLogIds); + const clearSelectedLogIds = useAppStore((s) => s.selectedLogs.clearSelectedLogIds); + + const [selectedBaseModel, setSelectedBaseModel] = useState(SUPPORTED_BASE_MODELS[0]); + const [modelSlug, setModelSlug] = useState(humanId({ separator: "-", capitalize: false })); + + useEffect(() => { + if (disclosure.isOpen) { + setSelectedBaseModel(SUPPORTED_BASE_MODELS[0]); + setModelSlug(humanId({ separator: "-", capitalize: false })); + } + }, [disclosure.isOpen]); + + const utils = api.useContext(); + const router = useRouter(); + + const createFineTuneMutation = api.fineTunes.create.useMutation(); + + const [createFineTune, creationInProgress] = useHandledAsyncCallback(async () => { + if (!selectedProjectId || !modelSlug || !selectedBaseModel || !selectedLogIds.size) return; + await createFineTuneMutation.mutateAsync({ + projectId: selectedProjectId, + slug: modelSlug, + baseModel: selectedBaseModel, + selectedLogIds: Array.from(selectedLogIds), + }); + + await utils.fineTunes.list.invalidate(); + await router.push({ pathname: "/fine-tunes" }); + clearSelectedLogIds(); + disclosure.onClose(); + }, [createFineTuneMutation, selectedProjectId, selectedLogIds, modelSlug, selectedBaseModel]); + + return ( + + + + + + + Fine Tune + + + + + + + We'll train on the {selectedLogIds.size} logs you've selected. + + + + + Model ID: + + setModelSlug(e.target.value)} + w={48} + placeholder="unique-id" + onKeyDown={(e) => { + // If the user types anything other than a-z, A-Z, or 0-9, replace it with - + if (!/[a-zA-Z0-9]/.test(e.key)) { + e.preventDefault(); + setModelSlug((s) => s && `${s}-`); + } + }} + /> + + + + Base model: + + setSelectedBaseModel(option)} + inputGroupProps={{ w: 48 }} + /> + + + + + + + + + + + + + + ); +}; diff --git a/app/src/components/requestLogs/LoggedCallsTable.tsx b/app/src/components/requestLogs/LoggedCallsTable.tsx index 913ee28..5d68400 100644 --- a/app/src/components/requestLogs/LoggedCallsTable.tsx +++ b/app/src/components/requestLogs/LoggedCallsTable.tsx @@ -10,7 +10,7 @@ export default function LoggedCallsTable() { return ( - + {loggedCalls?.calls?.map((loggedCall) => { return ( @@ -25,7 +25,7 @@ export default function LoggedCallsTable() { setExpandedRow(loggedCall.id); } }} - isSimple + showOptions /> ); })} diff --git a/app/src/components/requestLogs/TableRow.tsx b/app/src/components/requestLogs/TableRow.tsx index a024d48..85fac83 100644 --- a/app/src/components/requestLogs/TableRow.tsx +++ b/app/src/components/requestLogs/TableRow.tsx @@ -14,10 +14,9 @@ import { Text, Checkbox, } from "@chakra-ui/react"; -import dayjs from "dayjs"; -import relativeTime from "dayjs/plugin/relativeTime"; import Link from "next/link"; +import dayjs from "~/utils/dayjs"; import { type RouterOutputs } from "~/utils/api"; import { FormattedJson } from "./FormattedJson"; import { useAppStore } from "~/state/store"; @@ -25,11 +24,9 @@ import { useIsClientRehydrated, useLoggedCalls, useTagNames } from "~/utils/hook import { useMemo } from "react"; import { StaticColumnKeys } from "~/state/columnVisiblitySlice"; -dayjs.extend(relativeTime); - type LoggedCall = RouterOutputs["loggedCalls"]["list"]["calls"][0]; -export const TableHeader = ({ isSimple }: { isSimple?: boolean }) => { +export const TableHeader = ({ showOptions }: { showOptions?: boolean }) => { const matchingLogIds = useLoggedCalls().data?.matchingLogIds; const selectedLogIds = useAppStore((s) => s.selectedLogs.selectedLogIds); const addAll = useAppStore((s) => s.selectedLogs.addSelectedLogIds); @@ -46,7 +43,7 @@ export const TableHeader = ({ isSimple }: { isSimple?: boolean }) => { return ( - {isSimple && ( + {showOptions && ( @@ -147,9 +148,9 @@ export const TableRow = ({ )} - {tagNames - ?.filter((tagName) => visibleColumns.has(tagName)) - .map((tagName) => )} + {visibleTagNames.map((tagName) => ( + + ))} {visibleColumns.has(StaticColumnKeys.DURATION) && ( -
void; - isSimple?: boolean; + showOptions?: boolean; }) => { const isError = loggedCall.modelResponse?.statusCode !== 200; const requestedAt = dayjs(loggedCall.requestedAt).format("MMMM D h:mm A"); @@ -101,6 +98,10 @@ export const TableRow = ({ const tagNames = useTagNames().data; const visibleColumns = useAppStore((s) => s.columnVisibility.visibleColumns); + const visibleTagNames = useMemo(() => { + return tagNames?.filter((tagName) => visibleColumns.has(tagName)) ?? []; + }, [tagNames, visibleColumns]); + const isClientRehydrated = useIsClientRehydrated(); if (!isClientRehydrated) return null; @@ -115,7 +116,7 @@ export const TableRow = ({ }} fontSize="sm" > - {isSimple && ( + {showOptions && ( toggleChecked(loggedCall.id)} /> {loggedCall.tags[tagName]}{loggedCall.tags[tagName]} {loggedCall.cacheHit ? ( @@ -172,7 +173,7 @@ export const TableRow = ({ )}
+ diff --git a/app/src/env.mjs b/app/src/env.mjs index 2f6fada..433b95b 100644 --- a/app/src/env.mjs +++ b/app/src/env.mjs @@ -46,8 +46,7 @@ export const env = createEnv({ NEXT_PUBLIC_SOCKET_URL: z.string().url().default("http://localhost:3318"), NEXT_PUBLIC_HOST: z.string().url().default("http://localhost:3000"), NEXT_PUBLIC_SENTRY_DSN: z.string().optional(), - NEXT_PUBLIC_SHOW_DATA: z.string().optional(), - NEXT_PUBLIC_FF_SHOW_LOGGED_CALLS: z.string().optional(), + NEXT_PUBLIC_FF_SHOW_BETA_FEATURES: z.string().optional(), }, /** @@ -62,7 +61,6 @@ export const env = createEnv({ NEXT_PUBLIC_POSTHOG_KEY: process.env.NEXT_PUBLIC_POSTHOG_KEY, NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL, NEXT_PUBLIC_HOST: process.env.NEXT_PUBLIC_HOST, - NEXT_PUBLIC_SHOW_DATA: process.env.NEXT_PUBLIC_SHOW_DATA, GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID, GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET, REPLICATE_API_TOKEN: process.env.REPLICATE_API_TOKEN, @@ -70,7 +68,7 @@ export const env = createEnv({ NEXT_PUBLIC_SENTRY_DSN: process.env.NEXT_PUBLIC_SENTRY_DSN, SENTRY_AUTH_TOKEN: process.env.SENTRY_AUTH_TOKEN, OPENPIPE_API_KEY: process.env.OPENPIPE_API_KEY, - NEXT_PUBLIC_FF_SHOW_LOGGED_CALLS: process.env.NEXT_PUBLIC_FF_SHOW_LOGGED_CALLS, + NEXT_PUBLIC_FF_SHOW_BETA_FEATURES: process.env.NEXT_PUBLIC_FF_SHOW_BETA_FEATURES, SENDER_EMAIL: process.env.SENDER_EMAIL, SMTP_HOST: process.env.SMTP_HOST, SMTP_PORT: process.env.SMTP_PORT, diff --git a/app/src/pages/dashboard/index.tsx b/app/src/pages/dashboard/index.tsx index aa790b4..7f58582 100644 --- a/app/src/pages/dashboard/index.tsx +++ b/app/src/pages/dashboard/index.tsx @@ -33,7 +33,7 @@ export default function Dashboard() { ); return ( - + Dashboard diff --git a/app/src/pages/data/[id].tsx b/app/src/pages/data/[id].tsx deleted file mode 100644 index 6401630..0000000 --- a/app/src/pages/data/[id].tsx +++ /dev/null @@ -1,97 +0,0 @@ -import { - Box, - Breadcrumb, - BreadcrumbItem, - Center, - Flex, - Icon, - Input, - VStack, -} from "@chakra-ui/react"; -import Link from "next/link"; - -import { useRouter } from "next/router"; -import { useState, useEffect } from "react"; -import { RiDatabase2Line } from "react-icons/ri"; -import AppShell from "~/components/nav/AppShell"; -import { api } from "~/utils/api"; -import { useDataset, useHandledAsyncCallback } from "~/utils/hooks"; -import DatasetEntriesTable from "~/components/datasets/DatasetEntriesTable"; -import { DatasetHeaderButtons } from "~/components/datasets/DatasetHeaderButtons/DatasetHeaderButtons"; -import PageHeaderContainer from "~/components/nav/PageHeaderContainer"; -import ProjectBreadcrumbContents from "~/components/nav/ProjectBreadcrumbContents"; - -export default function Dataset() { - const router = useRouter(); - const utils = api.useContext(); - - const dataset = useDataset(); - const datasetId = router.query.id as string; - - const [name, setName] = useState(dataset.data?.name || ""); - useEffect(() => { - setName(dataset.data?.name || ""); - }, [dataset.data?.name]); - - const updateMutation = api.datasets.update.useMutation(); - const [onSaveName] = useHandledAsyncCallback(async () => { - if (name && name !== dataset.data?.name && dataset.data?.id) { - await updateMutation.mutateAsync({ - id: dataset.data.id, - updates: { name: name }, - }); - await Promise.all([utils.datasets.list.invalidate(), utils.datasets.get.invalidate()]); - } - }, [updateMutation, dataset.data?.id, dataset.data?.name, name]); - - if (!dataset.isLoading && !dataset.data) { - return ( - -
-
Dataset not found 😕
-
-
- ); - } - - return ( - - - - - - - - - - - Datasets - - - - - setName(e.target.value)} - onBlur={onSaveName} - borderWidth={1} - borderColor="transparent" - fontSize={16} - px={0} - minW={{ base: 100, lg: 300 }} - flex={1} - _hover={{ borderColor: "gray.300" }} - _focus={{ borderColor: "blue.500", outline: "none" }} - /> - - - - - - {datasetId && } - - - - ); -} diff --git a/app/src/pages/data/index.tsx b/app/src/pages/data/index.tsx deleted file mode 100644 index fb55389..0000000 --- a/app/src/pages/data/index.tsx +++ /dev/null @@ -1,49 +0,0 @@ -import { SimpleGrid, Icon, Breadcrumb, BreadcrumbItem, Flex } from "@chakra-ui/react"; -import AppShell from "~/components/nav/AppShell"; -import { RiDatabase2Line } from "react-icons/ri"; -import { - DatasetCard, - DatasetCardSkeleton, - NewDatasetCard, -} from "~/components/datasets/DatasetCard"; -import PageHeaderContainer from "~/components/nav/PageHeaderContainer"; -import ProjectBreadcrumbContents from "~/components/nav/ProjectBreadcrumbContents"; -import { useDatasets } from "~/utils/hooks"; - -export default function DatasetsPage() { - const datasets = useDatasets(); - - return ( - - - - - - - - - Datasets - - - - - - - {datasets.data && !datasets.isLoading ? ( - datasets?.data?.map((dataset) => ( - - )) - ) : ( - <> - - - - - )} - - - ); -} diff --git a/app/src/pages/fine-tunes/index.tsx b/app/src/pages/fine-tunes/index.tsx new file mode 100644 index 0000000..6d8b6ec --- /dev/null +++ b/app/src/pages/fine-tunes/index.tsx @@ -0,0 +1,18 @@ +import { Text, VStack, Divider } from "@chakra-ui/react"; +import FineTunesTable from "~/components/fineTunes/FineTunesTable"; + +import AppShell from "~/components/nav/AppShell"; + +export default function FineTunes() { + return ( + + + + Fine Tunes + + + + + + ); +} diff --git a/app/src/pages/request-logs/index.tsx b/app/src/pages/request-logs/index.tsx index d2e2916..1cd5a4b 100644 --- a/app/src/pages/request-logs/index.tsx +++ b/app/src/pages/request-logs/index.tsx @@ -1,5 +1,5 @@ import { useState } from "react"; -import { Text, VStack, Divider, HStack } from "@chakra-ui/react"; +import { Text, VStack, Divider, HStack, Box } from "@chakra-ui/react"; import AppShell from "~/components/nav/AppShell"; import LoggedCallTable from "~/components/requestLogs/LoggedCallsTable"; @@ -10,6 +10,7 @@ import { RiFlaskLine } from "react-icons/ri"; import { FiFilter } from "react-icons/fi"; import LogFilters from "~/components/requestLogs/LogFilters/LogFilters"; import ColumnVisiblityDropdown from "~/components/requestLogs/ColumnVisiblityDropdown"; +import FineTuneButton from "~/components/requestLogs/FineTuneButton"; export default function LoggedCalls() { const selectedLogIds = useAppStore((s) => s.selectedLogs.selectedLogIds); @@ -17,34 +18,37 @@ export default function LoggedCalls() { const [filtersShown, setFiltersShown] = useState(true); return ( - - - - Request Logs - - - - - { - setFiltersShown(!filtersShown); - }} - label={filtersShown ? "Hide Filters" : "Show Filters"} - icon={FiFilter} - /> - { - console.log("experimenting with these ids", selectedLogIds); - }} - label="Experiment" - icon={RiFlaskLine} - isDisabled={selectedLogIds.size === 0} - /> - - {filtersShown && } - - - + + + + + Request Logs + + + + + { + console.log("experimenting with these ids", selectedLogIds); + }} + label="Experiment" + icon={RiFlaskLine} + isDisabled={selectedLogIds.size === 0} + /> + + { + setFiltersShown(!filtersShown); + }} + label={filtersShown ? "Hide Filters" : "Show Filters"} + icon={FiFilter} + /> + + {filtersShown && } + + + + ); } diff --git a/app/src/server/api/autogenerate/autogenerateDatasetEntries.ts b/app/src/server/api/autogenerate/autogenerateDatasetEntries.ts deleted file mode 100644 index 1fb18a3..0000000 --- a/app/src/server/api/autogenerate/autogenerateDatasetEntries.ts +++ /dev/null @@ -1,113 +0,0 @@ -import { type ChatCompletion } from "openai/resources/chat"; -import { openai } from "../../utils/openai"; -import { isAxiosError } from "./utils"; -import { type APIResponse } from "openai/core"; -import { sleep } from "~/server/utils/sleep"; - -const MAX_AUTO_RETRIES = 50; -const MIN_DELAY = 500; // milliseconds -const MAX_DELAY = 15000; // milliseconds - -function calculateDelay(numPreviousTries: number): number { - const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries)); - const jitter = Math.random() * baseDelay; - return baseDelay + jitter; -} - -const getCompletionWithBackoff = async ( - getCompletion: () => Promise>, -) => { - let completion; - let tries = 0; - while (tries < MAX_AUTO_RETRIES) { - try { - completion = await getCompletion(); - break; - } catch (e) { - if (isAxiosError(e)) { - console.error(e?.response?.data?.error?.message); - } else { - await sleep(calculateDelay(tries)); - console.error(e); - } - } - tries++; - } - return completion; -}; -// TODO: Add seeds to ensure batches don't contain duplicate data -const MAX_BATCH_SIZE = 5; - -export const autogenerateDatasetEntries = async ( - numToGenerate: number, - inputDescription: string, - outputDescription: string, -): Promise<{ input: string; output: string }[]> => { - const batchSizes = Array.from({ length: Math.ceil(numToGenerate / MAX_BATCH_SIZE) }, (_, i) => - i === Math.ceil(numToGenerate / MAX_BATCH_SIZE) - 1 && numToGenerate % MAX_BATCH_SIZE - ? numToGenerate % MAX_BATCH_SIZE - : MAX_BATCH_SIZE, - ); - - const getCompletion = (batchSize: number) => - openai.chat.completions.create({ - model: "gpt-4", - messages: [ - { - role: "system", - content: `The user needs ${batchSize} rows of data, each with an input and an output.\n---\n The input should follow these requirements: ${inputDescription}\n---\n The output should follow these requirements: ${outputDescription}`, - }, - ], - functions: [ - { - name: "add_list_of_data", - description: "Add a list of data to the database", - parameters: { - type: "object", - properties: { - rows: { - type: "array", - description: "The rows of data that match the description", - items: { - type: "object", - properties: { - input: { - type: "string", - description: "The input for this row", - }, - output: { - type: "string", - description: "The output for this row", - }, - }, - }, - }, - }, - }, - }, - ], - - function_call: { name: "add_list_of_data" }, - temperature: 0.5, - openpipe: { - tags: { - prompt_id: "autogenerateDatasetEntries", - }, - }, - }); - - const completionCallbacks = batchSizes.map((batchSize) => - getCompletionWithBackoff(() => getCompletion(batchSize)), - ); - - const completions = await Promise.all(completionCallbacks); - - const rows = completions.flatMap((completion) => { - const parsed = JSON.parse( - completion?.choices[0]?.message?.function_call?.arguments ?? "{rows: []}", - ) as { rows: { input: string; output: string }[] }; - return parsed.rows; - }); - - return rows; -}; diff --git a/app/src/server/api/root.router.ts b/app/src/server/api/root.router.ts index 0d0f51e..dd88fa8 100644 --- a/app/src/server/api/root.router.ts +++ b/app/src/server/api/root.router.ts @@ -6,11 +6,10 @@ import { scenarioVariantCellsRouter } from "./routers/scenarioVariantCells.route import { scenarioVarsRouter } from "./routers/scenarioVariables.router"; import { evaluationsRouter } from "./routers/evaluations.router"; import { worldChampsRouter } from "./routers/worldChamps.router"; -import { datasetsRouter } from "./routers/datasets.router"; -import { datasetEntries } from "./routers/datasetEntries.router"; import { projectsRouter } from "./routers/projects.router"; import { dashboardRouter } from "./routers/dashboard.router"; import { loggedCallsRouter } from "./routers/loggedCalls.router"; +import { fineTunesRouter } from "./routers/fineTunes.router"; import { usersRouter } from "./routers/users.router"; import { adminJobsRouter } from "./routers/adminJobs.router"; @@ -27,11 +26,10 @@ export const appRouter = createTRPCRouter({ scenarioVars: scenarioVarsRouter, evaluations: evaluationsRouter, worldChamps: worldChampsRouter, - datasets: datasetsRouter, - datasetEntries: datasetEntries, projects: projectsRouter, dashboard: dashboardRouter, loggedCalls: loggedCallsRouter, + fineTunes: fineTunesRouter, users: usersRouter, adminJobs: adminJobsRouter, }); diff --git a/app/src/server/api/routers/datasetEntries.router.ts b/app/src/server/api/routers/datasetEntries.router.ts deleted file mode 100644 index a79782b..0000000 --- a/app/src/server/api/routers/datasetEntries.router.ts +++ /dev/null @@ -1,145 +0,0 @@ -import { z } from "zod"; -import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc"; -import { prisma } from "~/server/db"; -import { requireCanModifyDataset, requireCanViewDataset } from "~/utils/accessControl"; -import { autogenerateDatasetEntries } from "../autogenerate/autogenerateDatasetEntries"; - -export const datasetEntries = createTRPCRouter({ - list: protectedProcedure - .input(z.object({ datasetId: z.string(), page: z.number(), pageSize: z.number() })) - .query(async ({ input, ctx }) => { - await requireCanViewDataset(input.datasetId, ctx); - - const { datasetId, page, pageSize } = input; - - const entries = await prisma.datasetEntry.findMany({ - where: { - datasetId, - }, - orderBy: { createdAt: "desc" }, - skip: (page - 1) * pageSize, - take: pageSize, - }); - - const count = await prisma.datasetEntry.count({ - where: { - datasetId, - }, - }); - - return { - entries, - count, - }; - }), - createOne: protectedProcedure - .input( - z.object({ - datasetId: z.string(), - input: z.string(), - output: z.string().optional(), - }), - ) - .mutation(async ({ input, ctx }) => { - await requireCanModifyDataset(input.datasetId, ctx); - - return await prisma.datasetEntry.create({ - data: { - datasetId: input.datasetId, - input: input.input, - output: input.output, - }, - }); - }), - - autogenerateEntries: protectedProcedure - .input( - z.object({ - datasetId: z.string(), - numToGenerate: z.number(), - inputDescription: z.string(), - outputDescription: z.string(), - }), - ) - .mutation(async ({ input, ctx }) => { - await requireCanModifyDataset(input.datasetId, ctx); - - const dataset = await prisma.dataset.findUnique({ - where: { - id: input.datasetId, - }, - }); - - if (!dataset) { - throw new Error(`Dataset with id ${input.datasetId} does not exist`); - } - - const entries = await autogenerateDatasetEntries( - input.numToGenerate, - input.inputDescription, - input.outputDescription, - ); - - const createdEntries = await prisma.datasetEntry.createMany({ - data: entries.map((entry) => ({ - datasetId: input.datasetId, - input: entry.input, - output: entry.output, - })), - }); - - return createdEntries; - }), - - delete: protectedProcedure - .input(z.object({ id: z.string() })) - .mutation(async ({ input, ctx }) => { - const datasetId = ( - await prisma.datasetEntry.findUniqueOrThrow({ - where: { id: input.id }, - }) - ).datasetId; - - await requireCanModifyDataset(datasetId, ctx); - - return await prisma.datasetEntry.delete({ - where: { - id: input.id, - }, - }); - }), - - update: protectedProcedure - .input( - z.object({ - id: z.string(), - updates: z.object({ - input: z.string(), - output: z.string().optional(), - }), - }), - ) - .mutation(async ({ input, ctx }) => { - const existing = await prisma.datasetEntry.findUnique({ - where: { - id: input.id, - }, - }); - - if (!existing) { - throw new Error(`dataEntry with id ${input.id} does not exist`); - } - - await requireCanModifyDataset(existing.datasetId, ctx); - - return await prisma.datasetEntry.update({ - where: { - id: input.id, - }, - data: { - input: input.updates.input, - output: input.updates.output, - }, - }); - }), -}); diff --git a/app/src/server/api/routers/datasets.router.ts b/app/src/server/api/routers/datasets.router.ts deleted file mode 100644 index 92579a2..0000000 --- a/app/src/server/api/routers/datasets.router.ts +++ /dev/null @@ -1,88 +0,0 @@ -import { z } from "zod"; -import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc"; -import { prisma } from "~/server/db"; -import { - requireCanModifyDataset, - requireCanModifyProject, - requireCanViewDataset, - requireCanViewProject, -} from "~/utils/accessControl"; - -export const datasetsRouter = createTRPCRouter({ - list: protectedProcedure - .input(z.object({ projectId: z.string() })) - .query(async ({ input, ctx }) => { - await requireCanViewProject(input.projectId, ctx); - - const datasets = await prisma.dataset.findMany({ - where: { - projectId: input.projectId, - }, - orderBy: { - createdAt: "desc", - }, - include: { - _count: { - select: { datasetEntries: true }, - }, - }, - }); - - return datasets; - }), - - get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => { - await requireCanViewDataset(input.id, ctx); - return await prisma.dataset.findFirstOrThrow({ - where: { id: input.id }, - include: { - project: true, - }, - }); - }), - - create: protectedProcedure - .input(z.object({ projectId: z.string() })) - .mutation(async ({ input, ctx }) => { - await requireCanModifyProject(input.projectId, ctx); - - const numDatasets = await prisma.dataset.count({ - where: { - projectId: input.projectId, - }, - }); - - return await prisma.dataset.create({ - data: { - name: `Dataset ${numDatasets + 1}`, - projectId: input.projectId, - }, - }); - }), - - update: protectedProcedure - .input(z.object({ id: z.string(), updates: z.object({ name: z.string() }) })) - .mutation(async ({ input, ctx }) => { - await requireCanModifyDataset(input.id, ctx); - return await prisma.dataset.update({ - where: { - id: input.id, - }, - data: { - name: input.updates.name, - }, - }); - }), - - delete: protectedProcedure - .input(z.object({ id: z.string() })) - .mutation(async ({ input, ctx }) => { - await requireCanModifyDataset(input.id, ctx); - - await prisma.dataset.delete({ - where: { - id: input.id, - }, - }); - }), -}); diff --git a/app/src/server/api/routers/fineTunes.router.ts b/app/src/server/api/routers/fineTunes.router.ts new file mode 100644 index 0000000..cff976e --- /dev/null +++ b/app/src/server/api/routers/fineTunes.router.ts @@ -0,0 +1,113 @@ +import { z } from "zod"; +import { v4 as uuidv4 } from "uuid"; +import { type Prisma } from "@prisma/client"; + +import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc"; +import { prisma } from "~/server/db"; +import { requireCanViewProject, requireCanModifyProject } from "~/utils/accessControl"; +import { error, success } from "~/utils/errorHandling/standardResponses"; + +export const fineTunesRouter = createTRPCRouter({ + list: protectedProcedure + .input( + z.object({ + projectId: z.string(), + page: z.number(), + pageSize: z.number(), + }), + ) + .query(async ({ input, ctx }) => { + const { projectId, page, pageSize } = input; + + await requireCanViewProject(projectId, ctx); + + const fineTunes = await prisma.fineTune.findMany({ + where: { + projectId, + }, + include: { + dataset: { + include: { + _count: { + select: { + datasetEntries: true, + }, + }, + }, + }, + }, + orderBy: { createdAt: "asc" }, + skip: (page - 1) * pageSize, + take: pageSize, + }); + + const count = await prisma.fineTune.count({ + where: { + projectId, + }, + }); + + return { + fineTunes, + count, + }; + }), + create: protectedProcedure + .input( + z.object({ + projectId: z.string(), + selectedLogIds: z.array(z.string()), + slug: z.string(), + baseModel: z.string(), + }), + ) + .mutation(async ({ input, ctx }) => { + await requireCanModifyProject(input.projectId, ctx); + + const existingFineTune = await prisma.fineTune.findFirst({ + where: { + slug: input.slug, + }, + }); + + if (existingFineTune) { + return error("A fine tune with that slug already exists"); + } + + const newDatasetId = uuidv4(); + + const datasetEntriesToCreate: Prisma.DatasetEntryCreateManyDatasetInput[] = + input.selectedLogIds.map((loggedCallId) => ({ + loggedCallId, + })); + + await prisma.$transaction([ + prisma.dataset.create({ + data: { + id: newDatasetId, + name: input.slug, + project: { + connect: { + id: input.projectId, + }, + }, + datasetEntries: { + createMany: { + data: datasetEntriesToCreate, + }, + }, + }, + }), + prisma.fineTune.create({ + data: { + projectId: input.projectId, + slug: input.slug, + baseModel: input.baseModel, + datasetId: newDatasetId, + }, + }), + ]); + + return success(); + }), +}); diff --git a/app/src/utils/accessControl.ts b/app/src/utils/accessControl.ts index a09764d..5ca314c 100644 --- a/app/src/utils/accessControl.ts +++ b/app/src/utils/accessControl.ts @@ -78,33 +78,6 @@ export const requireCanModifyProject = async (projectId: string, ctx: TRPCContex } }; -export const requireCanViewDataset = async (datasetId: string, ctx: TRPCContext) => { - ctx.markAccessControlRun(); - - const dataset = await prisma.dataset.findFirst({ - where: { - id: datasetId, - project: { - projectUsers: { - some: { - role: { in: [ProjectUserRole.ADMIN, ProjectUserRole.MEMBER] }, - userId: ctx.session?.user.id, - }, - }, - }, - }, - }); - - if (!dataset) { - throw new TRPCError({ code: "UNAUTHORIZED" }); - } -}; - -export const requireCanModifyDataset = async (datasetId: string, ctx: TRPCContext) => { - // Right now all users who can view a dataset can also modify it - await requireCanViewDataset(datasetId, ctx); -}; - export const requireCanViewExperiment = (experimentId: string, ctx: TRPCContext): Promise => { // Right now all experiments are publicly viewable, so this is a no-op. ctx.markAccessControlRun(); diff --git a/app/src/utils/hooks.ts b/app/src/utils/hooks.ts index 5e508f7..4628a00 100644 --- a/app/src/utils/hooks.ts +++ b/app/src/utils/hooks.ts @@ -26,34 +26,6 @@ export const useExperimentAccess = () => { return useExperiment().data?.access ?? { canView: false, canModify: false }; }; -export const useDatasets = () => { - const selectedProjectId = useAppStore((state) => state.selectedProjectId); - return api.datasets.list.useQuery( - { projectId: selectedProjectId ?? "" }, - { enabled: !!selectedProjectId }, - ); -}; - -export const useDataset = () => { - const router = useRouter(); - const dataset = api.datasets.get.useQuery( - { id: router.query.id as string }, - { enabled: !!router.query.id }, - ); - - return dataset; -}; - -export const useDatasetEntries = () => { - const dataset = useDataset(); - const { page, pageSize } = usePageParams(); - - return api.datasetEntries.list.useQuery( - { datasetId: dataset.data?.id ?? "", page, pageSize }, - { enabled: dataset.data?.id != null }, - ); -}; - type AsyncFunction = (...args: T) => Promise; export function useHandledAsyncCallback( @@ -206,6 +178,16 @@ export const useTagNames = () => { ); }; +export const useFineTunes = () => { + const selectedProjectId = useAppStore((state) => state.selectedProjectId); + const { page, pageSize } = usePageParams(); + + return api.fineTunes.list.useQuery( + { projectId: selectedProjectId ?? "", page, pageSize }, + { enabled: !!selectedProjectId }, + ); +}; + export const useIsClientRehydrated = () => { const isRehydrated = useAppStore((state) => state.isRehydrated); const [isMounted, setIsMounted] = useState(false); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index fd5fce1..73723e0 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -116,6 +116,9 @@ importers: graphile-worker: specifier: ^0.13.0 version: 0.13.0 + human-id: + specifier: ^4.0.0 + version: 4.0.0 immer: specifier: ^10.0.2 version: 10.0.2 @@ -5942,6 +5945,10 @@ packages: - supports-color dev: false + /human-id@4.0.0: + resolution: {integrity: sha512-pui0xZRgeAlaRt0I9r8N2pNlbNmluvn71EfjKRpM7jOpZbuHe5mm76r67gcprjw/Nd+GpvB9C3OlTbh7ZKLg7A==} + dev: false + /humanize-ms@1.2.1: resolution: {integrity: sha512-Fl70vYtsAFb/C06PTS9dZBo7ihau+Tu/DNCk/OyHhea07S+aeMWpFFkUaXRa8fI+ScZbEI8dfSxwY7gxZ9SAVQ==} dependencies: