Add flow for fine-tuning (#183)
* Remove unnecessary dataset code * Fix jump on row selection * Add FineTuneButton * Add model slug to modal * Add fine tunes to schema * Remove dataset routers * Remove more dataset-specific code * Remove more data code * Fix horizontal scroll bar jumping * Add fine tunes page * Actually create the fine tune entry * Add beta modal * Require beta for fine tunes and request logs * Send user to waitlist link * control beta features in .env variable * Combine migration files * Show beta features in app shell * Clear selected log ids last when closing fine tune modal * Remove ModalCloseButton from BetaModal * Remove unused import * Change timestamps to camelCase
This commit is contained in:
3
app/@types/nextjs-routes.d.ts
vendored
3
app/@types/nextjs-routes.d.ts
vendored
@@ -19,10 +19,9 @@ declare module "nextjs-routes" {
|
||||
| DynamicRoute<"/api/v1/[...trpc]", { "trpc": string[] }>
|
||||
| StaticRoute<"/api/v1/openapi">
|
||||
| StaticRoute<"/dashboard">
|
||||
| DynamicRoute<"/data/[id]", { "id": string }>
|
||||
| StaticRoute<"/data">
|
||||
| DynamicRoute<"/experiments/[experimentSlug]", { "experimentSlug": string }>
|
||||
| StaticRoute<"/experiments">
|
||||
| StaticRoute<"/fine-tunes">
|
||||
| StaticRoute<"/">
|
||||
| DynamicRoute<"/invitations/[invitationToken]", { "invitationToken": string }>
|
||||
| StaticRoute<"/project/settings">
|
||||
|
||||
@@ -23,7 +23,7 @@ ARG NEXT_PUBLIC_SOCKET_URL
|
||||
ARG NEXT_PUBLIC_HOST
|
||||
ARG NEXT_PUBLIC_SENTRY_DSN
|
||||
ARG SENTRY_AUTH_TOKEN
|
||||
ARG NEXT_PUBLIC_FF_SHOW_LOGGED_CALLS
|
||||
ARG NEXT_PUBLIC_FF_SHOW_BETA_FEATURES
|
||||
|
||||
WORKDIR /code
|
||||
COPY --from=deps /code/node_modules ./node_modules
|
||||
|
||||
@@ -60,6 +60,7 @@
|
||||
"framer-motion": "^10.12.17",
|
||||
"gpt-tokens": "^1.0.10",
|
||||
"graphile-worker": "^0.13.0",
|
||||
"human-id": "^4.0.0",
|
||||
"immer": "^10.0.2",
|
||||
"isolated-vm": "^4.5.0",
|
||||
"json-schema-to-typescript": "^13.0.2",
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- You are about to drop the column `input` on the `DatasetEntry` table. All the data in the column will be lost.
|
||||
- You are about to drop the column `output` on the `DatasetEntry` table. All the data in the column will be lost.
|
||||
- Added the required column `loggedCallId` to the `DatasetEntry` table without a default value. This is not possible if the table is not empty.
|
||||
|
||||
*/
|
||||
-- AlterTable
|
||||
ALTER TABLE "DatasetEntry" DROP COLUMN "input",
|
||||
DROP COLUMN "output",
|
||||
ADD COLUMN "loggedCallId" UUID NOT NULL;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "DatasetEntry" ADD CONSTRAINT "DatasetEntry_loggedCallId_fkey" FOREIGN KEY ("loggedCallId") REFERENCES "LoggedCall"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "LoggedCallModelResponse" ALTER COLUMN "cost" SET DATA TYPE DOUBLE PRECISION;
|
||||
|
||||
-- CreateEnum
|
||||
CREATE TYPE "FineTuneStatus" AS ENUM ('PENDING', 'TRAINING', 'AWAITING_DEPLOYMENT', 'DEPLOYING', 'DEPLOYED', 'ERROR');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "FineTune" (
|
||||
"id" UUID NOT NULL,
|
||||
"slug" TEXT NOT NULL,
|
||||
"baseModel" TEXT NOT NULL,
|
||||
"status" "FineTuneStatus" NOT NULL DEFAULT 'PENDING',
|
||||
"trainingStartedAt" TIMESTAMP(3),
|
||||
"trainingFinishedAt" TIMESTAMP(3),
|
||||
"deploymentStartedAt" TIMESTAMP(3),
|
||||
"deploymentFinishedAt" TIMESTAMP(3),
|
||||
"datasetId" UUID NOT NULL,
|
||||
"projectId" UUID NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
|
||||
CONSTRAINT "FineTune_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "FineTune_slug_key" ON "FineTune"("slug");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "FineTune" ADD CONSTRAINT "FineTune_datasetId_fkey" FOREIGN KEY ("datasetId") REFERENCES "Dataset"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "FineTune" ADD CONSTRAINT "FineTune_projectId_fkey" FOREIGN KEY ("projectId") REFERENCES "Project"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
@@ -181,6 +181,7 @@ model Dataset {
|
||||
|
||||
name String
|
||||
datasetEntries DatasetEntry[]
|
||||
fineTunes FineTune[]
|
||||
|
||||
projectId String @db.Uuid
|
||||
project Project @relation(fields: [projectId], references: [id], onDelete: Cascade)
|
||||
@@ -192,8 +193,8 @@ model Dataset {
|
||||
model DatasetEntry {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
|
||||
input String
|
||||
output String?
|
||||
loggedCallId String @db.Uuid
|
||||
loggedCall LoggedCall @relation(fields: [loggedCallId], references: [id], onDelete: Cascade)
|
||||
|
||||
datasetId String @db.Uuid
|
||||
dataset Dataset? @relation(fields: [datasetId], references: [id], onDelete: Cascade)
|
||||
@@ -216,6 +217,7 @@ model Project {
|
||||
experiments Experiment[]
|
||||
datasets Dataset[]
|
||||
loggedCalls LoggedCall[]
|
||||
fineTunes FineTune[]
|
||||
apiKeys ApiKey[]
|
||||
}
|
||||
|
||||
@@ -276,8 +278,9 @@ model LoggedCall {
|
||||
projectId String @db.Uuid
|
||||
project Project? @relation(fields: [projectId], references: [id], onDelete: Cascade)
|
||||
|
||||
model String?
|
||||
tags LoggedCallTag[]
|
||||
model String?
|
||||
tags LoggedCallTag[]
|
||||
datasetEntries DatasetEntry[]
|
||||
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
@@ -312,7 +315,7 @@ model LoggedCallModelResponse {
|
||||
outputTokens Int?
|
||||
finishReason String?
|
||||
completionId String?
|
||||
cost Decimal? @db.Decimal(18, 12)
|
||||
cost Float?
|
||||
|
||||
// The LoggedCall that created this LoggedCallModelResponse
|
||||
originalLoggedCallId String @unique @db.Uuid
|
||||
@@ -427,3 +430,33 @@ model VerificationToken {
|
||||
|
||||
@@unique([identifier, token])
|
||||
}
|
||||
|
||||
enum FineTuneStatus {
|
||||
PENDING
|
||||
TRAINING
|
||||
AWAITING_DEPLOYMENT
|
||||
DEPLOYING
|
||||
DEPLOYED
|
||||
ERROR
|
||||
}
|
||||
|
||||
model FineTune {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
|
||||
slug String @unique
|
||||
baseModel String
|
||||
status FineTuneStatus @default(PENDING)
|
||||
trainingStartedAt DateTime?
|
||||
trainingFinishedAt DateTime?
|
||||
deploymentStartedAt DateTime?
|
||||
deploymentFinishedAt DateTime?
|
||||
|
||||
datasetId String @db.Uuid
|
||||
dataset Dataset @relation(fields: [datasetId], references: [id], onDelete: Cascade)
|
||||
|
||||
projectId String @db.Uuid
|
||||
project Project @relation(fields: [projectId], references: [id], onDelete: Cascade)
|
||||
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
Button,
|
||||
Text,
|
||||
useDisclosure,
|
||||
type InputGroupProps,
|
||||
} from "@chakra-ui/react";
|
||||
|
||||
import { FiChevronDown } from "react-icons/fi";
|
||||
@@ -20,15 +21,25 @@ type InputDropdownProps<T> = {
|
||||
options: ReadonlyArray<T>;
|
||||
selectedOption: T;
|
||||
onSelect: (option: T) => void;
|
||||
inputGroupProps?: InputGroupProps;
|
||||
};
|
||||
|
||||
const InputDropdown = <T,>({ options, selectedOption, onSelect }: InputDropdownProps<T>) => {
|
||||
const InputDropdown = <T,>({
|
||||
options,
|
||||
selectedOption,
|
||||
onSelect,
|
||||
inputGroupProps,
|
||||
}: InputDropdownProps<T>) => {
|
||||
const popover = useDisclosure();
|
||||
|
||||
return (
|
||||
<Popover placement="bottom-start" {...popover}>
|
||||
<PopoverTrigger>
|
||||
<InputGroup cursor="pointer" w={(selectedOption as string).length * 14 + 180}>
|
||||
<InputGroup
|
||||
cursor="pointer"
|
||||
w={(selectedOption as string).length * 14 + 180}
|
||||
{...inputGroupProps}
|
||||
>
|
||||
<Input
|
||||
value={selectedOption as string}
|
||||
// eslint-disable-next-line @typescript-eslint/no-empty-function -- controlled input requires onChange
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
import {
|
||||
HStack,
|
||||
Icon,
|
||||
VStack,
|
||||
Text,
|
||||
Divider,
|
||||
Spinner,
|
||||
AspectRatio,
|
||||
SkeletonText,
|
||||
} from "@chakra-ui/react";
|
||||
import { RiDatabase2Line } from "react-icons/ri";
|
||||
import { formatTimePast } from "~/utils/dayjs";
|
||||
import Link from "next/link";
|
||||
import { useRouter } from "next/router";
|
||||
import { BsPlusSquare } from "react-icons/bs";
|
||||
import { api } from "~/utils/api";
|
||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { useAppStore } from "~/state/store";
|
||||
|
||||
type DatasetData = {
|
||||
name: string;
|
||||
numEntries: number;
|
||||
id: string;
|
||||
createdAt: Date;
|
||||
updatedAt: Date;
|
||||
};
|
||||
|
||||
export const DatasetCard = ({ dataset }: { dataset: DatasetData }) => {
|
||||
return (
|
||||
<AspectRatio ratio={1.2} w="full">
|
||||
<VStack
|
||||
as={Link}
|
||||
href={{ pathname: "/data/[id]", query: { id: dataset.id } }}
|
||||
bg="gray.50"
|
||||
_hover={{ bg: "gray.100" }}
|
||||
transition="background 0.2s"
|
||||
cursor="pointer"
|
||||
borderColor="gray.200"
|
||||
borderWidth={1}
|
||||
p={4}
|
||||
justify="space-between"
|
||||
>
|
||||
<HStack w="full" color="gray.700" justify="center">
|
||||
<Icon as={RiDatabase2Line} boxSize={4} />
|
||||
<Text fontWeight="bold">{dataset.name}</Text>
|
||||
</HStack>
|
||||
<HStack h="full" spacing={4} flex={1} align="center">
|
||||
<CountLabel label="Rows" count={dataset.numEntries} />
|
||||
</HStack>
|
||||
<HStack w="full" color="gray.500" fontSize="xs" textAlign="center">
|
||||
<Text flex={1}>Created {formatTimePast(dataset.createdAt)}</Text>
|
||||
<Divider h={4} orientation="vertical" />
|
||||
<Text flex={1}>Updated {formatTimePast(dataset.updatedAt)}</Text>
|
||||
</HStack>
|
||||
</VStack>
|
||||
</AspectRatio>
|
||||
);
|
||||
};
|
||||
|
||||
const CountLabel = ({ label, count }: { label: string; count: number }) => {
|
||||
return (
|
||||
<VStack alignItems="center" flex={1}>
|
||||
<Text color="gray.500" fontWeight="bold">
|
||||
{label}
|
||||
</Text>
|
||||
<Text fontSize="sm" color="gray.500">
|
||||
{count}
|
||||
</Text>
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
export const NewDatasetCard = () => {
|
||||
const router = useRouter();
|
||||
const selectedProjectId = useAppStore((s) => s.selectedProjectId);
|
||||
const createMutation = api.datasets.create.useMutation();
|
||||
const [createDataset, isLoading] = useHandledAsyncCallback(async () => {
|
||||
const newDataset = await createMutation.mutateAsync({ projectId: selectedProjectId ?? "" });
|
||||
await router.push({ pathname: "/data/[id]", query: { id: newDataset.id } });
|
||||
}, [createMutation, router, selectedProjectId]);
|
||||
|
||||
return (
|
||||
<AspectRatio ratio={1.2} w="full">
|
||||
<VStack
|
||||
align="center"
|
||||
justify="center"
|
||||
_hover={{ cursor: "pointer", bg: "gray.50" }}
|
||||
transition="background 0.2s"
|
||||
cursor="pointer"
|
||||
borderColor="gray.200"
|
||||
borderWidth={1}
|
||||
p={4}
|
||||
onClick={createDataset}
|
||||
>
|
||||
<Icon as={isLoading ? Spinner : BsPlusSquare} boxSize={8} />
|
||||
<Text display={{ base: "none", md: "block" }} ml={2}>
|
||||
New Dataset
|
||||
</Text>
|
||||
</VStack>
|
||||
</AspectRatio>
|
||||
);
|
||||
};
|
||||
|
||||
export const DatasetCardSkeleton = () => (
|
||||
<AspectRatio ratio={1.2} w="full">
|
||||
<VStack align="center" borderColor="gray.200" borderWidth={1} p={4} bg="gray.50">
|
||||
<SkeletonText noOfLines={1} w="80%" />
|
||||
<SkeletonText noOfLines={2} w="60%" />
|
||||
<SkeletonText noOfLines={1} w="80%" />
|
||||
</VStack>
|
||||
</AspectRatio>
|
||||
);
|
||||
@@ -1,16 +0,0 @@
|
||||
import { type StackProps } from "@chakra-ui/react";
|
||||
|
||||
import { useDatasetEntries } from "~/utils/hooks";
|
||||
import Paginator from "../Paginator";
|
||||
|
||||
const DatasetEntriesPaginator = (props: StackProps) => {
|
||||
const { data } = useDatasetEntries();
|
||||
|
||||
if (!data) return null;
|
||||
|
||||
const { count } = data;
|
||||
|
||||
return <Paginator count={count} {...props} />;
|
||||
};
|
||||
|
||||
export default DatasetEntriesPaginator;
|
||||
@@ -1,31 +0,0 @@
|
||||
import { type StackProps, VStack, Table, Th, Tr, Thead, Tbody, Text } from "@chakra-ui/react";
|
||||
import { useDatasetEntries } from "~/utils/hooks";
|
||||
import TableRow from "./TableRow";
|
||||
import DatasetEntriesPaginator from "./DatasetEntriesPaginator";
|
||||
|
||||
const DatasetEntriesTable = (props: StackProps) => {
|
||||
const { data } = useDatasetEntries();
|
||||
|
||||
return (
|
||||
<VStack justifyContent="space-between" {...props}>
|
||||
<Table variant="simple" sx={{ "table-layout": "fixed", width: "full" }}>
|
||||
<Thead>
|
||||
<Tr>
|
||||
<Th>Input</Th>
|
||||
<Th>Output</Th>
|
||||
</Tr>
|
||||
</Thead>
|
||||
<Tbody>{data?.entries.map((entry) => <TableRow key={entry.id} entry={entry} />)}</Tbody>
|
||||
</Table>
|
||||
{(!data || data.entries.length) === 0 ? (
|
||||
<Text alignSelf="flex-start" pl={6} color="gray.500">
|
||||
No entries found
|
||||
</Text>
|
||||
) : (
|
||||
<DatasetEntriesPaginator />
|
||||
)}
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
export default DatasetEntriesTable;
|
||||
@@ -1,26 +0,0 @@
|
||||
import { Button, HStack, useDisclosure } from "@chakra-ui/react";
|
||||
import { BiImport } from "react-icons/bi";
|
||||
import { BsStars } from "react-icons/bs";
|
||||
|
||||
import { GenerateDataModal } from "./GenerateDataModal";
|
||||
|
||||
export const DatasetHeaderButtons = () => {
|
||||
const generateModalDisclosure = useDisclosure();
|
||||
|
||||
return (
|
||||
<>
|
||||
<HStack>
|
||||
<Button leftIcon={<BiImport />} colorScheme="blue" variant="ghost">
|
||||
Import Data
|
||||
</Button>
|
||||
<Button leftIcon={<BsStars />} colorScheme="blue" onClick={generateModalDisclosure.onOpen}>
|
||||
Generate Data
|
||||
</Button>
|
||||
</HStack>
|
||||
<GenerateDataModal
|
||||
isOpen={generateModalDisclosure.isOpen}
|
||||
onClose={generateModalDisclosure.onClose}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
};
|
||||
@@ -1,128 +0,0 @@
|
||||
import {
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
ModalFooter,
|
||||
Text,
|
||||
HStack,
|
||||
VStack,
|
||||
Icon,
|
||||
NumberInput,
|
||||
NumberInputField,
|
||||
NumberInputStepper,
|
||||
NumberIncrementStepper,
|
||||
NumberDecrementStepper,
|
||||
Button,
|
||||
} from "@chakra-ui/react";
|
||||
import { BsStars } from "react-icons/bs";
|
||||
import { useState } from "react";
|
||||
import { useDataset, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { api } from "~/utils/api";
|
||||
import AutoResizeTextArea from "~/components/AutoResizeTextArea";
|
||||
|
||||
export const GenerateDataModal = ({
|
||||
isOpen,
|
||||
onClose,
|
||||
}: {
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
}) => {
|
||||
const utils = api.useContext();
|
||||
|
||||
const datasetId = useDataset().data?.id;
|
||||
|
||||
const [numToGenerate, setNumToGenerate] = useState<number>(20);
|
||||
const [inputDescription, setInputDescription] = useState<string>(
|
||||
"Each input should contain an email body. Half of the emails should contain event details, and the other half should not.",
|
||||
);
|
||||
const [outputDescription, setOutputDescription] = useState<string>(
|
||||
`Each output should contain "true" or "false", where "true" indicates that the email contains event details.`,
|
||||
);
|
||||
|
||||
const generateEntriesMutation = api.datasetEntries.autogenerateEntries.useMutation();
|
||||
|
||||
const [generateEntries, generateEntriesInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (!inputDescription || !outputDescription || !numToGenerate || !datasetId) return;
|
||||
await generateEntriesMutation.mutateAsync({
|
||||
datasetId,
|
||||
inputDescription,
|
||||
outputDescription,
|
||||
numToGenerate,
|
||||
});
|
||||
await utils.datasetEntries.list.invalidate();
|
||||
onClose();
|
||||
}, [
|
||||
generateEntriesMutation,
|
||||
onClose,
|
||||
inputDescription,
|
||||
outputDescription,
|
||||
numToGenerate,
|
||||
datasetId,
|
||||
]);
|
||||
|
||||
return (
|
||||
<Modal isOpen={isOpen} onClose={onClose} size={{ base: "xl", sm: "2xl", md: "3xl" }}>
|
||||
<ModalOverlay />
|
||||
<ModalContent w={1200}>
|
||||
<ModalHeader>
|
||||
<HStack>
|
||||
<Icon as={BsStars} />
|
||||
<Text>Generate Data</Text>
|
||||
</HStack>
|
||||
</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody maxW="unset">
|
||||
<VStack w="full" spacing={8} padding={8} alignItems="flex-start">
|
||||
<VStack alignItems="flex-start" spacing={2}>
|
||||
<Text fontWeight="bold">Number of Rows:</Text>
|
||||
<NumberInput
|
||||
step={5}
|
||||
defaultValue={15}
|
||||
min={0}
|
||||
max={100}
|
||||
onChange={(valueString) => setNumToGenerate(parseInt(valueString) || 0)}
|
||||
value={numToGenerate}
|
||||
w="24"
|
||||
>
|
||||
<NumberInputField />
|
||||
<NumberInputStepper>
|
||||
<NumberIncrementStepper />
|
||||
<NumberDecrementStepper />
|
||||
</NumberInputStepper>
|
||||
</NumberInput>
|
||||
</VStack>
|
||||
<VStack alignItems="flex-start" w="full" spacing={2}>
|
||||
<Text fontWeight="bold">Input Description:</Text>
|
||||
<AutoResizeTextArea
|
||||
value={inputDescription}
|
||||
onChange={(e) => setInputDescription(e.target.value)}
|
||||
placeholder="Each input should contain..."
|
||||
/>
|
||||
</VStack>
|
||||
<VStack alignItems="flex-start" w="full" spacing={2}>
|
||||
<Text fontWeight="bold">Output Description (optional):</Text>
|
||||
<AutoResizeTextArea
|
||||
value={outputDescription}
|
||||
onChange={(e) => setOutputDescription(e.target.value)}
|
||||
placeholder="The output should contain..."
|
||||
/>
|
||||
</VStack>
|
||||
</VStack>
|
||||
</ModalBody>
|
||||
<ModalFooter>
|
||||
<Button
|
||||
colorScheme="blue"
|
||||
isLoading={generateEntriesInProgress}
|
||||
isDisabled={!numToGenerate || !inputDescription || !outputDescription}
|
||||
onClick={generateEntries}
|
||||
>
|
||||
Generate
|
||||
</Button>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
@@ -1,13 +0,0 @@
|
||||
import { Td, Tr } from "@chakra-ui/react";
|
||||
import { type DatasetEntry } from "@prisma/client";
|
||||
|
||||
const TableRow = ({ entry }: { entry: DatasetEntry }) => {
|
||||
return (
|
||||
<Tr key={entry.id}>
|
||||
<Td>{entry.input}</Td>
|
||||
<Td>{entry.output}</Td>
|
||||
</Tr>
|
||||
);
|
||||
};
|
||||
|
||||
export default TableRow;
|
||||
65
app/src/components/fineTunes/FineTunesTable.tsx
Normal file
65
app/src/components/fineTunes/FineTunesTable.tsx
Normal file
@@ -0,0 +1,65 @@
|
||||
import { Card, Table, Thead, Tr, Th, Tbody, Td, VStack, Icon, Text } from "@chakra-ui/react";
|
||||
import { FaTable } from "react-icons/fa";
|
||||
import { type FineTuneStatus } from "@prisma/client";
|
||||
|
||||
import dayjs from "~/utils/dayjs";
|
||||
import { useFineTunes } from "~/utils/hooks";
|
||||
|
||||
const FineTunesTable = ({}) => {
|
||||
const { data } = useFineTunes();
|
||||
|
||||
const fineTunes = data?.fineTunes || [];
|
||||
|
||||
return (
|
||||
<Card width="100%" overflowX="auto">
|
||||
{fineTunes.length ? (
|
||||
<Table>
|
||||
<Thead>
|
||||
<Tr>
|
||||
<Th>ID</Th>
|
||||
<Th>Created At</Th>
|
||||
<Th>Base Model</Th>
|
||||
<Th>Dataset Size</Th>
|
||||
<Th>Status</Th>
|
||||
</Tr>
|
||||
</Thead>
|
||||
<Tbody>
|
||||
{fineTunes.map((fineTune) => {
|
||||
return (
|
||||
<Tr key={fineTune.id}>
|
||||
<Td>{fineTune.slug}</Td>
|
||||
<Td>{dayjs(fineTune.createdAt).format("MMMM D h:mm A")}</Td>
|
||||
<Td>{fineTune.baseModel}</Td>
|
||||
<Td>{fineTune.dataset._count.datasetEntries}</Td>
|
||||
<Td fontSize="sm" fontWeight="bold">
|
||||
<Text color={getStatusColor(fineTune.status)}>{fineTune.status}</Text>
|
||||
</Td>
|
||||
</Tr>
|
||||
);
|
||||
})}
|
||||
</Tbody>
|
||||
</Table>
|
||||
) : (
|
||||
<VStack py={8}>
|
||||
<Icon as={FaTable} boxSize={16} color="gray.300" />
|
||||
<Text color="gray.400" fontSize="lg" fontWeight="bold">
|
||||
No Fine Tunes Found
|
||||
</Text>
|
||||
</VStack>
|
||||
)}
|
||||
</Card>
|
||||
);
|
||||
};
|
||||
|
||||
export default FineTunesTable;
|
||||
|
||||
const getStatusColor = (status: FineTuneStatus) => {
|
||||
switch (status) {
|
||||
case "DEPLOYED":
|
||||
return "green.500";
|
||||
case "ERROR":
|
||||
return "red.500";
|
||||
default:
|
||||
return "yellow.500";
|
||||
}
|
||||
};
|
||||
@@ -15,12 +15,14 @@ import Head from "next/head";
|
||||
import Link from "next/link";
|
||||
import { BsGearFill, BsGithub, BsPersonCircle } from "react-icons/bs";
|
||||
import { IoStatsChartOutline } from "react-icons/io5";
|
||||
import { RiHome3Line, RiDatabase2Line, RiFlaskLine } from "react-icons/ri";
|
||||
import { RiHome3Line, RiFlaskLine } from "react-icons/ri";
|
||||
import { FaRobot } from "react-icons/fa";
|
||||
import { signIn, useSession } from "next-auth/react";
|
||||
import { env } from "~/env.mjs";
|
||||
import ProjectMenu from "./ProjectMenu";
|
||||
import NavSidebarOption from "./NavSidebarOption";
|
||||
import IconLink from "./IconLink";
|
||||
import { BetaModal } from "./BetaModal";
|
||||
|
||||
const Divider = () => <Box h="1px" bgColor="gray.300" w="full" />;
|
||||
|
||||
@@ -71,21 +73,10 @@ const NavSidebar = () => {
|
||||
<ProjectMenu />
|
||||
<Divider />
|
||||
|
||||
{env.NEXT_PUBLIC_FF_SHOW_LOGGED_CALLS && (
|
||||
<>
|
||||
<IconLink icon={RiHome3Line} label="Dashboard" href="/dashboard" beta />
|
||||
<IconLink
|
||||
icon={IoStatsChartOutline}
|
||||
label="Request Logs"
|
||||
href="/request-logs"
|
||||
beta
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
<IconLink icon={RiHome3Line} label="Dashboard" href="/dashboard" beta />
|
||||
<IconLink icon={IoStatsChartOutline} label="Request Logs" href="/request-logs" beta />
|
||||
<IconLink icon={FaRobot} label="Fine Tunes" href="/fine-tunes" beta />
|
||||
<IconLink icon={RiFlaskLine} label="Experiments" href="/experiments" />
|
||||
{env.NEXT_PUBLIC_SHOW_DATA && (
|
||||
<IconLink icon={RiDatabase2Line} label="Data" href="/data" />
|
||||
)}
|
||||
<VStack w="full" alignItems="flex-start" spacing={0} pt={8}>
|
||||
<Text
|
||||
pl={2}
|
||||
@@ -141,10 +132,12 @@ export default function AppShell({
|
||||
children,
|
||||
title,
|
||||
requireAuth,
|
||||
requireBeta,
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
title?: string;
|
||||
requireAuth?: boolean;
|
||||
requireBeta?: boolean;
|
||||
}) {
|
||||
const [vh, setVh] = useState("100vh"); // Default height to prevent flicker on initial render
|
||||
|
||||
@@ -175,14 +168,17 @@ export default function AppShell({
|
||||
}, [requireAuth, user, authLoading]);
|
||||
|
||||
return (
|
||||
<Flex h={vh} w="100vw">
|
||||
<Head>
|
||||
<title>{title ? `${title} | OpenPipe` : "OpenPipe"}</title>
|
||||
</Head>
|
||||
<NavSidebar />
|
||||
<Box h="100%" flex={1} overflowY="auto" bgColor="gray.50">
|
||||
{children}
|
||||
</Box>
|
||||
</Flex>
|
||||
<>
|
||||
<Flex h={vh} w="100vw">
|
||||
<Head>
|
||||
<title>{title ? `${title} | OpenPipe` : "OpenPipe"}</title>
|
||||
</Head>
|
||||
<NavSidebar />
|
||||
<Box h="100%" flex={1} overflowY="auto" bgColor="gray.50">
|
||||
{children}
|
||||
</Box>
|
||||
</Flex>
|
||||
{requireBeta && !env.NEXT_PUBLIC_FF_SHOW_BETA_FEATURES && <BetaModal />}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
67
app/src/components/nav/BetaModal.tsx
Normal file
67
app/src/components/nav/BetaModal.tsx
Normal file
@@ -0,0 +1,67 @@
|
||||
import {
|
||||
Button,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
VStack,
|
||||
Text,
|
||||
HStack,
|
||||
Icon,
|
||||
Link,
|
||||
} from "@chakra-ui/react";
|
||||
import { BsStars } from "react-icons/bs";
|
||||
import { useRouter } from "next/router";
|
||||
import { useSession } from "next-auth/react";
|
||||
|
||||
export const BetaModal = () => {
|
||||
const router = useRouter();
|
||||
const session = useSession();
|
||||
|
||||
const email = session.data?.user.email ?? "";
|
||||
|
||||
return (
|
||||
<Modal
|
||||
isOpen
|
||||
onClose={router.back}
|
||||
closeOnOverlayClick={false}
|
||||
size={{ base: "xl", md: "2xl" }}
|
||||
>
|
||||
<ModalOverlay />
|
||||
<ModalContent w={1200}>
|
||||
<ModalHeader>
|
||||
<HStack>
|
||||
<Icon as={BsStars} />
|
||||
<Text>Beta-Only Feature</Text>
|
||||
</HStack>
|
||||
</ModalHeader>
|
||||
<ModalBody maxW="unset">
|
||||
<VStack spacing={8} py={4} alignItems="flex-start">
|
||||
<Text fontSize="md">
|
||||
This feature is currently in beta. To receive early access to beta-only features, join
|
||||
the waitlist. You'll receive an email at <b>{email}</b> when you're approved.
|
||||
</Text>
|
||||
</VStack>
|
||||
</ModalBody>
|
||||
<ModalFooter>
|
||||
<HStack spacing={4}>
|
||||
<Button
|
||||
as={Link}
|
||||
textDecoration="none !important"
|
||||
colorScheme="orange"
|
||||
target="_blank"
|
||||
href={`https://ax3nafkw0jp.typeform.com/to/ZNpYqvAc#email=${email}`}
|
||||
>
|
||||
Join Waitlist
|
||||
</Button>
|
||||
<Button colorScheme="blue" onClick={router.back}>
|
||||
Done
|
||||
</Button>
|
||||
</HStack>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
161
app/src/components/requestLogs/FineTuneButton.tsx
Normal file
161
app/src/components/requestLogs/FineTuneButton.tsx
Normal file
@@ -0,0 +1,161 @@
|
||||
import { useState, useEffect } from "react";
|
||||
import {
|
||||
Modal,
|
||||
ModalOverlay,
|
||||
ModalContent,
|
||||
ModalHeader,
|
||||
ModalCloseButton,
|
||||
ModalBody,
|
||||
ModalFooter,
|
||||
HStack,
|
||||
VStack,
|
||||
Icon,
|
||||
Text,
|
||||
Button,
|
||||
useDisclosure,
|
||||
type UseDisclosureReturn,
|
||||
Input,
|
||||
} from "@chakra-ui/react";
|
||||
import { FaRobot } from "react-icons/fa";
|
||||
import humanId from "human-id";
|
||||
import { useRouter } from "next/router";
|
||||
|
||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import { api } from "~/utils/api";
|
||||
import { useAppStore } from "~/state/store";
|
||||
import ActionButton from "./ActionButton";
|
||||
import InputDropdown from "../InputDropdown";
|
||||
import { FiChevronDown } from "react-icons/fi";
|
||||
|
||||
const SUPPORTED_BASE_MODELS = ["llama2-7b", "llama2-13b", "llama2-70b", "gpt-3.5-turbo"];
|
||||
|
||||
const FineTuneButton = () => {
|
||||
const selectedLogIds = useAppStore((s) => s.selectedLogs.selectedLogIds);
|
||||
|
||||
const disclosure = useDisclosure();
|
||||
|
||||
return (
|
||||
<>
|
||||
<ActionButton
|
||||
onClick={disclosure.onOpen}
|
||||
label="Fine Tune"
|
||||
icon={FaRobot}
|
||||
isDisabled={selectedLogIds.size === 0}
|
||||
/>
|
||||
<FineTuneModal disclosure={disclosure} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default FineTuneButton;
|
||||
|
||||
const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
|
||||
const selectedProjectId = useAppStore((s) => s.selectedProjectId);
|
||||
const selectedLogIds = useAppStore((s) => s.selectedLogs.selectedLogIds);
|
||||
const clearSelectedLogIds = useAppStore((s) => s.selectedLogs.clearSelectedLogIds);
|
||||
|
||||
const [selectedBaseModel, setSelectedBaseModel] = useState(SUPPORTED_BASE_MODELS[0]);
|
||||
const [modelSlug, setModelSlug] = useState(humanId({ separator: "-", capitalize: false }));
|
||||
|
||||
useEffect(() => {
|
||||
if (disclosure.isOpen) {
|
||||
setSelectedBaseModel(SUPPORTED_BASE_MODELS[0]);
|
||||
setModelSlug(humanId({ separator: "-", capitalize: false }));
|
||||
}
|
||||
}, [disclosure.isOpen]);
|
||||
|
||||
const utils = api.useContext();
|
||||
const router = useRouter();
|
||||
|
||||
const createFineTuneMutation = api.fineTunes.create.useMutation();
|
||||
|
||||
const [createFineTune, creationInProgress] = useHandledAsyncCallback(async () => {
|
||||
if (!selectedProjectId || !modelSlug || !selectedBaseModel || !selectedLogIds.size) return;
|
||||
await createFineTuneMutation.mutateAsync({
|
||||
projectId: selectedProjectId,
|
||||
slug: modelSlug,
|
||||
baseModel: selectedBaseModel,
|
||||
selectedLogIds: Array.from(selectedLogIds),
|
||||
});
|
||||
|
||||
await utils.fineTunes.list.invalidate();
|
||||
await router.push({ pathname: "/fine-tunes" });
|
||||
clearSelectedLogIds();
|
||||
disclosure.onClose();
|
||||
}, [createFineTuneMutation, selectedProjectId, selectedLogIds, modelSlug, selectedBaseModel]);
|
||||
|
||||
return (
|
||||
<Modal size={{ base: "xl", md: "2xl" }} {...disclosure}>
|
||||
<ModalOverlay />
|
||||
<ModalContent w={1200}>
|
||||
<ModalHeader>
|
||||
<HStack>
|
||||
<Icon as={FaRobot} />
|
||||
<Text>Fine Tune</Text>
|
||||
</HStack>
|
||||
</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody maxW="unset">
|
||||
<VStack w="full" spacing={8} pt={4} alignItems="flex-start">
|
||||
<Text>
|
||||
We'll train on the <b>{selectedLogIds.size}</b> logs you've selected.
|
||||
</Text>
|
||||
<VStack>
|
||||
<HStack spacing={2} w="full">
|
||||
<Text fontWeight="bold" w={36}>
|
||||
Model ID:
|
||||
</Text>
|
||||
<Input
|
||||
value={modelSlug}
|
||||
onChange={(e) => setModelSlug(e.target.value)}
|
||||
w={48}
|
||||
placeholder="unique-id"
|
||||
onKeyDown={(e) => {
|
||||
// If the user types anything other than a-z, A-Z, or 0-9, replace it with -
|
||||
if (!/[a-zA-Z0-9]/.test(e.key)) {
|
||||
e.preventDefault();
|
||||
setModelSlug((s) => s && `${s}-`);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</HStack>
|
||||
<HStack spacing={2}>
|
||||
<Text fontWeight="bold" w={36}>
|
||||
Base model:
|
||||
</Text>
|
||||
<InputDropdown
|
||||
options={SUPPORTED_BASE_MODELS}
|
||||
selectedOption={selectedBaseModel}
|
||||
onSelect={(option) => setSelectedBaseModel(option)}
|
||||
inputGroupProps={{ w: 48 }}
|
||||
/>
|
||||
</HStack>
|
||||
</VStack>
|
||||
<Button variant="unstyled" color="blue.600">
|
||||
<HStack>
|
||||
<Text>Advanced Options</Text>
|
||||
<Icon as={FiChevronDown} />
|
||||
</HStack>
|
||||
</Button>
|
||||
</VStack>
|
||||
</ModalBody>
|
||||
<ModalFooter>
|
||||
<HStack>
|
||||
<Button colorScheme="gray" onClick={disclosure.onClose} minW={24}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
colorScheme="blue"
|
||||
onClick={createFineTune}
|
||||
isLoading={creationInProgress}
|
||||
minW={24}
|
||||
isDisabled={!modelSlug}
|
||||
>
|
||||
Start Training
|
||||
</Button>
|
||||
</HStack>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
@@ -10,7 +10,7 @@ export default function LoggedCallsTable() {
|
||||
return (
|
||||
<Card width="100%" overflowX="auto">
|
||||
<Table>
|
||||
<TableHeader isSimple />
|
||||
<TableHeader showOptions />
|
||||
<Tbody>
|
||||
{loggedCalls?.calls?.map((loggedCall) => {
|
||||
return (
|
||||
@@ -25,7 +25,7 @@ export default function LoggedCallsTable() {
|
||||
setExpandedRow(loggedCall.id);
|
||||
}
|
||||
}}
|
||||
isSimple
|
||||
showOptions
|
||||
/>
|
||||
);
|
||||
})}
|
||||
|
||||
@@ -14,10 +14,9 @@ import {
|
||||
Text,
|
||||
Checkbox,
|
||||
} from "@chakra-ui/react";
|
||||
import dayjs from "dayjs";
|
||||
import relativeTime from "dayjs/plugin/relativeTime";
|
||||
import Link from "next/link";
|
||||
|
||||
import dayjs from "~/utils/dayjs";
|
||||
import { type RouterOutputs } from "~/utils/api";
|
||||
import { FormattedJson } from "./FormattedJson";
|
||||
import { useAppStore } from "~/state/store";
|
||||
@@ -25,11 +24,9 @@ import { useIsClientRehydrated, useLoggedCalls, useTagNames } from "~/utils/hook
|
||||
import { useMemo } from "react";
|
||||
import { StaticColumnKeys } from "~/state/columnVisiblitySlice";
|
||||
|
||||
dayjs.extend(relativeTime);
|
||||
|
||||
type LoggedCall = RouterOutputs["loggedCalls"]["list"]["calls"][0];
|
||||
|
||||
export const TableHeader = ({ isSimple }: { isSimple?: boolean }) => {
|
||||
export const TableHeader = ({ showOptions }: { showOptions?: boolean }) => {
|
||||
const matchingLogIds = useLoggedCalls().data?.matchingLogIds;
|
||||
const selectedLogIds = useAppStore((s) => s.selectedLogs.selectedLogIds);
|
||||
const addAll = useAppStore((s) => s.selectedLogs.addSelectedLogIds);
|
||||
@@ -46,7 +43,7 @@ export const TableHeader = ({ isSimple }: { isSimple?: boolean }) => {
|
||||
return (
|
||||
<Thead>
|
||||
<Tr>
|
||||
{isSimple && (
|
||||
{showOptions && (
|
||||
<Th pr={0}>
|
||||
<HStack minW={16}>
|
||||
<Checkbox
|
||||
@@ -84,12 +81,12 @@ export const TableRow = ({
|
||||
loggedCall,
|
||||
isExpanded,
|
||||
onToggle,
|
||||
isSimple,
|
||||
showOptions,
|
||||
}: {
|
||||
loggedCall: LoggedCall;
|
||||
isExpanded: boolean;
|
||||
onToggle: () => void;
|
||||
isSimple?: boolean;
|
||||
showOptions?: boolean;
|
||||
}) => {
|
||||
const isError = loggedCall.modelResponse?.statusCode !== 200;
|
||||
const requestedAt = dayjs(loggedCall.requestedAt).format("MMMM D h:mm A");
|
||||
@@ -101,6 +98,10 @@ export const TableRow = ({
|
||||
const tagNames = useTagNames().data;
|
||||
const visibleColumns = useAppStore((s) => s.columnVisibility.visibleColumns);
|
||||
|
||||
const visibleTagNames = useMemo(() => {
|
||||
return tagNames?.filter((tagName) => visibleColumns.has(tagName)) ?? [];
|
||||
}, [tagNames, visibleColumns]);
|
||||
|
||||
const isClientRehydrated = useIsClientRehydrated();
|
||||
if (!isClientRehydrated) return null;
|
||||
|
||||
@@ -115,7 +116,7 @@ export const TableRow = ({
|
||||
}}
|
||||
fontSize="sm"
|
||||
>
|
||||
{isSimple && (
|
||||
{showOptions && (
|
||||
<Td>
|
||||
<Checkbox isChecked={isChecked} onChange={() => toggleChecked(loggedCall.id)} />
|
||||
</Td>
|
||||
@@ -147,9 +148,9 @@ export const TableRow = ({
|
||||
</HStack>
|
||||
</Td>
|
||||
)}
|
||||
{tagNames
|
||||
?.filter((tagName) => visibleColumns.has(tagName))
|
||||
.map((tagName) => <Td key={tagName}>{loggedCall.tags[tagName]}</Td>)}
|
||||
{visibleTagNames.map((tagName) => (
|
||||
<Td key={tagName}>{loggedCall.tags[tagName]}</Td>
|
||||
))}
|
||||
{visibleColumns.has(StaticColumnKeys.DURATION) && (
|
||||
<Td isNumeric>
|
||||
{loggedCall.cacheHit ? (
|
||||
@@ -172,7 +173,7 @@ export const TableRow = ({
|
||||
)}
|
||||
</Tr>
|
||||
<Tr>
|
||||
<Td colSpan={8} p={0}>
|
||||
<Td colSpan={visibleColumns.size + 1} w="full" p={0}>
|
||||
<Collapse in={isExpanded} unmountOnExit={true}>
|
||||
<VStack p={4} align="stretch">
|
||||
<HStack align="stretch">
|
||||
|
||||
@@ -46,8 +46,7 @@ export const env = createEnv({
|
||||
NEXT_PUBLIC_SOCKET_URL: z.string().url().default("http://localhost:3318"),
|
||||
NEXT_PUBLIC_HOST: z.string().url().default("http://localhost:3000"),
|
||||
NEXT_PUBLIC_SENTRY_DSN: z.string().optional(),
|
||||
NEXT_PUBLIC_SHOW_DATA: z.string().optional(),
|
||||
NEXT_PUBLIC_FF_SHOW_LOGGED_CALLS: z.string().optional(),
|
||||
NEXT_PUBLIC_FF_SHOW_BETA_FEATURES: z.string().optional(),
|
||||
},
|
||||
|
||||
/**
|
||||
@@ -62,7 +61,6 @@ export const env = createEnv({
|
||||
NEXT_PUBLIC_POSTHOG_KEY: process.env.NEXT_PUBLIC_POSTHOG_KEY,
|
||||
NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL,
|
||||
NEXT_PUBLIC_HOST: process.env.NEXT_PUBLIC_HOST,
|
||||
NEXT_PUBLIC_SHOW_DATA: process.env.NEXT_PUBLIC_SHOW_DATA,
|
||||
GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID,
|
||||
GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET,
|
||||
REPLICATE_API_TOKEN: process.env.REPLICATE_API_TOKEN,
|
||||
@@ -70,7 +68,7 @@ export const env = createEnv({
|
||||
NEXT_PUBLIC_SENTRY_DSN: process.env.NEXT_PUBLIC_SENTRY_DSN,
|
||||
SENTRY_AUTH_TOKEN: process.env.SENTRY_AUTH_TOKEN,
|
||||
OPENPIPE_API_KEY: process.env.OPENPIPE_API_KEY,
|
||||
NEXT_PUBLIC_FF_SHOW_LOGGED_CALLS: process.env.NEXT_PUBLIC_FF_SHOW_LOGGED_CALLS,
|
||||
NEXT_PUBLIC_FF_SHOW_BETA_FEATURES: process.env.NEXT_PUBLIC_FF_SHOW_BETA_FEATURES,
|
||||
SENDER_EMAIL: process.env.SENDER_EMAIL,
|
||||
SMTP_HOST: process.env.SMTP_HOST,
|
||||
SMTP_PORT: process.env.SMTP_PORT,
|
||||
|
||||
@@ -33,7 +33,7 @@ export default function Dashboard() {
|
||||
);
|
||||
|
||||
return (
|
||||
<AppShell title="Dashboard" requireAuth>
|
||||
<AppShell title="Dashboard" requireAuth requireBeta>
|
||||
<VStack px={8} py={8} alignItems="flex-start" spacing={4}>
|
||||
<Text fontSize="2xl" fontWeight="bold">
|
||||
Dashboard
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
import {
|
||||
Box,
|
||||
Breadcrumb,
|
||||
BreadcrumbItem,
|
||||
Center,
|
||||
Flex,
|
||||
Icon,
|
||||
Input,
|
||||
VStack,
|
||||
} from "@chakra-ui/react";
|
||||
import Link from "next/link";
|
||||
|
||||
import { useRouter } from "next/router";
|
||||
import { useState, useEffect } from "react";
|
||||
import { RiDatabase2Line } from "react-icons/ri";
|
||||
import AppShell from "~/components/nav/AppShell";
|
||||
import { api } from "~/utils/api";
|
||||
import { useDataset, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import DatasetEntriesTable from "~/components/datasets/DatasetEntriesTable";
|
||||
import { DatasetHeaderButtons } from "~/components/datasets/DatasetHeaderButtons/DatasetHeaderButtons";
|
||||
import PageHeaderContainer from "~/components/nav/PageHeaderContainer";
|
||||
import ProjectBreadcrumbContents from "~/components/nav/ProjectBreadcrumbContents";
|
||||
|
||||
export default function Dataset() {
|
||||
const router = useRouter();
|
||||
const utils = api.useContext();
|
||||
|
||||
const dataset = useDataset();
|
||||
const datasetId = router.query.id as string;
|
||||
|
||||
const [name, setName] = useState(dataset.data?.name || "");
|
||||
useEffect(() => {
|
||||
setName(dataset.data?.name || "");
|
||||
}, [dataset.data?.name]);
|
||||
|
||||
const updateMutation = api.datasets.update.useMutation();
|
||||
const [onSaveName] = useHandledAsyncCallback(async () => {
|
||||
if (name && name !== dataset.data?.name && dataset.data?.id) {
|
||||
await updateMutation.mutateAsync({
|
||||
id: dataset.data.id,
|
||||
updates: { name: name },
|
||||
});
|
||||
await Promise.all([utils.datasets.list.invalidate(), utils.datasets.get.invalidate()]);
|
||||
}
|
||||
}, [updateMutation, dataset.data?.id, dataset.data?.name, name]);
|
||||
|
||||
if (!dataset.isLoading && !dataset.data) {
|
||||
return (
|
||||
<AppShell title="Dataset not found">
|
||||
<Center h="100%">
|
||||
<div>Dataset not found 😕</div>
|
||||
</Center>
|
||||
</AppShell>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<AppShell title={dataset.data?.name}>
|
||||
<VStack h="full">
|
||||
<PageHeaderContainer>
|
||||
<Breadcrumb>
|
||||
<BreadcrumbItem>
|
||||
<ProjectBreadcrumbContents projectName={dataset.data?.project?.name} />
|
||||
</BreadcrumbItem>
|
||||
<BreadcrumbItem>
|
||||
<Link href="/data">
|
||||
<Flex alignItems="center" _hover={{ textDecoration: "underline" }}>
|
||||
<Icon as={RiDatabase2Line} boxSize={4} mr={2} /> Datasets
|
||||
</Flex>
|
||||
</Link>
|
||||
</BreadcrumbItem>
|
||||
<BreadcrumbItem isCurrentPage>
|
||||
<Input
|
||||
size="sm"
|
||||
value={name}
|
||||
onChange={(e) => setName(e.target.value)}
|
||||
onBlur={onSaveName}
|
||||
borderWidth={1}
|
||||
borderColor="transparent"
|
||||
fontSize={16}
|
||||
px={0}
|
||||
minW={{ base: 100, lg: 300 }}
|
||||
flex={1}
|
||||
_hover={{ borderColor: "gray.300" }}
|
||||
_focus={{ borderColor: "blue.500", outline: "none" }}
|
||||
/>
|
||||
</BreadcrumbItem>
|
||||
</Breadcrumb>
|
||||
<DatasetHeaderButtons />
|
||||
</PageHeaderContainer>
|
||||
<Box w="full" overflowX="auto" flex={1} px={8} pt={8} pb={16}>
|
||||
{datasetId && <DatasetEntriesTable />}
|
||||
</Box>
|
||||
</VStack>
|
||||
</AppShell>
|
||||
);
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
import { SimpleGrid, Icon, Breadcrumb, BreadcrumbItem, Flex } from "@chakra-ui/react";
|
||||
import AppShell from "~/components/nav/AppShell";
|
||||
import { RiDatabase2Line } from "react-icons/ri";
|
||||
import {
|
||||
DatasetCard,
|
||||
DatasetCardSkeleton,
|
||||
NewDatasetCard,
|
||||
} from "~/components/datasets/DatasetCard";
|
||||
import PageHeaderContainer from "~/components/nav/PageHeaderContainer";
|
||||
import ProjectBreadcrumbContents from "~/components/nav/ProjectBreadcrumbContents";
|
||||
import { useDatasets } from "~/utils/hooks";
|
||||
|
||||
export default function DatasetsPage() {
|
||||
const datasets = useDatasets();
|
||||
|
||||
return (
|
||||
<AppShell title="Data" requireAuth>
|
||||
<PageHeaderContainer>
|
||||
<Breadcrumb>
|
||||
<BreadcrumbItem>
|
||||
<ProjectBreadcrumbContents />
|
||||
</BreadcrumbItem>
|
||||
<BreadcrumbItem minH={8}>
|
||||
<Flex alignItems="center">
|
||||
<Icon as={RiDatabase2Line} boxSize={4} mr={2} /> Datasets
|
||||
</Flex>
|
||||
</BreadcrumbItem>
|
||||
</Breadcrumb>
|
||||
</PageHeaderContainer>
|
||||
<SimpleGrid w="full" columns={{ base: 1, md: 2, lg: 3, xl: 4 }} spacing={8} py={4} px={8}>
|
||||
<NewDatasetCard />
|
||||
{datasets.data && !datasets.isLoading ? (
|
||||
datasets?.data?.map((dataset) => (
|
||||
<DatasetCard
|
||||
key={dataset.id}
|
||||
dataset={{ ...dataset, numEntries: dataset._count.datasetEntries }}
|
||||
/>
|
||||
))
|
||||
) : (
|
||||
<>
|
||||
<DatasetCardSkeleton />
|
||||
<DatasetCardSkeleton />
|
||||
<DatasetCardSkeleton />
|
||||
</>
|
||||
)}
|
||||
</SimpleGrid>
|
||||
</AppShell>
|
||||
);
|
||||
}
|
||||
18
app/src/pages/fine-tunes/index.tsx
Normal file
18
app/src/pages/fine-tunes/index.tsx
Normal file
@@ -0,0 +1,18 @@
|
||||
import { Text, VStack, Divider } from "@chakra-ui/react";
|
||||
import FineTunesTable from "~/components/fineTunes/FineTunesTable";
|
||||
|
||||
import AppShell from "~/components/nav/AppShell";
|
||||
|
||||
export default function FineTunes() {
|
||||
return (
|
||||
<AppShell title="Fine Tunes" requireAuth requireBeta>
|
||||
<VStack px={8} py={8} alignItems="flex-start" spacing={4} w="full">
|
||||
<Text fontSize="2xl" fontWeight="bold">
|
||||
Fine Tunes
|
||||
</Text>
|
||||
<Divider />
|
||||
<FineTunesTable />
|
||||
</VStack>
|
||||
</AppShell>
|
||||
);
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useState } from "react";
|
||||
import { Text, VStack, Divider, HStack } from "@chakra-ui/react";
|
||||
import { Text, VStack, Divider, HStack, Box } from "@chakra-ui/react";
|
||||
|
||||
import AppShell from "~/components/nav/AppShell";
|
||||
import LoggedCallTable from "~/components/requestLogs/LoggedCallsTable";
|
||||
@@ -10,6 +10,7 @@ import { RiFlaskLine } from "react-icons/ri";
|
||||
import { FiFilter } from "react-icons/fi";
|
||||
import LogFilters from "~/components/requestLogs/LogFilters/LogFilters";
|
||||
import ColumnVisiblityDropdown from "~/components/requestLogs/ColumnVisiblityDropdown";
|
||||
import FineTuneButton from "~/components/requestLogs/FineTuneButton";
|
||||
|
||||
export default function LoggedCalls() {
|
||||
const selectedLogIds = useAppStore((s) => s.selectedLogs.selectedLogIds);
|
||||
@@ -17,34 +18,37 @@ export default function LoggedCalls() {
|
||||
const [filtersShown, setFiltersShown] = useState(true);
|
||||
|
||||
return (
|
||||
<AppShell title="Request Logs" requireAuth>
|
||||
<VStack px={8} py={8} alignItems="flex-start" spacing={4} w="full">
|
||||
<Text fontSize="2xl" fontWeight="bold">
|
||||
Request Logs
|
||||
</Text>
|
||||
<Divider />
|
||||
<HStack w="full" justifyContent="flex-end">
|
||||
<ColumnVisiblityDropdown />
|
||||
<ActionButton
|
||||
onClick={() => {
|
||||
setFiltersShown(!filtersShown);
|
||||
}}
|
||||
label={filtersShown ? "Hide Filters" : "Show Filters"}
|
||||
icon={FiFilter}
|
||||
/>
|
||||
<ActionButton
|
||||
onClick={() => {
|
||||
console.log("experimenting with these ids", selectedLogIds);
|
||||
}}
|
||||
label="Experiment"
|
||||
icon={RiFlaskLine}
|
||||
isDisabled={selectedLogIds.size === 0}
|
||||
/>
|
||||
</HStack>
|
||||
{filtersShown && <LogFilters />}
|
||||
<LoggedCallTable />
|
||||
<LoggedCallsPaginator />
|
||||
</VStack>
|
||||
<AppShell title="Request Logs" requireAuth requireBeta>
|
||||
<Box h="100vh" overflowY="scroll">
|
||||
<VStack px={8} py={8} alignItems="flex-start" spacing={4} w="full">
|
||||
<Text fontSize="2xl" fontWeight="bold">
|
||||
Request Logs
|
||||
</Text>
|
||||
<Divider />
|
||||
<HStack w="full" justifyContent="flex-end">
|
||||
<FineTuneButton />
|
||||
<ActionButton
|
||||
onClick={() => {
|
||||
console.log("experimenting with these ids", selectedLogIds);
|
||||
}}
|
||||
label="Experiment"
|
||||
icon={RiFlaskLine}
|
||||
isDisabled={selectedLogIds.size === 0}
|
||||
/>
|
||||
<ColumnVisiblityDropdown />
|
||||
<ActionButton
|
||||
onClick={() => {
|
||||
setFiltersShown(!filtersShown);
|
||||
}}
|
||||
label={filtersShown ? "Hide Filters" : "Show Filters"}
|
||||
icon={FiFilter}
|
||||
/>
|
||||
</HStack>
|
||||
{filtersShown && <LogFilters />}
|
||||
<LoggedCallTable />
|
||||
<LoggedCallsPaginator />
|
||||
</VStack>
|
||||
</Box>
|
||||
</AppShell>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
import { type ChatCompletion } from "openai/resources/chat";
|
||||
import { openai } from "../../utils/openai";
|
||||
import { isAxiosError } from "./utils";
|
||||
import { type APIResponse } from "openai/core";
|
||||
import { sleep } from "~/server/utils/sleep";
|
||||
|
||||
const MAX_AUTO_RETRIES = 50;
|
||||
const MIN_DELAY = 500; // milliseconds
|
||||
const MAX_DELAY = 15000; // milliseconds
|
||||
|
||||
function calculateDelay(numPreviousTries: number): number {
|
||||
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
|
||||
const jitter = Math.random() * baseDelay;
|
||||
return baseDelay + jitter;
|
||||
}
|
||||
|
||||
const getCompletionWithBackoff = async (
|
||||
getCompletion: () => Promise<APIResponse<ChatCompletion>>,
|
||||
) => {
|
||||
let completion;
|
||||
let tries = 0;
|
||||
while (tries < MAX_AUTO_RETRIES) {
|
||||
try {
|
||||
completion = await getCompletion();
|
||||
break;
|
||||
} catch (e) {
|
||||
if (isAxiosError(e)) {
|
||||
console.error(e?.response?.data?.error?.message);
|
||||
} else {
|
||||
await sleep(calculateDelay(tries));
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
tries++;
|
||||
}
|
||||
return completion;
|
||||
};
|
||||
// TODO: Add seeds to ensure batches don't contain duplicate data
|
||||
const MAX_BATCH_SIZE = 5;
|
||||
|
||||
export const autogenerateDatasetEntries = async (
|
||||
numToGenerate: number,
|
||||
inputDescription: string,
|
||||
outputDescription: string,
|
||||
): Promise<{ input: string; output: string }[]> => {
|
||||
const batchSizes = Array.from({ length: Math.ceil(numToGenerate / MAX_BATCH_SIZE) }, (_, i) =>
|
||||
i === Math.ceil(numToGenerate / MAX_BATCH_SIZE) - 1 && numToGenerate % MAX_BATCH_SIZE
|
||||
? numToGenerate % MAX_BATCH_SIZE
|
||||
: MAX_BATCH_SIZE,
|
||||
);
|
||||
|
||||
const getCompletion = (batchSize: number) =>
|
||||
openai.chat.completions.create({
|
||||
model: "gpt-4",
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: `The user needs ${batchSize} rows of data, each with an input and an output.\n---\n The input should follow these requirements: ${inputDescription}\n---\n The output should follow these requirements: ${outputDescription}`,
|
||||
},
|
||||
],
|
||||
functions: [
|
||||
{
|
||||
name: "add_list_of_data",
|
||||
description: "Add a list of data to the database",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
rows: {
|
||||
type: "array",
|
||||
description: "The rows of data that match the description",
|
||||
items: {
|
||||
type: "object",
|
||||
properties: {
|
||||
input: {
|
||||
type: "string",
|
||||
description: "The input for this row",
|
||||
},
|
||||
output: {
|
||||
type: "string",
|
||||
description: "The output for this row",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
|
||||
function_call: { name: "add_list_of_data" },
|
||||
temperature: 0.5,
|
||||
openpipe: {
|
||||
tags: {
|
||||
prompt_id: "autogenerateDatasetEntries",
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const completionCallbacks = batchSizes.map((batchSize) =>
|
||||
getCompletionWithBackoff(() => getCompletion(batchSize)),
|
||||
);
|
||||
|
||||
const completions = await Promise.all(completionCallbacks);
|
||||
|
||||
const rows = completions.flatMap((completion) => {
|
||||
const parsed = JSON.parse(
|
||||
completion?.choices[0]?.message?.function_call?.arguments ?? "{rows: []}",
|
||||
) as { rows: { input: string; output: string }[] };
|
||||
return parsed.rows;
|
||||
});
|
||||
|
||||
return rows;
|
||||
};
|
||||
@@ -6,11 +6,10 @@ import { scenarioVariantCellsRouter } from "./routers/scenarioVariantCells.route
|
||||
import { scenarioVarsRouter } from "./routers/scenarioVariables.router";
|
||||
import { evaluationsRouter } from "./routers/evaluations.router";
|
||||
import { worldChampsRouter } from "./routers/worldChamps.router";
|
||||
import { datasetsRouter } from "./routers/datasets.router";
|
||||
import { datasetEntries } from "./routers/datasetEntries.router";
|
||||
import { projectsRouter } from "./routers/projects.router";
|
||||
import { dashboardRouter } from "./routers/dashboard.router";
|
||||
import { loggedCallsRouter } from "./routers/loggedCalls.router";
|
||||
import { fineTunesRouter } from "./routers/fineTunes.router";
|
||||
import { usersRouter } from "./routers/users.router";
|
||||
import { adminJobsRouter } from "./routers/adminJobs.router";
|
||||
|
||||
@@ -27,11 +26,10 @@ export const appRouter = createTRPCRouter({
|
||||
scenarioVars: scenarioVarsRouter,
|
||||
evaluations: evaluationsRouter,
|
||||
worldChamps: worldChampsRouter,
|
||||
datasets: datasetsRouter,
|
||||
datasetEntries: datasetEntries,
|
||||
projects: projectsRouter,
|
||||
dashboard: dashboardRouter,
|
||||
loggedCalls: loggedCallsRouter,
|
||||
fineTunes: fineTunesRouter,
|
||||
users: usersRouter,
|
||||
adminJobs: adminJobsRouter,
|
||||
});
|
||||
|
||||
@@ -1,145 +0,0 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { requireCanModifyDataset, requireCanViewDataset } from "~/utils/accessControl";
|
||||
import { autogenerateDatasetEntries } from "../autogenerate/autogenerateDatasetEntries";
|
||||
|
||||
export const datasetEntries = createTRPCRouter({
|
||||
list: protectedProcedure
|
||||
.input(z.object({ datasetId: z.string(), page: z.number(), pageSize: z.number() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
await requireCanViewDataset(input.datasetId, ctx);
|
||||
|
||||
const { datasetId, page, pageSize } = input;
|
||||
|
||||
const entries = await prisma.datasetEntry.findMany({
|
||||
where: {
|
||||
datasetId,
|
||||
},
|
||||
orderBy: { createdAt: "desc" },
|
||||
skip: (page - 1) * pageSize,
|
||||
take: pageSize,
|
||||
});
|
||||
|
||||
const count = await prisma.datasetEntry.count({
|
||||
where: {
|
||||
datasetId,
|
||||
},
|
||||
});
|
||||
|
||||
return {
|
||||
entries,
|
||||
count,
|
||||
};
|
||||
}),
|
||||
createOne: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
datasetId: z.string(),
|
||||
input: z.string(),
|
||||
output: z.string().optional(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyDataset(input.datasetId, ctx);
|
||||
|
||||
return await prisma.datasetEntry.create({
|
||||
data: {
|
||||
datasetId: input.datasetId,
|
||||
input: input.input,
|
||||
output: input.output,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
autogenerateEntries: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
datasetId: z.string(),
|
||||
numToGenerate: z.number(),
|
||||
inputDescription: z.string(),
|
||||
outputDescription: z.string(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyDataset(input.datasetId, ctx);
|
||||
|
||||
const dataset = await prisma.dataset.findUnique({
|
||||
where: {
|
||||
id: input.datasetId,
|
||||
},
|
||||
});
|
||||
|
||||
if (!dataset) {
|
||||
throw new Error(`Dataset with id ${input.datasetId} does not exist`);
|
||||
}
|
||||
|
||||
const entries = await autogenerateDatasetEntries(
|
||||
input.numToGenerate,
|
||||
input.inputDescription,
|
||||
input.outputDescription,
|
||||
);
|
||||
|
||||
const createdEntries = await prisma.datasetEntry.createMany({
|
||||
data: entries.map((entry) => ({
|
||||
datasetId: input.datasetId,
|
||||
input: entry.input,
|
||||
output: entry.output,
|
||||
})),
|
||||
});
|
||||
|
||||
return createdEntries;
|
||||
}),
|
||||
|
||||
delete: protectedProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const datasetId = (
|
||||
await prisma.datasetEntry.findUniqueOrThrow({
|
||||
where: { id: input.id },
|
||||
})
|
||||
).datasetId;
|
||||
|
||||
await requireCanModifyDataset(datasetId, ctx);
|
||||
|
||||
return await prisma.datasetEntry.delete({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
update: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
updates: z.object({
|
||||
input: z.string(),
|
||||
output: z.string().optional(),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const existing = await prisma.datasetEntry.findUnique({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
|
||||
if (!existing) {
|
||||
throw new Error(`dataEntry with id ${input.id} does not exist`);
|
||||
}
|
||||
|
||||
await requireCanModifyDataset(existing.datasetId, ctx);
|
||||
|
||||
return await prisma.datasetEntry.update({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
data: {
|
||||
input: input.updates.input,
|
||||
output: input.updates.output,
|
||||
},
|
||||
});
|
||||
}),
|
||||
});
|
||||
@@ -1,88 +0,0 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import {
|
||||
requireCanModifyDataset,
|
||||
requireCanModifyProject,
|
||||
requireCanViewDataset,
|
||||
requireCanViewProject,
|
||||
} from "~/utils/accessControl";
|
||||
|
||||
export const datasetsRouter = createTRPCRouter({
|
||||
list: protectedProcedure
|
||||
.input(z.object({ projectId: z.string() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
await requireCanViewProject(input.projectId, ctx);
|
||||
|
||||
const datasets = await prisma.dataset.findMany({
|
||||
where: {
|
||||
projectId: input.projectId,
|
||||
},
|
||||
orderBy: {
|
||||
createdAt: "desc",
|
||||
},
|
||||
include: {
|
||||
_count: {
|
||||
select: { datasetEntries: true },
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
return datasets;
|
||||
}),
|
||||
|
||||
get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
|
||||
await requireCanViewDataset(input.id, ctx);
|
||||
return await prisma.dataset.findFirstOrThrow({
|
||||
where: { id: input.id },
|
||||
include: {
|
||||
project: true,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
create: protectedProcedure
|
||||
.input(z.object({ projectId: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyProject(input.projectId, ctx);
|
||||
|
||||
const numDatasets = await prisma.dataset.count({
|
||||
where: {
|
||||
projectId: input.projectId,
|
||||
},
|
||||
});
|
||||
|
||||
return await prisma.dataset.create({
|
||||
data: {
|
||||
name: `Dataset ${numDatasets + 1}`,
|
||||
projectId: input.projectId,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
update: protectedProcedure
|
||||
.input(z.object({ id: z.string(), updates: z.object({ name: z.string() }) }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyDataset(input.id, ctx);
|
||||
return await prisma.dataset.update({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
data: {
|
||||
name: input.updates.name,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
delete: protectedProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyDataset(input.id, ctx);
|
||||
|
||||
await prisma.dataset.delete({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
}),
|
||||
});
|
||||
113
app/src/server/api/routers/fineTunes.router.ts
Normal file
113
app/src/server/api/routers/fineTunes.router.ts
Normal file
@@ -0,0 +1,113 @@
|
||||
import { z } from "zod";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import { type Prisma } from "@prisma/client";
|
||||
|
||||
import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { requireCanViewProject, requireCanModifyProject } from "~/utils/accessControl";
|
||||
import { error, success } from "~/utils/errorHandling/standardResponses";
|
||||
|
||||
export const fineTunesRouter = createTRPCRouter({
|
||||
list: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
projectId: z.string(),
|
||||
page: z.number(),
|
||||
pageSize: z.number(),
|
||||
}),
|
||||
)
|
||||
.query(async ({ input, ctx }) => {
|
||||
const { projectId, page, pageSize } = input;
|
||||
|
||||
await requireCanViewProject(projectId, ctx);
|
||||
|
||||
const fineTunes = await prisma.fineTune.findMany({
|
||||
where: {
|
||||
projectId,
|
||||
},
|
||||
include: {
|
||||
dataset: {
|
||||
include: {
|
||||
_count: {
|
||||
select: {
|
||||
datasetEntries: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
orderBy: { createdAt: "asc" },
|
||||
skip: (page - 1) * pageSize,
|
||||
take: pageSize,
|
||||
});
|
||||
|
||||
const count = await prisma.fineTune.count({
|
||||
where: {
|
||||
projectId,
|
||||
},
|
||||
});
|
||||
|
||||
return {
|
||||
fineTunes,
|
||||
count,
|
||||
};
|
||||
}),
|
||||
create: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
projectId: z.string(),
|
||||
selectedLogIds: z.array(z.string()),
|
||||
slug: z.string(),
|
||||
baseModel: z.string(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyProject(input.projectId, ctx);
|
||||
|
||||
const existingFineTune = await prisma.fineTune.findFirst({
|
||||
where: {
|
||||
slug: input.slug,
|
||||
},
|
||||
});
|
||||
|
||||
if (existingFineTune) {
|
||||
return error("A fine tune with that slug already exists");
|
||||
}
|
||||
|
||||
const newDatasetId = uuidv4();
|
||||
|
||||
const datasetEntriesToCreate: Prisma.DatasetEntryCreateManyDatasetInput[] =
|
||||
input.selectedLogIds.map((loggedCallId) => ({
|
||||
loggedCallId,
|
||||
}));
|
||||
|
||||
await prisma.$transaction([
|
||||
prisma.dataset.create({
|
||||
data: {
|
||||
id: newDatasetId,
|
||||
name: input.slug,
|
||||
project: {
|
||||
connect: {
|
||||
id: input.projectId,
|
||||
},
|
||||
},
|
||||
datasetEntries: {
|
||||
createMany: {
|
||||
data: datasetEntriesToCreate,
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
prisma.fineTune.create({
|
||||
data: {
|
||||
projectId: input.projectId,
|
||||
slug: input.slug,
|
||||
baseModel: input.baseModel,
|
||||
datasetId: newDatasetId,
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
||||
return success();
|
||||
}),
|
||||
});
|
||||
@@ -78,33 +78,6 @@ export const requireCanModifyProject = async (projectId: string, ctx: TRPCContex
|
||||
}
|
||||
};
|
||||
|
||||
export const requireCanViewDataset = async (datasetId: string, ctx: TRPCContext) => {
|
||||
ctx.markAccessControlRun();
|
||||
|
||||
const dataset = await prisma.dataset.findFirst({
|
||||
where: {
|
||||
id: datasetId,
|
||||
project: {
|
||||
projectUsers: {
|
||||
some: {
|
||||
role: { in: [ProjectUserRole.ADMIN, ProjectUserRole.MEMBER] },
|
||||
userId: ctx.session?.user.id,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
if (!dataset) {
|
||||
throw new TRPCError({ code: "UNAUTHORIZED" });
|
||||
}
|
||||
};
|
||||
|
||||
export const requireCanModifyDataset = async (datasetId: string, ctx: TRPCContext) => {
|
||||
// Right now all users who can view a dataset can also modify it
|
||||
await requireCanViewDataset(datasetId, ctx);
|
||||
};
|
||||
|
||||
export const requireCanViewExperiment = (experimentId: string, ctx: TRPCContext): Promise<void> => {
|
||||
// Right now all experiments are publicly viewable, so this is a no-op.
|
||||
ctx.markAccessControlRun();
|
||||
|
||||
@@ -26,34 +26,6 @@ export const useExperimentAccess = () => {
|
||||
return useExperiment().data?.access ?? { canView: false, canModify: false };
|
||||
};
|
||||
|
||||
export const useDatasets = () => {
|
||||
const selectedProjectId = useAppStore((state) => state.selectedProjectId);
|
||||
return api.datasets.list.useQuery(
|
||||
{ projectId: selectedProjectId ?? "" },
|
||||
{ enabled: !!selectedProjectId },
|
||||
);
|
||||
};
|
||||
|
||||
export const useDataset = () => {
|
||||
const router = useRouter();
|
||||
const dataset = api.datasets.get.useQuery(
|
||||
{ id: router.query.id as string },
|
||||
{ enabled: !!router.query.id },
|
||||
);
|
||||
|
||||
return dataset;
|
||||
};
|
||||
|
||||
export const useDatasetEntries = () => {
|
||||
const dataset = useDataset();
|
||||
const { page, pageSize } = usePageParams();
|
||||
|
||||
return api.datasetEntries.list.useQuery(
|
||||
{ datasetId: dataset.data?.id ?? "", page, pageSize },
|
||||
{ enabled: dataset.data?.id != null },
|
||||
);
|
||||
};
|
||||
|
||||
type AsyncFunction<T extends unknown[], U> = (...args: T) => Promise<U>;
|
||||
|
||||
export function useHandledAsyncCallback<T extends unknown[], U>(
|
||||
@@ -206,6 +178,16 @@ export const useTagNames = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export const useFineTunes = () => {
|
||||
const selectedProjectId = useAppStore((state) => state.selectedProjectId);
|
||||
const { page, pageSize } = usePageParams();
|
||||
|
||||
return api.fineTunes.list.useQuery(
|
||||
{ projectId: selectedProjectId ?? "", page, pageSize },
|
||||
{ enabled: !!selectedProjectId },
|
||||
);
|
||||
};
|
||||
|
||||
export const useIsClientRehydrated = () => {
|
||||
const isRehydrated = useAppStore((state) => state.isRehydrated);
|
||||
const [isMounted, setIsMounted] = useState(false);
|
||||
|
||||
Reference in New Issue
Block a user