Allow experiment forking (#89)

* Move DeleteButton into a separate file

* Rename plural relations

* Add ability to fork

* Fork automatically after auth upon return

* Add experiment card skeleton

* Create HeaderButtons component

* return no header buttons while experiment loading

* Fix prettier

* Remove unused variable

* Remove newline

* Default json values to undefined

* Change header styles

* Fix prettier

* Give AddScenario icon less width

* Move useEffect

* Skip invalidating experiments list after forking

* Require user to be able to view experiment to fork it

* Move experiment creation into same transaction

* Only return the forked experiment id

* Put delete button in experiment settings drawer

* Move useEffect hook
This commit is contained in:
arcticfly
2023-07-24 18:10:59 -07:00
committed by GitHub
parent 09140f8b5f
commit d6b97b29f7
18 changed files with 493 additions and 131 deletions

View File

@@ -81,6 +81,7 @@
"tsx": "^3.12.7", "tsx": "^3.12.7",
"type-fest": "^4.0.0", "type-fest": "^4.0.0",
"use-query-params": "^2.2.1", "use-query-params": "^2.2.1",
"uuid": "^9.0.0",
"vite-tsconfig-paths": "^4.2.0", "vite-tsconfig-paths": "^4.2.0",
"zod": "^3.21.4", "zod": "^3.21.4",
"zustand": "^4.3.9" "zustand": "^4.3.9"
@@ -101,6 +102,7 @@
"@types/react": "^18.2.6", "@types/react": "^18.2.6",
"@types/react-dom": "^18.2.4", "@types/react-dom": "^18.2.4",
"@types/react-syntax-highlighter": "^15.5.7", "@types/react-syntax-highlighter": "^15.5.7",
"@types/uuid": "^9.0.2",
"@typescript-eslint/eslint-plugin": "^5.59.6", "@typescript-eslint/eslint-plugin": "^5.59.6",
"@typescript-eslint/parser": "^5.59.6", "@typescript-eslint/parser": "^5.59.6",
"eslint": "^8.40.0", "eslint": "^8.40.0",

17
pnpm-lock.yaml generated
View File

@@ -1,4 +1,4 @@
lockfileVersion: '6.1' lockfileVersion: '6.0'
settings: settings:
autoInstallPeers: true autoInstallPeers: true
@@ -185,6 +185,9 @@ dependencies:
use-query-params: use-query-params:
specifier: ^2.2.1 specifier: ^2.2.1
version: 2.2.1(react-dom@18.2.0)(react@18.2.0) version: 2.2.1(react-dom@18.2.0)(react@18.2.0)
uuid:
specifier: ^9.0.0
version: 9.0.0
vite-tsconfig-paths: vite-tsconfig-paths:
specifier: ^4.2.0 specifier: ^4.2.0
version: 4.2.0(typescript@5.0.4) version: 4.2.0(typescript@5.0.4)
@@ -241,6 +244,9 @@ devDependencies:
'@types/react-syntax-highlighter': '@types/react-syntax-highlighter':
specifier: ^15.5.7 specifier: ^15.5.7
version: 15.5.7 version: 15.5.7
'@types/uuid':
specifier: ^9.0.2
version: 9.0.2
'@typescript-eslint/eslint-plugin': '@typescript-eslint/eslint-plugin':
specifier: ^5.59.6 specifier: ^5.59.6
version: 5.59.6(@typescript-eslint/parser@5.59.6)(eslint@8.40.0)(typescript@5.0.4) version: 5.59.6(@typescript-eslint/parser@5.59.6)(eslint@8.40.0)(typescript@5.0.4)
@@ -3024,6 +3030,10 @@ packages:
resolution: {integrity: sha512-cputDpIbFgLUaGQn6Vqg3/YsJwxUwHLO13v3i5ouxT4lat0khip9AEWxtERujXV9wxIB1EyF97BSJFt6vpdI8g==} resolution: {integrity: sha512-cputDpIbFgLUaGQn6Vqg3/YsJwxUwHLO13v3i5ouxT4lat0khip9AEWxtERujXV9wxIB1EyF97BSJFt6vpdI8g==}
dev: false dev: false
/@types/uuid@9.0.2:
resolution: {integrity: sha512-kNnC1GFBLuhImSnV7w4njQkUiJi0ZXUycu1rUaouPqiKlXkh77JKgdRnTAp1x5eBwcIwbtI+3otwzuIDEuDoxQ==}
dev: true
/@typescript-eslint/eslint-plugin@5.59.6(@typescript-eslint/parser@5.59.6)(eslint@8.40.0)(typescript@5.0.4): /@typescript-eslint/eslint-plugin@5.59.6(@typescript-eslint/parser@5.59.6)(eslint@8.40.0)(typescript@5.0.4):
resolution: {integrity: sha512-sXtOgJNEuRU5RLwPUb1jxtToZbgvq3M6FPpY4QENxoOggK+UpTxUBpj6tD8+Qh2g46Pi9We87E+eHnUw8YcGsw==} resolution: {integrity: sha512-sXtOgJNEuRU5RLwPUb1jxtToZbgvq3M6FPpY4QENxoOggK+UpTxUBpj6tD8+Qh2g46Pi9We87E+eHnUw8YcGsw==}
engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0} engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0}
@@ -7913,6 +7923,11 @@ packages:
hasBin: true hasBin: true
dev: false dev: false
/uuid@9.0.0:
resolution: {integrity: sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg==}
hasBin: true
dev: false
/vary@1.1.2: /vary@1.1.2:
resolution: {integrity: sha512-BNGbWLfd0eUPabhkXUVm0j8uuvREyTh5ovRa/dyow/BqAbZJyC+5fU+IzQOzmAKzYqYRAISoRhdQr3eIZ/PXqg==} resolution: {integrity: sha512-BNGbWLfd0eUPabhkXUVm0j8uuvREyTh5ovRa/dyow/BqAbZJyC+5fU+IzQOzmAKzYqYRAISoRhdQr3eIZ/PXqg==}
engines: {node: '>= 0.8'} engines: {node: '>= 0.8'}

View File

@@ -22,10 +22,10 @@ model Experiment {
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
TemplateVariable TemplateVariable[] templateVariables TemplateVariable[]
PromptVariant PromptVariant[] promptVariants PromptVariant[]
TestScenario TestScenario[] testScenarios TestScenario[]
Evaluation Evaluation[] evaluations Evaluation[]
} }
model PromptVariant { model PromptVariant {
@@ -126,7 +126,7 @@ model ModelOutput {
scenarioVariantCellId String @db.Uuid scenarioVariantCellId String @db.Uuid
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade) scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
outputEvaluation OutputEvaluation[] outputEvaluations OutputEvaluation[]
@@unique([scenarioVariantCellId]) @@unique([scenarioVariantCellId])
@@index([inputHash]) @@index([inputHash])
@@ -150,7 +150,7 @@ model Evaluation {
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
OutputEvaluation OutputEvaluation[] outputEvaluations OutputEvaluation[]
} }
model OutputEvaluation { model OutputEvaluation {
@@ -179,8 +179,8 @@ model Organization {
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
OrganizationUser OrganizationUser[] organizationUsers OrganizationUser[]
Experiment Experiment[] experiments Experiment[]
} }
enum OrganizationUserRole { enum OrganizationUserRole {
@@ -241,8 +241,8 @@ model User {
image String? image String?
accounts Account[] accounts Account[]
sessions Session[] sessions Session[]
OrganizationUser OrganizationUser[] organizationUsers OrganizationUser[]
Organization Organization[] organizations Organization[]
} }
model VerificationToken { model VerificationToken {

View File

@@ -0,0 +1,77 @@
import {
Button,
Icon,
AlertDialog,
AlertDialogBody,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogContent,
AlertDialogOverlay,
useDisclosure,
Text,
} from "@chakra-ui/react";
import { useRouter } from "next/router";
import { useRef } from "react";
import { BsTrash } from "react-icons/bs";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
export const DeleteButton = () => {
const experiment = useExperiment();
const mutation = api.experiments.delete.useMutation();
const utils = api.useContext();
const router = useRouter();
const { isOpen, onOpen, onClose } = useDisclosure();
const cancelRef = useRef<HTMLButtonElement>(null);
const [onDeleteConfirm] = useHandledAsyncCallback(async () => {
if (!experiment.data?.id) return;
await mutation.mutateAsync({ id: experiment.data.id });
await utils.experiments.list.invalidate();
await router.push({ pathname: "/experiments" });
onClose();
}, [mutation, experiment.data?.id, router]);
return (
<>
<Button
size="sm"
variant={{ base: "outline", lg: "ghost" }}
colorScheme="red"
fontWeight="normal"
onClick={onOpen}
>
<Icon as={BsTrash} boxSize={4} />
<Text display={{ base: "none", lg: "block" }} ml={2}>
Delete Experiment
</Text>
</Button>
<AlertDialog isOpen={isOpen} leastDestructiveRef={cancelRef} onClose={onClose}>
<AlertDialogOverlay>
<AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold">
Delete Experiment
</AlertDialogHeader>
<AlertDialogBody>
If you delete this experiment all the associated prompts and scenarios will be deleted
as well. Are you sure?
</AlertDialogBody>
<AlertDialogFooter>
<Button ref={cancelRef} onClick={onClose}>
Cancel
</Button>
<Button colorScheme="red" onClick={onDeleteConfirm} ml={3}>
Delete
</Button>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialogOverlay>
</AlertDialog>
</>
);
};

View File

@@ -6,13 +6,14 @@ import {
DrawerHeader, DrawerHeader,
DrawerOverlay, DrawerOverlay,
Heading, Heading,
Stack, VStack,
} from "@chakra-ui/react"; } from "@chakra-ui/react";
import EditScenarioVars from "./EditScenarioVars"; import EditScenarioVars from "../OutputsTable/EditScenarioVars";
import EditEvaluations from "./EditEvaluations"; import EditEvaluations from "../OutputsTable/EditEvaluations";
import { useAppStore } from "~/state/store"; import { useAppStore } from "~/state/store";
import { DeleteButton } from "./DeleteButton";
export default function SettingsDrawer() { export default function ExperimentSettingsDrawer() {
const isOpen = useAppStore((state) => state.drawerOpen); const isOpen = useAppStore((state) => state.drawerOpen);
const closeDrawer = useAppStore((state) => state.closeDrawer); const closeDrawer = useAppStore((state) => state.closeDrawer);
@@ -22,13 +23,16 @@ export default function SettingsDrawer() {
<DrawerContent> <DrawerContent>
<DrawerCloseButton /> <DrawerCloseButton />
<DrawerHeader> <DrawerHeader>
<Heading size="md">Settings</Heading> <Heading size="md">Experiment Settings</Heading>
</DrawerHeader> </DrawerHeader>
<DrawerBody> <DrawerBody h="full" pb={4}>
<Stack spacing={6}> <VStack h="full" justifyContent="space-between">
<VStack spacing={6}>
<EditScenarioVars /> <EditScenarioVars />
<EditEvaluations /> <EditEvaluations />
</Stack> </VStack>
<DeleteButton />
</VStack>
</DrawerBody> </DrawerBody>
</DrawerContent> </DrawerContent>
</Drawer> </Drawer>

View File

@@ -22,7 +22,7 @@ export const OutputStats = ({
return ( return (
<HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}> <HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
<HStack flex={1}> <HStack flex={1}>
{modelOutput.outputEvaluation.map((evaluation) => { {modelOutput.outputEvaluations.map((evaluation) => {
const passed = evaluation.result > 0.5; const passed = evaluation.result > 0.5;
return ( return (
<Tooltip <Tooltip

View File

@@ -63,7 +63,7 @@ export const ScenariosHeader = () => {
</MenuButton> </MenuButton>
<MenuList fontSize="md" zIndex="dropdown" mt={-3}> <MenuList fontSize="md" zIndex="dropdown" mt={-3}>
<MenuItem <MenuItem
icon={<Icon as={BsPlus} boxSize={6} mx={-1} />} icon={<Icon as={BsPlus} boxSize={6} mx="-5px" />}
onClick={() => onAddScenario(false)} onClick={() => onAddScenario(false)}
> >
Add Scenario Add Scenario

View File

@@ -1,4 +1,13 @@
import { HStack, Icon, VStack, Text, Divider, Spinner, AspectRatio } from "@chakra-ui/react"; import {
HStack,
Icon,
VStack,
Text,
Divider,
Spinner,
AspectRatio,
SkeletonText,
} from "@chakra-ui/react";
import { RiFlaskLine } from "react-icons/ri"; import { RiFlaskLine } from "react-icons/ri";
import { formatTimePast } from "~/utils/dayjs"; import { formatTimePast } from "~/utils/dayjs";
import Link from "next/link"; import Link from "next/link";
@@ -93,3 +102,13 @@ export const NewExperimentCard = () => {
</AspectRatio> </AspectRatio>
); );
}; };
export const ExperimentCardSkeleton = () => (
<AspectRatio ratio={1.2} w="full">
<VStack align="center" borderColor="gray.200" borderWidth={1} p={4} bg="gray.50">
<SkeletonText noOfLines={1} w="80%" />
<SkeletonText noOfLines={2} w="60%" />
<SkeletonText noOfLines={1} w="80%" />
</VStack>
</AspectRatio>
);

View File

@@ -0,0 +1,57 @@
import {
Button,
AlertDialog,
AlertDialogBody,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogContent,
AlertDialogOverlay,
} from "@chakra-ui/react";
import { useRouter } from "next/router";
import { useRef } from "react";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
export const DeleteDialog = ({ onClose }: { onClose: () => void }) => {
const experiment = useExperiment();
const deleteMutation = api.experiments.delete.useMutation();
const utils = api.useContext();
const router = useRouter();
const cancelRef = useRef<HTMLButtonElement>(null);
const [onDeleteConfirm] = useHandledAsyncCallback(async () => {
if (!experiment.data?.id) return;
await deleteMutation.mutateAsync({ id: experiment.data.id });
await utils.experiments.list.invalidate();
await router.push({ pathname: "/experiments" });
onClose();
}, [deleteMutation, experiment.data?.id, router]);
return (
<AlertDialog isOpen leastDestructiveRef={cancelRef} onClose={onClose}>
<AlertDialogOverlay>
<AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold">
Delete Experiment
</AlertDialogHeader>
<AlertDialogBody>
If you delete this experiment all the associated prompts and scenarios will be deleted
as well. Are you sure?
</AlertDialogBody>
<AlertDialogFooter>
<Button ref={cancelRef} onClick={onClose}>
Cancel
</Button>
<Button colorScheme="red" onClick={onDeleteConfirm} ml={3}>
Delete
</Button>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialogOverlay>
</AlertDialog>
);
};

View File

@@ -0,0 +1,46 @@
import { Button, HStack, Icon, Spinner, Text } from "@chakra-ui/react";
import { useOnForkButtonPressed } from "./useOnForkButtonPressed";
import { useExperiment } from "~/utils/hooks";
import { BsGearFill } from "react-icons/bs";
import { TbGitFork } from "react-icons/tb";
import { useAppStore } from "~/state/store";
export const HeaderButtons = () => {
const experiment = useExperiment();
const canModify = experiment.data?.access.canModify ?? false;
const { onForkButtonPressed, isForking } = useOnForkButtonPressed();
const openDrawer = useAppStore((s) => s.openDrawer);
if (experiment.isLoading) return null;
return (
<HStack spacing={0}>
<Button
onClick={onForkButtonPressed}
mr={4}
colorScheme={canModify ? undefined : "orange"}
bgColor={canModify ? undefined : "orange.400"}
minW={0}
variant={canModify ? "ghost" : "solid"}
>
{isForking ? <Spinner boxSize={5} /> : <Icon as={TbGitFork} boxSize={5} />}
<Text ml={2}>Fork</Text>
</Button>
{canModify && (
<Button
mt={{ base: 2, md: 0 }}
variant={{ base: "solid", md: "ghost" }}
onClick={openDrawer}
>
<HStack>
<Icon as={BsGearFill} />
<Text>Settings</Text>
</HStack>
</Button>
)}
</HStack>
);
};

View File

@@ -0,0 +1,30 @@
import { useCallback } from "react";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import { signIn, useSession } from "next-auth/react";
import { useRouter } from "next/router";
export const useOnForkButtonPressed = () => {
const router = useRouter();
const user = useSession().data;
const experiment = useExperiment();
const forkMutation = api.experiments.fork.useMutation();
const [onFork, isForking] = useHandledAsyncCallback(async () => {
if (!experiment.data?.id) return;
const forkedExperimentId = await forkMutation.mutateAsync({ id: experiment.data.id });
await router.push({ pathname: "/experiments/[id]", query: { id: forkedExperimentId } });
}, [forkMutation, experiment.data?.id, router]);
const onForkButtonPressed = useCallback(() => {
if (user === null) {
signIn("github").catch(console.error);
} else {
onFork();
}
}, [onFork, user]);
return { onForkButtonPressed, isForking };
};

View File

@@ -2,100 +2,31 @@ import {
Box, Box,
Breadcrumb, Breadcrumb,
BreadcrumbItem, BreadcrumbItem,
Button,
Center, Center,
Flex, Flex,
Icon, Icon,
Input, Input,
AlertDialog,
AlertDialogBody,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogContent,
AlertDialogOverlay,
useDisclosure,
Text, Text,
HStack,
VStack, VStack,
} from "@chakra-ui/react"; } from "@chakra-ui/react";
import Link from "next/link"; import Link from "next/link";
import { useRouter } from "next/router"; import { useRouter } from "next/router";
import { useState, useEffect, useRef } from "react"; import { useState, useEffect } from "react";
import { BsGearFill, BsTrash } from "react-icons/bs";
import { RiFlaskLine } from "react-icons/ri"; import { RiFlaskLine } from "react-icons/ri";
import OutputsTable from "~/components/OutputsTable"; import OutputsTable from "~/components/OutputsTable";
import SettingsDrawer from "~/components/OutputsTable/SettingsDrawer"; import ExperimentSettingsDrawer from "~/components/ExperimentSettingsDrawer/ExperimentSettingsDrawer";
import AppShell from "~/components/nav/AppShell"; import AppShell from "~/components/nav/AppShell";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import { useAppStore } from "~/state/store"; import { useAppStore } from "~/state/store";
import { useSyncVariantEditor } from "~/state/sync"; import { useSyncVariantEditor } from "~/state/sync";
import { HeaderButtons } from "~/components/experiments/HeaderButtons/HeaderButtons";
const DeleteButton = () => {
const experiment = useExperiment();
const mutation = api.experiments.delete.useMutation();
const utils = api.useContext();
const router = useRouter();
const { isOpen, onOpen, onClose } = useDisclosure();
const cancelRef = useRef<HTMLButtonElement>(null);
const [onDeleteConfirm] = useHandledAsyncCallback(async () => {
if (!experiment.data?.id) return;
await mutation.mutateAsync({ id: experiment.data.id });
await utils.experiments.list.invalidate();
await router.push({ pathname: "/experiments" });
onClose();
}, [mutation, experiment.data?.id, router]);
return (
<>
<Button
size="sm"
variant={{ base: "outline", lg: "ghost" }}
colorScheme="gray"
fontWeight="normal"
onClick={onOpen}
>
<Icon as={BsTrash} boxSize={4} color="gray.600" />
<Text display={{ base: "none", lg: "block" }} ml={2}>
Delete Experiment
</Text>
</Button>
<AlertDialog isOpen={isOpen} leastDestructiveRef={cancelRef} onClose={onClose}>
<AlertDialogOverlay>
<AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold">
Delete Experiment
</AlertDialogHeader>
<AlertDialogBody>
If you delete this experiment all the associated prompts and scenarios will be deleted
as well. Are you sure?
</AlertDialogBody>
<AlertDialogFooter>
<Button ref={cancelRef} onClick={onClose}>
Cancel
</Button>
<Button colorScheme="red" onClick={onDeleteConfirm} ml={3}>
Delete
</Button>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialogOverlay>
</AlertDialog>
</>
);
};
export default function Experiment() { export default function Experiment() {
const router = useRouter(); const router = useRouter();
const experiment = useExperiment(); const experiment = useExperiment();
const utils = api.useContext(); const utils = api.useContext();
const openDrawer = useAppStore((s) => s.openDrawer);
useSyncVariantEditor(); useSyncVariantEditor();
useEffect(() => { useEffect(() => {
@@ -138,7 +69,7 @@ export default function Experiment() {
py={2} py={2}
w="full" w="full"
direction={{ base: "column", sm: "row" }} direction={{ base: "column", sm: "row" }}
alignItems="flex-start" alignItems={{ base: "flex-start", sm: "center" }}
> >
<Breadcrumb flex={1}> <Breadcrumb flex={1}>
<BreadcrumbItem> <BreadcrumbItem>
@@ -171,25 +102,9 @@ export default function Experiment() {
)} )}
</BreadcrumbItem> </BreadcrumbItem>
</Breadcrumb> </Breadcrumb>
{canModify && ( <HeaderButtons />
<HStack>
<Button
size="sm"
variant={{ base: "outline", lg: "ghost" }}
colorScheme="gray"
fontWeight="normal"
onClick={openDrawer}
>
<Icon as={BsGearFill} boxSize={4} color="gray.600" />
<Text display={{ base: "none", lg: "block" }} ml={2}>
Edit Vars & Evals
</Text>
</Button>
<DeleteButton />
</HStack>
)}
</Flex> </Flex>
<SettingsDrawer /> <ExperimentSettingsDrawer />
<Box w="100%" overflowX="auto" flex={1}> <Box w="100%" overflowX="auto" flex={1}>
<OutputsTable experimentId={router.query.id as string | undefined} /> <OutputsTable experimentId={router.query.id as string | undefined} />
</Box> </Box>

View File

@@ -13,7 +13,11 @@ import {
import { RiFlaskLine } from "react-icons/ri"; import { RiFlaskLine } from "react-icons/ri";
import AppShell from "~/components/nav/AppShell"; import AppShell from "~/components/nav/AppShell";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { ExperimentCard, NewExperimentCard } from "~/components/experiments/ExperimentCard"; import {
ExperimentCard,
ExperimentCardSkeleton,
NewExperimentCard,
} from "~/components/experiments/ExperimentCard";
import { signIn, useSession } from "next-auth/react"; import { signIn, useSession } from "next-auth/react";
export default function ExperimentsPage() { export default function ExperimentsPage() {
@@ -47,7 +51,7 @@ export default function ExperimentsPage() {
return ( return (
<AppShell title="Experiments"> <AppShell title="Experiments">
<VStack alignItems={"flex-start"} px={4} py={2}> <VStack alignItems={"flex-start"} px={4} py={2}>
<HStack minH={8} align="center"> <HStack minH={8} align="center" pt={2}>
<Breadcrumb flex={1}> <Breadcrumb flex={1}>
<BreadcrumbItem> <BreadcrumbItem>
<Flex alignItems="center"> <Flex alignItems="center">
@@ -58,7 +62,15 @@ export default function ExperimentsPage() {
</HStack> </HStack>
<SimpleGrid w="full" columns={{ base: 1, md: 2, lg: 3, xl: 4 }} spacing={8} p="4"> <SimpleGrid w="full" columns={{ base: 1, md: 2, lg: 3, xl: 4 }} spacing={8} p="4">
<NewExperimentCard /> <NewExperimentCard />
{experiments?.data?.map((exp) => <ExperimentCard key={exp.id} exp={exp} />)} {experiments.data && !experiments.isLoading ? (
experiments?.data?.map((exp) => <ExperimentCard key={exp.id} exp={exp} />)
) : (
<>
<ExperimentCardSkeleton />
<ExperimentCardSkeleton />
<ExperimentCardSkeleton />
</>
)}
</SimpleGrid> </SimpleGrid>
</VStack> </VStack>
</AppShell> </AppShell>

View File

@@ -1,5 +1,7 @@
import { z } from "zod"; import { z } from "zod";
import { v4 as uuidv4 } from "uuid";
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { type Prisma } from "@prisma/client";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import dedent from "dedent"; import dedent from "dedent";
import { generateNewCell } from "~/server/utils/generateNewCell"; import { generateNewCell } from "~/server/utils/generateNewCell";
@@ -20,7 +22,7 @@ export const experimentsRouter = createTRPCRouter({
const experiments = await prisma.experiment.findMany({ const experiments = await prisma.experiment.findMany({
where: { where: {
organization: { organization: {
OrganizationUser: { organizationUsers: {
some: { userId: ctx.session.user.id }, some: { userId: ctx.session.user.id },
}, },
}, },
@@ -77,6 +79,189 @@ export const experimentsRouter = createTRPCRouter({
}; };
}), }),
fork: protectedProcedure.input(z.object({ id: z.string() })).mutation(async ({ input, ctx }) => {
await requireCanViewExperiment(input.id, ctx);
const [
existingExp,
existingVariants,
existingScenarios,
existingCells,
evaluations,
templateVariables,
] = await prisma.$transaction([
prisma.experiment.findUniqueOrThrow({
where: {
id: input.id,
},
}),
prisma.promptVariant.findMany({
where: {
experimentId: input.id,
visible: true,
},
}),
prisma.testScenario.findMany({
where: {
experimentId: input.id,
visible: true,
},
}),
prisma.scenarioVariantCell.findMany({
where: {
testScenario: {
visible: true,
},
promptVariant: {
experimentId: input.id,
visible: true,
},
},
include: {
modelOutput: {
include: {
outputEvaluations: true,
},
},
},
}),
prisma.evaluation.findMany({
where: {
experimentId: input.id,
},
}),
prisma.templateVariable.findMany({
where: {
experimentId: input.id,
},
}),
]);
const newExperimentId = uuidv4();
const existingToNewVariantIds = new Map<string, string>();
const variantsToCreate: Prisma.PromptVariantCreateManyInput[] = [];
for (const variant of existingVariants) {
const newVariantId = uuidv4();
existingToNewVariantIds.set(variant.id, newVariantId);
variantsToCreate.push({
...variant,
id: newVariantId,
experimentId: newExperimentId,
});
}
const existingToNewScenarioIds = new Map<string, string>();
const scenariosToCreate: Prisma.TestScenarioCreateManyInput[] = [];
for (const scenario of existingScenarios) {
const newScenarioId = uuidv4();
existingToNewScenarioIds.set(scenario.id, newScenarioId);
scenariosToCreate.push({
...scenario,
id: newScenarioId,
experimentId: newExperimentId,
variableValues: scenario.variableValues as Prisma.InputJsonValue,
});
}
const existingToNewEvaluationIds = new Map<string, string>();
const evaluationsToCreate: Prisma.EvaluationCreateManyInput[] = [];
for (const evaluation of evaluations) {
const newEvaluationId = uuidv4();
existingToNewEvaluationIds.set(evaluation.id, newEvaluationId);
evaluationsToCreate.push({
...evaluation,
id: newEvaluationId,
experimentId: newExperimentId,
});
}
const cellsToCreate: Prisma.ScenarioVariantCellCreateManyInput[] = [];
const modelOutputsToCreate: Prisma.ModelOutputCreateManyInput[] = [];
const outputEvaluationsToCreate: Prisma.OutputEvaluationCreateManyInput[] = [];
for (const cell of existingCells) {
const newCellId = uuidv4();
const { modelOutput, ...cellData } = cell;
cellsToCreate.push({
...cellData,
id: newCellId,
promptVariantId: existingToNewVariantIds.get(cell.promptVariantId) ?? "",
testScenarioId: existingToNewScenarioIds.get(cell.testScenarioId) ?? "",
prompt: (cell.prompt as Prisma.InputJsonValue) ?? undefined,
});
if (modelOutput) {
const newModelOutputId = uuidv4();
const { outputEvaluations, ...modelOutputData } = modelOutput;
modelOutputsToCreate.push({
...modelOutputData,
id: newModelOutputId,
scenarioVariantCellId: newCellId,
output: (modelOutput.output as Prisma.InputJsonValue) ?? undefined,
});
for (const evaluation of outputEvaluations) {
outputEvaluationsToCreate.push({
...evaluation,
id: uuidv4(),
modelOutputId: newModelOutputId,
evaluationId: existingToNewEvaluationIds.get(evaluation.evaluationId) ?? "",
});
}
}
}
const templateVariablesToCreate: Prisma.TemplateVariableCreateManyInput[] = [];
for (const templateVariable of templateVariables) {
templateVariablesToCreate.push({
...templateVariable,
id: uuidv4(),
experimentId: newExperimentId,
});
}
const maxSortIndex =
(
await prisma.experiment.aggregate({
_max: {
sortIndex: true,
},
})
)._max?.sortIndex ?? 0;
await prisma.$transaction([
prisma.experiment.create({
data: {
id: newExperimentId,
sortIndex: maxSortIndex + 1,
label: `${existingExp.label} (forked)`,
organizationId: (await userOrg(ctx.session.user.id)).id,
},
}),
prisma.promptVariant.createMany({
data: variantsToCreate,
}),
prisma.testScenario.createMany({
data: scenariosToCreate,
}),
prisma.scenarioVariantCell.createMany({
data: cellsToCreate,
}),
prisma.modelOutput.createMany({
data: modelOutputsToCreate,
}),
prisma.evaluation.createMany({
data: evaluationsToCreate,
}),
prisma.outputEvaluation.createMany({
data: outputEvaluationsToCreate,
}),
prisma.templateVariable.createMany({
data: templateVariablesToCreate,
}),
]);
return newExperimentId;
}),
create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => { create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => {
// Anyone can create an experiment // Anyone can create an experiment
requireNothing(ctx); requireNothing(ctx);

View File

@@ -29,7 +29,7 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
include: { include: {
modelOutput: { modelOutput: {
include: { include: {
outputEvaluation: { outputEvaluations: {
include: { include: {
evaluation: { evaluation: {
select: { label: true }, select: { label: true },

View File

@@ -56,7 +56,7 @@ export const runAllEvals = async (experimentId: string) => {
testScenario: true, testScenario: true,
}, },
}, },
outputEvaluation: true, outputEvaluations: true,
}, },
}); });
const evals = await prisma.evaluation.findMany({ const evals = await prisma.evaluation.findMany({
@@ -66,7 +66,7 @@ export const runAllEvals = async (experimentId: string) => {
await Promise.all( await Promise.all(
outputs.map(async (output) => { outputs.map(async (output) => {
const unrunEvals = evals.filter( const unrunEvals = evals.filter(
(evaluation) => !output.outputEvaluation.find((e) => e.evaluationId === evaluation.id), (evaluation) => !output.outputEvaluations.find((e) => e.evaluationId === evaluation.id),
); );
await Promise.all( await Promise.all(

View File

@@ -8,7 +8,7 @@ export default async function userOrg(userId: string) {
update: {}, update: {},
create: { create: {
personalOrgUserId: userId, personalOrgUserId: userId,
OrganizationUser: { organizationUsers: {
create: { create: {
userId: userId, userId: userId,
role: "ADMIN", role: "ADMIN",

View File

@@ -22,7 +22,7 @@ export const canModifyExperiment = async (experimentId: string, userId: string)
where: { where: {
id: experimentId, id: experimentId,
organization: { organization: {
OrganizationUser: { organizationUsers: {
some: { some: {
role: { in: [OrganizationUserRole.ADMIN, OrganizationUserRole.MEMBER] }, role: { in: [OrganizationUserRole.ADMIN, OrganizationUserRole.MEMBER] },
userId, userId,