From b98bce8944a269e66307e403e42554d4e7e8143f Mon Sep 17 00:00:00 2001 From: arcticfly <41524992+arcticfly@users.noreply.github.com> Date: Fri, 4 Aug 2023 11:52:03 -0700 Subject: [PATCH] Add Datasets (#118) * Add dataset (without entries) * Fix dataset hook * Add dataset rows * Add buttons to import/generate data * Add GenerateDataModal * Autogenerate and save data * Fix prettier * Fix types * Add dataset pagination * Fix prettier * Use useDisclosure * Allow generate data modal fadeaway * hide/show data in env var * Fix prettier --- @types/nextjs-routes.d.ts | 2 + .../20230804042305_add_datasets/migration.sql | 28 ++++ prisma/schema.prisma | 27 ++++ .../CustomInstructionsInput.tsx | 20 ++- .../OutputsTable/ScenarioPaginator.tsx | 69 +-------- src/components/Paginator.tsx | 79 ++++++++++ .../RefinePromptModal/RefinePromptModal.tsx | 4 +- src/components/datasets/DatasetCard.tsx | 110 ++++++++++++++ .../datasets/DatasetEntriesPaginator.tsx | 21 +++ .../datasets/DatasetEntriesTable.tsx | 43 ++++++ .../DatasetHeaderButtons.tsx | 26 ++++ .../GenerateDataModal.tsx | 100 ++++++++++++ src/components/datasets/TableRow.tsx | 13 ++ .../DeleteDialog.tsx | 0 .../ExperimentHeaderButtons.tsx} | 2 +- .../useOnForkButtonPressed.tsx | 0 src/components/nav/AppShell.tsx | 71 +++++---- src/env.mjs | 2 + src/pages/data/[id].tsx | 99 ++++++++++++ src/pages/data/index.tsx | 83 ++++++++++ src/pages/experiments/[id].tsx | 4 +- .../autogenerate/autogenerateDatasetInputs.ts | 97 ++++++++++++ .../autogenerateScenarioValues.ts} | 23 +-- src/server/api/autogenerate/utils.ts | 18 +++ src/server/api/root.router.ts | 4 + .../api/routers/datasetEntries.router.ts | 143 ++++++++++++++++++ src/server/api/routers/datasets.router.ts | 91 +++++++++++ src/server/api/routers/scenarios.router.ts | 2 +- src/utils/accessControl.ts | 27 ++++ src/utils/hooks.ts | 20 +++ 30 files changed, 1108 insertions(+), 120 deletions(-) create mode 100644 prisma/migrations/20230804042305_add_datasets/migration.sql rename src/components/{RefinePromptModal => }/CustomInstructionsInput.tsx (83%) create mode 100644 src/components/Paginator.tsx create mode 100644 src/components/datasets/DatasetCard.tsx create mode 100644 src/components/datasets/DatasetEntriesPaginator.tsx create mode 100644 src/components/datasets/DatasetEntriesTable.tsx create mode 100644 src/components/datasets/DatasetHeaderButtons/DatasetHeaderButtons.tsx create mode 100644 src/components/datasets/DatasetHeaderButtons/GenerateDataModal.tsx create mode 100644 src/components/datasets/TableRow.tsx rename src/components/experiments/{HeaderButtons => ExperimentHeaderButtons}/DeleteDialog.tsx (100%) rename src/components/experiments/{HeaderButtons/HeaderButtons.tsx => ExperimentHeaderButtons/ExperimentHeaderButtons.tsx} (96%) rename src/components/experiments/{HeaderButtons => ExperimentHeaderButtons}/useOnForkButtonPressed.tsx (100%) create mode 100644 src/pages/data/[id].tsx create mode 100644 src/pages/data/index.tsx create mode 100644 src/server/api/autogenerate/autogenerateDatasetInputs.ts rename src/server/api/{autogen.ts => autogenerate/autogenerateScenarioValues.ts} (85%) create mode 100644 src/server/api/autogenerate/utils.ts create mode 100644 src/server/api/routers/datasetEntries.router.ts create mode 100644 src/server/api/routers/datasets.router.ts diff --git a/@types/nextjs-routes.d.ts b/@types/nextjs-routes.d.ts index ade4d1b..63efc55 100644 --- a/@types/nextjs-routes.d.ts +++ b/@types/nextjs-routes.d.ts @@ -16,6 +16,8 @@ declare module "nextjs-routes" { | StaticRoute<"/api/experiments/og-image"> | StaticRoute<"/api/sentry-example-api"> | DynamicRoute<"/api/trpc/[trpc]", { "trpc": string }> + | DynamicRoute<"/data/[id]", { "id": string }> + | StaticRoute<"/data"> | DynamicRoute<"/experiments/[id]", { "id": string }> | StaticRoute<"/experiments"> | StaticRoute<"/"> diff --git a/prisma/migrations/20230804042305_add_datasets/migration.sql b/prisma/migrations/20230804042305_add_datasets/migration.sql new file mode 100644 index 0000000..0fc039e --- /dev/null +++ b/prisma/migrations/20230804042305_add_datasets/migration.sql @@ -0,0 +1,28 @@ +-- CreateTable +CREATE TABLE "Dataset" ( + "id" UUID NOT NULL, + "name" TEXT NOT NULL, + "organizationId" UUID NOT NULL, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" TIMESTAMP(3) NOT NULL, + + CONSTRAINT "Dataset_pkey" PRIMARY KEY ("id") +); + +-- CreateTable +CREATE TABLE "DatasetEntry" ( + "id" UUID NOT NULL, + "input" TEXT NOT NULL, + "output" TEXT, + "datasetId" UUID NOT NULL, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" TIMESTAMP(3) NOT NULL, + + CONSTRAINT "DatasetEntry_pkey" PRIMARY KEY ("id") +); + +-- AddForeignKey +ALTER TABLE "Dataset" ADD CONSTRAINT "Dataset_organizationId_fkey" FOREIGN KEY ("organizationId") REFERENCES "Organization"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "DatasetEntry" ADD CONSTRAINT "DatasetEntry_datasetId_fkey" FOREIGN KEY ("datasetId") REFERENCES "Dataset"("id") ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/prisma/schema.prisma b/prisma/schema.prisma index 16d4c1c..bceb0b4 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -174,6 +174,32 @@ model OutputEvaluation { @@unique([modelResponseId, evaluationId]) } +model Dataset { + id String @id @default(uuid()) @db.Uuid + + name String + datasetEntries DatasetEntry[] + + organizationId String @db.Uuid + organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) + + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt +} + +model DatasetEntry { + id String @id @default(uuid()) @db.Uuid + + input String + output String? + + datasetId String @db.Uuid + dataset Dataset? @relation(fields: [datasetId], references: [id], onDelete: Cascade) + + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt +} + model Organization { id String @id @default(uuid()) @db.Uuid personalOrgUserId String? @unique @db.Uuid @@ -183,6 +209,7 @@ model Organization { updatedAt DateTime @updatedAt organizationUsers OrganizationUser[] experiments Experiment[] + datasets Dataset[] } enum OrganizationUserRole { diff --git a/src/components/RefinePromptModal/CustomInstructionsInput.tsx b/src/components/CustomInstructionsInput.tsx similarity index 83% rename from src/components/RefinePromptModal/CustomInstructionsInput.tsx rename to src/components/CustomInstructionsInput.tsx index 6bc5aa0..6303b38 100644 --- a/src/components/RefinePromptModal/CustomInstructionsInput.tsx +++ b/src/components/CustomInstructionsInput.tsx @@ -1,18 +1,29 @@ -import { Button, Spinner, InputGroup, InputRightElement, Icon, HStack } from "@chakra-ui/react"; +import { + Button, + Spinner, + InputGroup, + InputRightElement, + Icon, + HStack, + type InputGroupProps, +} from "@chakra-ui/react"; import { IoMdSend } from "react-icons/io"; -import AutoResizeTextArea from "../AutoResizeTextArea"; +import AutoResizeTextArea from "./AutoResizeTextArea"; export const CustomInstructionsInput = ({ instructions, setInstructions, loading, onSubmit, + placeholder = "Send custom instructions", + ...props }: { instructions: string; setInstructions: (instructions: string) => void; loading: boolean; onSubmit: () => void; -}) => { + placeholder?: string; +} & InputGroupProps) => { return ( { - const [page, setPage] = usePage(); const { data } = useScenarios(); if (!data) return null; const { scenarios, startIndex, lastPage, count } = data; - const nextPage = () => { - if (page < lastPage) { - setPage(page + 1, "replace"); - } - }; - - const prevPage = () => { - if (page > 1) { - setPage(page - 1, "replace"); - } - }; - - const goToLastPage = () => setPage(lastPage, "replace"); - const goToFirstPage = () => setPage(1, "replace"); - return ( - - } - /> - } - /> - - {startIndex}-{startIndex + scenarios.length - 1} / {count} - - } - /> - } - /> - + ); }; diff --git a/src/components/Paginator.tsx b/src/components/Paginator.tsx new file mode 100644 index 0000000..424b07e --- /dev/null +++ b/src/components/Paginator.tsx @@ -0,0 +1,79 @@ +import { Box, HStack, IconButton } from "@chakra-ui/react"; +import { + BsChevronDoubleLeft, + BsChevronDoubleRight, + BsChevronLeft, + BsChevronRight, +} from "react-icons/bs"; +import { usePage } from "~/utils/hooks"; + +const Paginator = ({ + numItemsLoaded, + startIndex, + lastPage, + count, +}: { + numItemsLoaded: number; + startIndex: number; + lastPage: number; + count: number; +}) => { + const [page, setPage] = usePage(); + + const nextPage = () => { + if (page < lastPage) { + setPage(page + 1, "replace"); + } + }; + + const prevPage = () => { + if (page > 1) { + setPage(page - 1, "replace"); + } + }; + + const goToLastPage = () => setPage(lastPage, "replace"); + const goToFirstPage = () => setPage(1, "replace"); + + return ( + + } + /> + } + /> + + {startIndex}-{startIndex + numItemsLoaded - 1} / {count} + + } + /> + } + /> + + ); +}; + +export default Paginator; diff --git a/src/components/RefinePromptModal/RefinePromptModal.tsx b/src/components/RefinePromptModal/RefinePromptModal.tsx index 3d9db2f..8e66e37 100644 --- a/src/components/RefinePromptModal/RefinePromptModal.tsx +++ b/src/components/RefinePromptModal/RefinePromptModal.tsx @@ -20,7 +20,7 @@ import { useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks"; import { type PromptVariant } from "@prisma/client"; import { useState } from "react"; import CompareFunctions from "./CompareFunctions"; -import { CustomInstructionsInput } from "./CustomInstructionsInput"; +import { CustomInstructionsInput } from "../CustomInstructionsInput"; import { RefineAction } from "./RefineAction"; import { isObject, isString } from "lodash-es"; import { type RefinementAction, type SupportedProvider } from "~/modelProviders/types"; @@ -122,7 +122,7 @@ export const RefinePromptModal = ({ instructions={instructions} setInstructions={setInstructions} loading={modificationInProgress} - onSubmit={getModifiedPromptFn} + onSubmit={() => getModifiedPromptFn()} /> { + return ( + + + + + {dataset.name} + + + + + + Created {formatTimePast(dataset.createdAt)} + + Updated {formatTimePast(dataset.updatedAt)} + + + + ); +}; + +const CountLabel = ({ label, count }: { label: string; count: number }) => { + return ( + + + {label} + + + {count} + + + ); +}; + +export const NewDatasetCard = () => { + const router = useRouter(); + const createMutation = api.datasets.create.useMutation(); + const [createDataset, isLoading] = useHandledAsyncCallback(async () => { + const newDataset = await createMutation.mutateAsync({ label: "New Dataset" }); + await router.push({ pathname: "/data/[id]", query: { id: newDataset.id } }); + }, [createMutation, router]); + + return ( + + + + + New Dataset + + + + ); +}; + +export const DatasetCardSkeleton = () => ( + + + + + + + +); diff --git a/src/components/datasets/DatasetEntriesPaginator.tsx b/src/components/datasets/DatasetEntriesPaginator.tsx new file mode 100644 index 0000000..51ec09a --- /dev/null +++ b/src/components/datasets/DatasetEntriesPaginator.tsx @@ -0,0 +1,21 @@ +import { useDatasetEntries } from "~/utils/hooks"; +import Paginator from "../Paginator"; + +const DatasetEntriesPaginator = () => { + const { data } = useDatasetEntries(); + + if (!data) return null; + + const { entries, startIndex, lastPage, count } = data; + + return ( + + ); +}; + +export default DatasetEntriesPaginator; diff --git a/src/components/datasets/DatasetEntriesTable.tsx b/src/components/datasets/DatasetEntriesTable.tsx new file mode 100644 index 0000000..8148b2d --- /dev/null +++ b/src/components/datasets/DatasetEntriesTable.tsx @@ -0,0 +1,43 @@ +import { + type StackProps, + VStack, + Table, + Th, + Tr, + Thead, + Tbody, + Text, + HStack, +} from "@chakra-ui/react"; +import { useDatasetEntries } from "~/utils/hooks"; +import TableRow from "./TableRow"; +import DatasetEntriesPaginator from "./DatasetEntriesPaginator"; + +const DatasetEntriesTable = (props: StackProps) => { + const { data } = useDatasetEntries(); + + return ( + + + + + + + + + {data?.entries.map((entry) => )} +
InputOutput
+ {(!data || data.entries.length) === 0 ? ( + + No entries found + + ) : ( + + + + )} +
+ ); +}; + +export default DatasetEntriesTable; diff --git a/src/components/datasets/DatasetHeaderButtons/DatasetHeaderButtons.tsx b/src/components/datasets/DatasetHeaderButtons/DatasetHeaderButtons.tsx new file mode 100644 index 0000000..840b8e0 --- /dev/null +++ b/src/components/datasets/DatasetHeaderButtons/DatasetHeaderButtons.tsx @@ -0,0 +1,26 @@ +import { Button, HStack, useDisclosure } from "@chakra-ui/react"; +import { BiImport } from "react-icons/bi"; +import { BsStars } from "react-icons/bs"; + +import { GenerateDataModal } from "./GenerateDataModal"; + +export const DatasetHeaderButtons = () => { + const generateModalDisclosure = useDisclosure(); + + return ( + <> + + + + + + + ); +}; diff --git a/src/components/datasets/DatasetHeaderButtons/GenerateDataModal.tsx b/src/components/datasets/DatasetHeaderButtons/GenerateDataModal.tsx new file mode 100644 index 0000000..d715954 --- /dev/null +++ b/src/components/datasets/DatasetHeaderButtons/GenerateDataModal.tsx @@ -0,0 +1,100 @@ +import { + Modal, + ModalBody, + ModalCloseButton, + ModalContent, + ModalHeader, + ModalOverlay, + ModalFooter, + Text, + HStack, + VStack, + Icon, + NumberInput, + NumberInputField, + NumberInputStepper, + NumberIncrementStepper, + NumberDecrementStepper, +} from "@chakra-ui/react"; +import { BsStars } from "react-icons/bs"; +import { useState } from "react"; +import { CustomInstructionsInput } from "~/components/CustomInstructionsInput"; +import { useDataset, useHandledAsyncCallback } from "~/utils/hooks"; +import { api } from "~/utils/api"; + +export const GenerateDataModal = ({ + isOpen, + onClose, +}: { + isOpen: boolean; + onClose: () => void; +}) => { + const utils = api.useContext(); + + const datasetId = useDataset().data?.id; + const [instructions, setInstructions] = useState( + "Each row should contain an email body. Half of the emails should contain event details, and the other half should not.", + ); + const [numToGenerate, setNumToGenerate] = useState(20); + + const generateInputsMutation = api.datasetEntries.autogenerateInputs.useMutation(); + + const [generateEntries, generateEntriesInProgress] = useHandledAsyncCallback(async () => { + if (!instructions || !numToGenerate || !datasetId) return; + await generateInputsMutation.mutateAsync({ + datasetId, + instructions, + numToGenerate, + }); + await utils.datasetEntries.list.invalidate(); + onClose(); + }, [generateInputsMutation, onClose, instructions, numToGenerate, datasetId]); + + return ( + + + + + + + Generate Data + + + + + + + Number of Rows: + setNumToGenerate(parseInt(valueString) || 0)} + value={numToGenerate} + w="24" + > + + + + + + + + + Row Description: + + + + + + + + ); +}; diff --git a/src/components/datasets/TableRow.tsx b/src/components/datasets/TableRow.tsx new file mode 100644 index 0000000..08ad2bb --- /dev/null +++ b/src/components/datasets/TableRow.tsx @@ -0,0 +1,13 @@ +import { Td, Tr } from "@chakra-ui/react"; +import { type DatasetEntry } from "@prisma/client"; + +const TableRow = ({ entry }: { entry: DatasetEntry }) => { + return ( + + {entry.input} + {entry.output} + + ); +}; + +export default TableRow; diff --git a/src/components/experiments/HeaderButtons/DeleteDialog.tsx b/src/components/experiments/ExperimentHeaderButtons/DeleteDialog.tsx similarity index 100% rename from src/components/experiments/HeaderButtons/DeleteDialog.tsx rename to src/components/experiments/ExperimentHeaderButtons/DeleteDialog.tsx diff --git a/src/components/experiments/HeaderButtons/HeaderButtons.tsx b/src/components/experiments/ExperimentHeaderButtons/ExperimentHeaderButtons.tsx similarity index 96% rename from src/components/experiments/HeaderButtons/HeaderButtons.tsx rename to src/components/experiments/ExperimentHeaderButtons/ExperimentHeaderButtons.tsx index 550d2a4..97960b6 100644 --- a/src/components/experiments/HeaderButtons/HeaderButtons.tsx +++ b/src/components/experiments/ExperimentHeaderButtons/ExperimentHeaderButtons.tsx @@ -5,7 +5,7 @@ import { BsGearFill } from "react-icons/bs"; import { TbGitFork } from "react-icons/tb"; import { useAppStore } from "~/state/store"; -export const HeaderButtons = () => { +export const ExperimentHeaderButtons = () => { const experiment = useExperiment(); const canModify = experiment.data?.access.canModify ?? false; diff --git a/src/components/experiments/HeaderButtons/useOnForkButtonPressed.tsx b/src/components/experiments/ExperimentHeaderButtons/useOnForkButtonPressed.tsx similarity index 100% rename from src/components/experiments/HeaderButtons/useOnForkButtonPressed.tsx rename to src/components/experiments/ExperimentHeaderButtons/useOnForkButtonPressed.tsx diff --git a/src/components/nav/AppShell.tsx b/src/components/nav/AppShell.tsx index 5610286..b6aec3e 100644 --- a/src/components/nav/AppShell.tsx +++ b/src/components/nav/AppShell.tsx @@ -8,42 +8,43 @@ import { Text, Box, type BoxProps, - type LinkProps, - Link, + Link as ChakraLink, Flex, } from "@chakra-ui/react"; import Head from "next/head"; +import Link, { type LinkProps } from "next/link"; import { BsGithub, BsPersonCircle } from "react-icons/bs"; import { useRouter } from "next/router"; import { type IconType } from "react-icons"; -import { RiFlaskLine } from "react-icons/ri"; +import { RiDatabase2Line, RiFlaskLine } from "react-icons/ri"; import { signIn, useSession } from "next-auth/react"; import UserMenu from "./UserMenu"; +import { env } from "~/env.mjs"; -type IconLinkProps = BoxProps & LinkProps & { label?: string; icon: IconType }; +type IconLinkProps = BoxProps & LinkProps & { label?: string; icon: IconType; href: string }; -const IconLink = ({ icon, label, href, target, color, ...props }: IconLinkProps) => { +const IconLink = ({ icon, label, href, color, ...props }: IconLinkProps) => { const router = useRouter(); const isActive = href && router.pathname.startsWith(href); return ( - - - - {label} - - + + + + + {label} + + + ); }; @@ -72,16 +73,28 @@ const NavSidebar = () => { {user != null && ( <> + {env.NEXT_PUBLIC_SHOW_DATA && ( + + )} )} {user === null && ( - { signIn("github").catch(console.error); }} - /> + > + + + Sign In + + )} {user ? ( @@ -90,7 +103,7 @@ const NavSidebar = () => { )} - { p={2} > - + ); diff --git a/src/env.mjs b/src/env.mjs index 686def3..bec0e5e 100644 --- a/src/env.mjs +++ b/src/env.mjs @@ -32,6 +32,7 @@ export const env = createEnv({ NEXT_PUBLIC_SOCKET_URL: z.string().url().default("http://localhost:3318"), NEXT_PUBLIC_HOST: z.string().url().default("http://localhost:3000"), NEXT_PUBLIC_SENTRY_DSN: z.string().optional(), + NEXT_PUBLIC_SHOW_DATA: z.string().optional(), }, /** @@ -46,6 +47,7 @@ export const env = createEnv({ NEXT_PUBLIC_POSTHOG_KEY: process.env.NEXT_PUBLIC_POSTHOG_KEY, NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL, NEXT_PUBLIC_HOST: process.env.NEXT_PUBLIC_HOST, + NEXT_PUBLIC_SHOW_DATA: process.env.NEXT_PUBLIC_SHOW_DATA, GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID, GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET, REPLICATE_API_TOKEN: process.env.REPLICATE_API_TOKEN, diff --git a/src/pages/data/[id].tsx b/src/pages/data/[id].tsx new file mode 100644 index 0000000..59d890c --- /dev/null +++ b/src/pages/data/[id].tsx @@ -0,0 +1,99 @@ +import { + Box, + Breadcrumb, + BreadcrumbItem, + Center, + Flex, + Icon, + Input, + VStack, +} from "@chakra-ui/react"; +import Link from "next/link"; + +import { useRouter } from "next/router"; +import { useState, useEffect } from "react"; +import { RiDatabase2Line } from "react-icons/ri"; +import AppShell from "~/components/nav/AppShell"; +import { api } from "~/utils/api"; +import { useDataset, useHandledAsyncCallback } from "~/utils/hooks"; +import DatasetEntriesTable from "~/components/datasets/DatasetEntriesTable"; +import { DatasetHeaderButtons } from "~/components/datasets/DatasetHeaderButtons/DatasetHeaderButtons"; + +export default function Dataset() { + const router = useRouter(); + const utils = api.useContext(); + + const dataset = useDataset(); + const datasetId = router.query.id as string; + + const [name, setName] = useState(dataset.data?.name || ""); + useEffect(() => { + setName(dataset.data?.name || ""); + }, [dataset.data?.name]); + + const updateMutation = api.datasets.update.useMutation(); + const [onSaveName] = useHandledAsyncCallback(async () => { + if (name && name !== dataset.data?.name && dataset.data?.id) { + await updateMutation.mutateAsync({ + id: dataset.data.id, + updates: { name: name }, + }); + await Promise.all([utils.datasets.list.invalidate(), utils.datasets.get.invalidate()]); + } + }, [updateMutation, dataset.data?.id, dataset.data?.name, name]); + + if (!dataset.isLoading && !dataset.data) { + return ( + +
+
Dataset not found 😕
+
+
+ ); + } + + return ( + + + + + + + + Datasets + + + + + setName(e.target.value)} + onBlur={onSaveName} + borderWidth={1} + borderColor="transparent" + fontSize={16} + px={0} + minW={{ base: 100, lg: 300 }} + flex={1} + _hover={{ borderColor: "gray.300" }} + _focus={{ borderColor: "blue.500", outline: "none" }} + /> + + + + + + {datasetId && } + + + + ); +} diff --git a/src/pages/data/index.tsx b/src/pages/data/index.tsx new file mode 100644 index 0000000..b1e581e --- /dev/null +++ b/src/pages/data/index.tsx @@ -0,0 +1,83 @@ +import { + SimpleGrid, + Icon, + VStack, + Breadcrumb, + BreadcrumbItem, + Flex, + Center, + Text, + Link, + HStack, +} from "@chakra-ui/react"; +import AppShell from "~/components/nav/AppShell"; +import { api } from "~/utils/api"; +import { signIn, useSession } from "next-auth/react"; +import { RiDatabase2Line } from "react-icons/ri"; +import { + DatasetCard, + DatasetCardSkeleton, + NewDatasetCard, +} from "~/components/datasets/DatasetCard"; + +export default function DatasetsPage() { + const datasets = api.datasets.list.useQuery(); + + const user = useSession().data; + const authLoading = useSession().status === "loading"; + + if (user === null || authLoading) { + return ( + +
+ {!authLoading && ( + + { + signIn("github").catch(console.error); + }} + textDecor="underline" + > + Sign in + {" "} + to view or create new datasets! + + )} +
+
+ ); + } + + return ( + + + + + + + Datasets + + + + + + + {datasets.data && !datasets.isLoading ? ( + datasets?.data?.map((dataset) => ( + + )) + ) : ( + <> + + + + + )} + + + + ); +} diff --git a/src/pages/experiments/[id].tsx b/src/pages/experiments/[id].tsx index c06c9d7..998b19c 100644 --- a/src/pages/experiments/[id].tsx +++ b/src/pages/experiments/[id].tsx @@ -21,7 +21,7 @@ import { api } from "~/utils/api"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import { useAppStore } from "~/state/store"; import { useSyncVariantEditor } from "~/state/sync"; -import { HeaderButtons } from "~/components/experiments/HeaderButtons/HeaderButtons"; +import { ExperimentHeaderButtons } from "~/components/experiments/ExperimentHeaderButtons/ExperimentHeaderButtons"; import Head from "next/head"; // TODO: import less to fix deployment with server side props @@ -142,7 +142,7 @@ export default function Experiment() { )} - + diff --git a/src/server/api/autogenerate/autogenerateDatasetInputs.ts b/src/server/api/autogenerate/autogenerateDatasetInputs.ts new file mode 100644 index 0000000..0f6ab97 --- /dev/null +++ b/src/server/api/autogenerate/autogenerateDatasetInputs.ts @@ -0,0 +1,97 @@ +import { type ChatCompletion } from "openai/resources/chat"; +import { openai } from "../../utils/openai"; +import { isAxiosError } from "./utils"; +import { type APIResponse } from "openai/core"; +import { sleep } from "~/server/utils/sleep"; + +const MAX_AUTO_RETRIES = 50; +const MIN_DELAY = 500; // milliseconds +const MAX_DELAY = 15000; // milliseconds + +function calculateDelay(numPreviousTries: number): number { + const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries)); + const jitter = Math.random() * baseDelay; + return baseDelay + jitter; +} + +const getCompletionWithBackoff = async ( + getCompletion: () => Promise>, +) => { + let completion; + let tries = 0; + while (tries < MAX_AUTO_RETRIES) { + try { + completion = await getCompletion(); + break; + } catch (e) { + if (isAxiosError(e)) { + console.error(e?.response?.data?.error?.message); + } else { + await sleep(calculateDelay(tries)); + console.error(e); + } + } + tries++; + } + return completion; +}; + +const MAX_BATCH_SIZE = 5; + +export const autogenerateDatasetInputs = async ( + numToGenerate: number, + customInstructions: string, +): Promise => { + const batchSizes = Array.from({ length: Math.ceil(numToGenerate / MAX_BATCH_SIZE) }, (_, i) => + i === Math.ceil(numToGenerate / MAX_BATCH_SIZE) - 1 + ? numToGenerate % MAX_BATCH_SIZE + : MAX_BATCH_SIZE, + ); + + const getCompletion = (batchSize: number) => + openai.chat.completions.create({ + model: "gpt-4", + messages: [ + { + role: "system", + content: `The user needs ${batchSize} rows of data that match the following instructions:\n---\n" + ${customInstructions}`, + }, + ], + functions: [ + { + name: "add_list_of_data", + description: "Add a list of data to the database", + parameters: { + type: "object", + properties: { + rows: { + type: "array", + description: "The rows of data that match the instructions", + items: { + type: "string", + }, + }, + }, + }, + }, + ], + + function_call: { name: "add_list_of_data" }, + temperature: 0.5, + }); + + const completionCallbacks = batchSizes.map((batchSize) => + getCompletionWithBackoff(() => getCompletion(batchSize)), + ); + + const completions = await Promise.all(completionCallbacks); + + const rows = completions.flatMap((completion) => { + const parsed = JSON.parse( + completion?.choices[0]?.message?.function_call?.arguments ?? "{rows: []}", + ) as { rows: string[] }; + return parsed.rows; + }); + + return rows; +}; diff --git a/src/server/api/autogen.ts b/src/server/api/autogenerate/autogenerateScenarioValues.ts similarity index 85% rename from src/server/api/autogen.ts rename to src/server/api/autogenerate/autogenerateScenarioValues.ts index fcfc8dd..25689ab 100644 --- a/src/server/api/autogen.ts +++ b/src/server/api/autogenerate/autogenerateScenarioValues.ts @@ -1,26 +1,9 @@ import { type CompletionCreateParams } from "openai/resources/chat"; -import { prisma } from "../db"; -import { openai } from "../utils/openai"; +import { prisma } from "../../db"; +import { openai } from "../../utils/openai"; import { pick } from "lodash-es"; +import { isAxiosError } from "./utils"; -type AxiosError = { - response?: { - data?: { - error?: { - message?: string; - }; - }; - }; -}; - -function isAxiosError(error: unknown): error is AxiosError { - if (typeof error === "object" && error !== null) { - // Initial check - const err = error as AxiosError; - return err.response?.data?.error?.message !== undefined; // Check structure - } - return false; -} export const autogenerateScenarioValues = async ( experimentId: string, ): Promise> => { diff --git a/src/server/api/autogenerate/utils.ts b/src/server/api/autogenerate/utils.ts new file mode 100644 index 0000000..6b63719 --- /dev/null +++ b/src/server/api/autogenerate/utils.ts @@ -0,0 +1,18 @@ +type AxiosError = { + response?: { + data?: { + error?: { + message?: string; + }; + }; + }; +}; + +export function isAxiosError(error: unknown): error is AxiosError { + if (typeof error === "object" && error !== null) { + // Initial check + const err = error as AxiosError; + return err.response?.data?.error?.message !== undefined; // Check structure + } + return false; +} diff --git a/src/server/api/root.router.ts b/src/server/api/root.router.ts index 345e301..6ec1d8e 100644 --- a/src/server/api/root.router.ts +++ b/src/server/api/root.router.ts @@ -6,6 +6,8 @@ import { scenarioVariantCellsRouter } from "./routers/scenarioVariantCells.route import { templateVarsRouter } from "./routers/templateVariables.router"; import { evaluationsRouter } from "./routers/evaluations.router"; import { worldChampsRouter } from "./routers/worldChamps.router"; +import { datasetsRouter } from "./routers/datasets.router"; +import { datasetEntries } from "./routers/datasetEntries.router"; /** * This is the primary router for your server. @@ -20,6 +22,8 @@ export const appRouter = createTRPCRouter({ templateVars: templateVarsRouter, evaluations: evaluationsRouter, worldChamps: worldChampsRouter, + datasets: datasetsRouter, + datasetEntries: datasetEntries, }); // export type definition of API diff --git a/src/server/api/routers/datasetEntries.router.ts b/src/server/api/routers/datasetEntries.router.ts new file mode 100644 index 0000000..1bf83b5 --- /dev/null +++ b/src/server/api/routers/datasetEntries.router.ts @@ -0,0 +1,143 @@ +import { z } from "zod"; +import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc"; +import { prisma } from "~/server/db"; +import { requireCanModifyDataset, requireCanViewDataset } from "~/utils/accessControl"; +import { autogenerateDatasetInputs } from "../autogenerate/autogenerateDatasetInputs"; + +const PAGE_SIZE = 10; + +export const datasetEntries = createTRPCRouter({ + list: protectedProcedure + .input(z.object({ datasetId: z.string(), page: z.number() })) + .query(async ({ input, ctx }) => { + await requireCanViewDataset(input.datasetId, ctx); + + const { datasetId, page } = input; + + const entries = await prisma.datasetEntry.findMany({ + where: { + datasetId, + }, + orderBy: { createdAt: "asc" }, + skip: (page - 1) * PAGE_SIZE, + take: PAGE_SIZE, + }); + + const count = await prisma.datasetEntry.count({ + where: { + datasetId, + }, + }); + + return { + entries, + startIndex: (page - 1) * PAGE_SIZE + 1, + lastPage: Math.ceil(count / PAGE_SIZE), + count, + }; + }), + createOne: protectedProcedure + .input( + z.object({ + datasetId: z.string(), + input: z.string(), + output: z.string().optional(), + }), + ) + .mutation(async ({ input, ctx }) => { + await requireCanModifyDataset(input.datasetId, ctx); + + return await prisma.datasetEntry.create({ + data: { + datasetId: input.datasetId, + input: input.input, + output: input.output, + }, + }); + }), + + autogenerateInputs: protectedProcedure + .input( + z.object({ + datasetId: z.string(), + numToGenerate: z.number(), + instructions: z.string(), + }), + ) + .mutation(async ({ input, ctx }) => { + await requireCanModifyDataset(input.datasetId, ctx); + + const dataset = await prisma.dataset.findUnique({ + where: { + id: input.datasetId, + }, + }); + + if (!dataset) { + throw new Error(`Dataset with id ${input.datasetId} does not exist`); + } + + const entryInputs = await autogenerateDatasetInputs(input.numToGenerate, input.instructions); + + const createdEntries = await prisma.datasetEntry.createMany({ + data: entryInputs.map((entryInput) => ({ + datasetId: input.datasetId, + input: entryInput, + })), + }); + + return createdEntries; + }), + + delete: protectedProcedure + .input(z.object({ id: z.string() })) + .mutation(async ({ input, ctx }) => { + const datasetId = ( + await prisma.datasetEntry.findUniqueOrThrow({ + where: { id: input.id }, + }) + ).datasetId; + + await requireCanModifyDataset(datasetId, ctx); + + return await prisma.datasetEntry.delete({ + where: { + id: input.id, + }, + }); + }), + + update: protectedProcedure + .input( + z.object({ + id: z.string(), + updates: z.object({ + input: z.string(), + output: z.string().optional(), + }), + }), + ) + .mutation(async ({ input, ctx }) => { + const existing = await prisma.datasetEntry.findUnique({ + where: { + id: input.id, + }, + }); + + if (!existing) { + throw new Error(`dataEntry with id ${input.id} does not exist`); + } + + await requireCanModifyDataset(existing.datasetId, ctx); + + return await prisma.datasetEntry.update({ + where: { + id: input.id, + }, + data: { + input: input.updates.input, + output: input.updates.output, + }, + }); + }), +}); diff --git a/src/server/api/routers/datasets.router.ts b/src/server/api/routers/datasets.router.ts new file mode 100644 index 0000000..b25fde4 --- /dev/null +++ b/src/server/api/routers/datasets.router.ts @@ -0,0 +1,91 @@ +import { z } from "zod"; +import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc"; +import { prisma } from "~/server/db"; +import { + requireCanModifyDataset, + requireCanViewDataset, + requireNothing, +} from "~/utils/accessControl"; +import userOrg from "~/server/utils/userOrg"; + +export const datasetsRouter = createTRPCRouter({ + list: protectedProcedure.query(async ({ ctx }) => { + // Anyone can list experiments + requireNothing(ctx); + + const datasets = await prisma.dataset.findMany({ + where: { + organization: { + organizationUsers: { + some: { userId: ctx.session.user.id }, + }, + }, + }, + orderBy: { + createdAt: "desc", + }, + include: { + _count: { + select: { datasetEntries: true }, + }, + }, + }); + + return datasets; + }), + + get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => { + await requireCanViewDataset(input.id, ctx); + return await prisma.dataset.findFirstOrThrow({ + where: { id: input.id }, + }); + }), + + create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => { + // Anyone can create an experiment + requireNothing(ctx); + + const numDatasets = await prisma.dataset.count({ + where: { + organization: { + organizationUsers: { + some: { userId: ctx.session.user.id }, + }, + }, + }, + }); + + return await prisma.dataset.create({ + data: { + name: `Dataset ${numDatasets + 1}`, + organizationId: (await userOrg(ctx.session.user.id)).id, + }, + }); + }), + + update: protectedProcedure + .input(z.object({ id: z.string(), updates: z.object({ name: z.string() }) })) + .mutation(async ({ input, ctx }) => { + await requireCanModifyDataset(input.id, ctx); + return await prisma.dataset.update({ + where: { + id: input.id, + }, + data: { + name: input.updates.name, + }, + }); + }), + + delete: protectedProcedure + .input(z.object({ id: z.string() })) + .mutation(async ({ input, ctx }) => { + await requireCanModifyDataset(input.id, ctx); + + await prisma.dataset.delete({ + where: { + id: input.id, + }, + }); + }), +}); diff --git a/src/server/api/routers/scenarios.router.ts b/src/server/api/routers/scenarios.router.ts index acaf7fa..9211556 100644 --- a/src/server/api/routers/scenarios.router.ts +++ b/src/server/api/routers/scenarios.router.ts @@ -1,7 +1,7 @@ import { z } from "zod"; import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc"; import { prisma } from "~/server/db"; -import { autogenerateScenarioValues } from "../autogen"; +import { autogenerateScenarioValues } from "../autogenerate/autogenerateScenarioValues"; import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated"; import { runAllEvals } from "~/server/utils/evaluations"; import { generateNewCell } from "~/server/utils/generateNewCell"; diff --git a/src/utils/accessControl.ts b/src/utils/accessControl.ts index bf19351..5fc1fd4 100644 --- a/src/utils/accessControl.ts +++ b/src/utils/accessControl.ts @@ -16,6 +16,33 @@ export const requireNothing = (ctx: TRPCContext) => { ctx.markAccessControlRun(); }; +export const requireCanViewDataset = async (datasetId: string, ctx: TRPCContext) => { + const dataset = await prisma.dataset.findFirst({ + where: { + id: datasetId, + organization: { + organizationUsers: { + some: { + role: { in: [OrganizationUserRole.ADMIN, OrganizationUserRole.MEMBER] }, + userId: ctx.session?.user.id, + }, + }, + }, + }, + }); + + if (!dataset) { + throw new TRPCError({ code: "UNAUTHORIZED" }); + } + + ctx.markAccessControlRun(); +}; + +export const requireCanModifyDataset = async (datasetId: string, ctx: TRPCContext) => { + // Right now all users who can view a dataset can also modify it + await requireCanViewDataset(datasetId, ctx); +}; + export const requireCanViewExperiment = async (experimentId: string, ctx: TRPCContext) => { await prisma.experiment.findFirst({ where: { id: experimentId }, diff --git a/src/utils/hooks.ts b/src/utils/hooks.ts index 75b9ddc..451c7dc 100644 --- a/src/utils/hooks.ts +++ b/src/utils/hooks.ts @@ -17,6 +17,26 @@ export const useExperimentAccess = () => { return useExperiment().data?.access ?? { canView: false, canModify: false }; }; +export const useDataset = () => { + const router = useRouter(); + const dataset = api.datasets.get.useQuery( + { id: router.query.id as string }, + { enabled: !!router.query.id }, + ); + + return dataset; +}; + +export const useDatasetEntries = () => { + const dataset = useDataset(); + const [page] = usePage(); + + return api.datasetEntries.list.useQuery( + { datasetId: dataset.data?.id ?? "", page }, + { enabled: dataset.data?.id != null }, + ); +}; + type AsyncFunction = (...args: T) => Promise; export function useHandledAsyncCallback(