Allow experiment forking (#89)

* Move DeleteButton into a separate file

* Rename plural relations

* Add ability to fork

* Fork automatically after auth upon return

* Add experiment card skeleton

* Create HeaderButtons component

* return no header buttons while experiment loading

* Fix prettier

* Remove unused variable

* Remove newline

* Default json values to undefined

* Change header styles

* Fix prettier

* Give AddScenario icon less width

* Move useEffect

* Skip invalidating experiments list after forking

* Require user to be able to view experiment to fork it

* Move experiment creation into same transaction

* Only return the forked experiment id

* Put delete button in experiment settings drawer

* Move useEffect hook
This commit is contained in:
arcticfly
2023-07-24 18:10:59 -07:00
committed by GitHub
parent 09140f8b5f
commit d6b97b29f7
18 changed files with 493 additions and 131 deletions

View File

@@ -81,6 +81,7 @@
"tsx": "^3.12.7",
"type-fest": "^4.0.0",
"use-query-params": "^2.2.1",
"uuid": "^9.0.0",
"vite-tsconfig-paths": "^4.2.0",
"zod": "^3.21.4",
"zustand": "^4.3.9"
@@ -101,6 +102,7 @@
"@types/react": "^18.2.6",
"@types/react-dom": "^18.2.4",
"@types/react-syntax-highlighter": "^15.5.7",
"@types/uuid": "^9.0.2",
"@typescript-eslint/eslint-plugin": "^5.59.6",
"@typescript-eslint/parser": "^5.59.6",
"eslint": "^8.40.0",

17
pnpm-lock.yaml generated
View File

@@ -1,4 +1,4 @@
lockfileVersion: '6.1'
lockfileVersion: '6.0'
settings:
autoInstallPeers: true
@@ -185,6 +185,9 @@ dependencies:
use-query-params:
specifier: ^2.2.1
version: 2.2.1(react-dom@18.2.0)(react@18.2.0)
uuid:
specifier: ^9.0.0
version: 9.0.0
vite-tsconfig-paths:
specifier: ^4.2.0
version: 4.2.0(typescript@5.0.4)
@@ -241,6 +244,9 @@ devDependencies:
'@types/react-syntax-highlighter':
specifier: ^15.5.7
version: 15.5.7
'@types/uuid':
specifier: ^9.0.2
version: 9.0.2
'@typescript-eslint/eslint-plugin':
specifier: ^5.59.6
version: 5.59.6(@typescript-eslint/parser@5.59.6)(eslint@8.40.0)(typescript@5.0.4)
@@ -3024,6 +3030,10 @@ packages:
resolution: {integrity: sha512-cputDpIbFgLUaGQn6Vqg3/YsJwxUwHLO13v3i5ouxT4lat0khip9AEWxtERujXV9wxIB1EyF97BSJFt6vpdI8g==}
dev: false
/@types/uuid@9.0.2:
resolution: {integrity: sha512-kNnC1GFBLuhImSnV7w4njQkUiJi0ZXUycu1rUaouPqiKlXkh77JKgdRnTAp1x5eBwcIwbtI+3otwzuIDEuDoxQ==}
dev: true
/@typescript-eslint/eslint-plugin@5.59.6(@typescript-eslint/parser@5.59.6)(eslint@8.40.0)(typescript@5.0.4):
resolution: {integrity: sha512-sXtOgJNEuRU5RLwPUb1jxtToZbgvq3M6FPpY4QENxoOggK+UpTxUBpj6tD8+Qh2g46Pi9We87E+eHnUw8YcGsw==}
engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0}
@@ -7913,6 +7923,11 @@ packages:
hasBin: true
dev: false
/uuid@9.0.0:
resolution: {integrity: sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg==}
hasBin: true
dev: false
/vary@1.1.2:
resolution: {integrity: sha512-BNGbWLfd0eUPabhkXUVm0j8uuvREyTh5ovRa/dyow/BqAbZJyC+5fU+IzQOzmAKzYqYRAISoRhdQr3eIZ/PXqg==}
engines: {node: '>= 0.8'}

View File

@@ -22,10 +22,10 @@ model Experiment {
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
TemplateVariable TemplateVariable[]
PromptVariant PromptVariant[]
TestScenario TestScenario[]
Evaluation Evaluation[]
templateVariables TemplateVariable[]
promptVariants PromptVariant[]
testScenarios TestScenario[]
evaluations Evaluation[]
}
model PromptVariant {
@@ -126,7 +126,7 @@ model ModelOutput {
scenarioVariantCellId String @db.Uuid
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
outputEvaluation OutputEvaluation[]
outputEvaluations OutputEvaluation[]
@@unique([scenarioVariantCellId])
@@index([inputHash])
@@ -150,7 +150,7 @@ model Evaluation {
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
OutputEvaluation OutputEvaluation[]
outputEvaluations OutputEvaluation[]
}
model OutputEvaluation {
@@ -179,8 +179,8 @@ model Organization {
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
OrganizationUser OrganizationUser[]
Experiment Experiment[]
organizationUsers OrganizationUser[]
experiments Experiment[]
}
enum OrganizationUserRole {
@@ -234,15 +234,15 @@ model Session {
}
model User {
id String @id @default(uuid()) @db.Uuid
name String?
email String? @unique
emailVerified DateTime?
image String?
accounts Account[]
sessions Session[]
OrganizationUser OrganizationUser[]
Organization Organization[]
id String @id @default(uuid()) @db.Uuid
name String?
email String? @unique
emailVerified DateTime?
image String?
accounts Account[]
sessions Session[]
organizationUsers OrganizationUser[]
organizations Organization[]
}
model VerificationToken {

View File

@@ -0,0 +1,77 @@
import {
Button,
Icon,
AlertDialog,
AlertDialogBody,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogContent,
AlertDialogOverlay,
useDisclosure,
Text,
} from "@chakra-ui/react";
import { useRouter } from "next/router";
import { useRef } from "react";
import { BsTrash } from "react-icons/bs";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
export const DeleteButton = () => {
const experiment = useExperiment();
const mutation = api.experiments.delete.useMutation();
const utils = api.useContext();
const router = useRouter();
const { isOpen, onOpen, onClose } = useDisclosure();
const cancelRef = useRef<HTMLButtonElement>(null);
const [onDeleteConfirm] = useHandledAsyncCallback(async () => {
if (!experiment.data?.id) return;
await mutation.mutateAsync({ id: experiment.data.id });
await utils.experiments.list.invalidate();
await router.push({ pathname: "/experiments" });
onClose();
}, [mutation, experiment.data?.id, router]);
return (
<>
<Button
size="sm"
variant={{ base: "outline", lg: "ghost" }}
colorScheme="red"
fontWeight="normal"
onClick={onOpen}
>
<Icon as={BsTrash} boxSize={4} />
<Text display={{ base: "none", lg: "block" }} ml={2}>
Delete Experiment
</Text>
</Button>
<AlertDialog isOpen={isOpen} leastDestructiveRef={cancelRef} onClose={onClose}>
<AlertDialogOverlay>
<AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold">
Delete Experiment
</AlertDialogHeader>
<AlertDialogBody>
If you delete this experiment all the associated prompts and scenarios will be deleted
as well. Are you sure?
</AlertDialogBody>
<AlertDialogFooter>
<Button ref={cancelRef} onClick={onClose}>
Cancel
</Button>
<Button colorScheme="red" onClick={onDeleteConfirm} ml={3}>
Delete
</Button>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialogOverlay>
</AlertDialog>
</>
);
};

View File

@@ -6,13 +6,14 @@ import {
DrawerHeader,
DrawerOverlay,
Heading,
Stack,
VStack,
} from "@chakra-ui/react";
import EditScenarioVars from "./EditScenarioVars";
import EditEvaluations from "./EditEvaluations";
import EditScenarioVars from "../OutputsTable/EditScenarioVars";
import EditEvaluations from "../OutputsTable/EditEvaluations";
import { useAppStore } from "~/state/store";
import { DeleteButton } from "./DeleteButton";
export default function SettingsDrawer() {
export default function ExperimentSettingsDrawer() {
const isOpen = useAppStore((state) => state.drawerOpen);
const closeDrawer = useAppStore((state) => state.closeDrawer);
@@ -22,13 +23,16 @@ export default function SettingsDrawer() {
<DrawerContent>
<DrawerCloseButton />
<DrawerHeader>
<Heading size="md">Settings</Heading>
<Heading size="md">Experiment Settings</Heading>
</DrawerHeader>
<DrawerBody>
<Stack spacing={6}>
<EditScenarioVars />
<EditEvaluations />
</Stack>
<DrawerBody h="full" pb={4}>
<VStack h="full" justifyContent="space-between">
<VStack spacing={6}>
<EditScenarioVars />
<EditEvaluations />
</VStack>
<DeleteButton />
</VStack>
</DrawerBody>
</DrawerContent>
</Drawer>

View File

@@ -22,7 +22,7 @@ export const OutputStats = ({
return (
<HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
<HStack flex={1}>
{modelOutput.outputEvaluation.map((evaluation) => {
{modelOutput.outputEvaluations.map((evaluation) => {
const passed = evaluation.result > 0.5;
return (
<Tooltip

View File

@@ -63,7 +63,7 @@ export const ScenariosHeader = () => {
</MenuButton>
<MenuList fontSize="md" zIndex="dropdown" mt={-3}>
<MenuItem
icon={<Icon as={BsPlus} boxSize={6} mx={-1} />}
icon={<Icon as={BsPlus} boxSize={6} mx="-5px" />}
onClick={() => onAddScenario(false)}
>
Add Scenario

View File

@@ -1,4 +1,13 @@
import { HStack, Icon, VStack, Text, Divider, Spinner, AspectRatio } from "@chakra-ui/react";
import {
HStack,
Icon,
VStack,
Text,
Divider,
Spinner,
AspectRatio,
SkeletonText,
} from "@chakra-ui/react";
import { RiFlaskLine } from "react-icons/ri";
import { formatTimePast } from "~/utils/dayjs";
import Link from "next/link";
@@ -93,3 +102,13 @@ export const NewExperimentCard = () => {
</AspectRatio>
);
};
export const ExperimentCardSkeleton = () => (
<AspectRatio ratio={1.2} w="full">
<VStack align="center" borderColor="gray.200" borderWidth={1} p={4} bg="gray.50">
<SkeletonText noOfLines={1} w="80%" />
<SkeletonText noOfLines={2} w="60%" />
<SkeletonText noOfLines={1} w="80%" />
</VStack>
</AspectRatio>
);

View File

@@ -0,0 +1,57 @@
import {
Button,
AlertDialog,
AlertDialogBody,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogContent,
AlertDialogOverlay,
} from "@chakra-ui/react";
import { useRouter } from "next/router";
import { useRef } from "react";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
export const DeleteDialog = ({ onClose }: { onClose: () => void }) => {
const experiment = useExperiment();
const deleteMutation = api.experiments.delete.useMutation();
const utils = api.useContext();
const router = useRouter();
const cancelRef = useRef<HTMLButtonElement>(null);
const [onDeleteConfirm] = useHandledAsyncCallback(async () => {
if (!experiment.data?.id) return;
await deleteMutation.mutateAsync({ id: experiment.data.id });
await utils.experiments.list.invalidate();
await router.push({ pathname: "/experiments" });
onClose();
}, [deleteMutation, experiment.data?.id, router]);
return (
<AlertDialog isOpen leastDestructiveRef={cancelRef} onClose={onClose}>
<AlertDialogOverlay>
<AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold">
Delete Experiment
</AlertDialogHeader>
<AlertDialogBody>
If you delete this experiment all the associated prompts and scenarios will be deleted
as well. Are you sure?
</AlertDialogBody>
<AlertDialogFooter>
<Button ref={cancelRef} onClick={onClose}>
Cancel
</Button>
<Button colorScheme="red" onClick={onDeleteConfirm} ml={3}>
Delete
</Button>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialogOverlay>
</AlertDialog>
);
};

View File

@@ -0,0 +1,46 @@
import { Button, HStack, Icon, Spinner, Text } from "@chakra-ui/react";
import { useOnForkButtonPressed } from "./useOnForkButtonPressed";
import { useExperiment } from "~/utils/hooks";
import { BsGearFill } from "react-icons/bs";
import { TbGitFork } from "react-icons/tb";
import { useAppStore } from "~/state/store";
export const HeaderButtons = () => {
const experiment = useExperiment();
const canModify = experiment.data?.access.canModify ?? false;
const { onForkButtonPressed, isForking } = useOnForkButtonPressed();
const openDrawer = useAppStore((s) => s.openDrawer);
if (experiment.isLoading) return null;
return (
<HStack spacing={0}>
<Button
onClick={onForkButtonPressed}
mr={4}
colorScheme={canModify ? undefined : "orange"}
bgColor={canModify ? undefined : "orange.400"}
minW={0}
variant={canModify ? "ghost" : "solid"}
>
{isForking ? <Spinner boxSize={5} /> : <Icon as={TbGitFork} boxSize={5} />}
<Text ml={2}>Fork</Text>
</Button>
{canModify && (
<Button
mt={{ base: 2, md: 0 }}
variant={{ base: "solid", md: "ghost" }}
onClick={openDrawer}
>
<HStack>
<Icon as={BsGearFill} />
<Text>Settings</Text>
</HStack>
</Button>
)}
</HStack>
);
};

View File

@@ -0,0 +1,30 @@
import { useCallback } from "react";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import { signIn, useSession } from "next-auth/react";
import { useRouter } from "next/router";
export const useOnForkButtonPressed = () => {
const router = useRouter();
const user = useSession().data;
const experiment = useExperiment();
const forkMutation = api.experiments.fork.useMutation();
const [onFork, isForking] = useHandledAsyncCallback(async () => {
if (!experiment.data?.id) return;
const forkedExperimentId = await forkMutation.mutateAsync({ id: experiment.data.id });
await router.push({ pathname: "/experiments/[id]", query: { id: forkedExperimentId } });
}, [forkMutation, experiment.data?.id, router]);
const onForkButtonPressed = useCallback(() => {
if (user === null) {
signIn("github").catch(console.error);
} else {
onFork();
}
}, [onFork, user]);
return { onForkButtonPressed, isForking };
};

View File

@@ -2,100 +2,31 @@ import {
Box,
Breadcrumb,
BreadcrumbItem,
Button,
Center,
Flex,
Icon,
Input,
AlertDialog,
AlertDialogBody,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogContent,
AlertDialogOverlay,
useDisclosure,
Text,
HStack,
VStack,
} from "@chakra-ui/react";
import Link from "next/link";
import { useRouter } from "next/router";
import { useState, useEffect, useRef } from "react";
import { BsGearFill, BsTrash } from "react-icons/bs";
import { useState, useEffect } from "react";
import { RiFlaskLine } from "react-icons/ri";
import OutputsTable from "~/components/OutputsTable";
import SettingsDrawer from "~/components/OutputsTable/SettingsDrawer";
import ExperimentSettingsDrawer from "~/components/ExperimentSettingsDrawer/ExperimentSettingsDrawer";
import AppShell from "~/components/nav/AppShell";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import { useAppStore } from "~/state/store";
import { useSyncVariantEditor } from "~/state/sync";
const DeleteButton = () => {
const experiment = useExperiment();
const mutation = api.experiments.delete.useMutation();
const utils = api.useContext();
const router = useRouter();
const { isOpen, onOpen, onClose } = useDisclosure();
const cancelRef = useRef<HTMLButtonElement>(null);
const [onDeleteConfirm] = useHandledAsyncCallback(async () => {
if (!experiment.data?.id) return;
await mutation.mutateAsync({ id: experiment.data.id });
await utils.experiments.list.invalidate();
await router.push({ pathname: "/experiments" });
onClose();
}, [mutation, experiment.data?.id, router]);
return (
<>
<Button
size="sm"
variant={{ base: "outline", lg: "ghost" }}
colorScheme="gray"
fontWeight="normal"
onClick={onOpen}
>
<Icon as={BsTrash} boxSize={4} color="gray.600" />
<Text display={{ base: "none", lg: "block" }} ml={2}>
Delete Experiment
</Text>
</Button>
<AlertDialog isOpen={isOpen} leastDestructiveRef={cancelRef} onClose={onClose}>
<AlertDialogOverlay>
<AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold">
Delete Experiment
</AlertDialogHeader>
<AlertDialogBody>
If you delete this experiment all the associated prompts and scenarios will be deleted
as well. Are you sure?
</AlertDialogBody>
<AlertDialogFooter>
<Button ref={cancelRef} onClick={onClose}>
Cancel
</Button>
<Button colorScheme="red" onClick={onDeleteConfirm} ml={3}>
Delete
</Button>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialogOverlay>
</AlertDialog>
</>
);
};
import { HeaderButtons } from "~/components/experiments/HeaderButtons/HeaderButtons";
export default function Experiment() {
const router = useRouter();
const experiment = useExperiment();
const utils = api.useContext();
const openDrawer = useAppStore((s) => s.openDrawer);
useSyncVariantEditor();
useEffect(() => {
@@ -138,7 +69,7 @@ export default function Experiment() {
py={2}
w="full"
direction={{ base: "column", sm: "row" }}
alignItems="flex-start"
alignItems={{ base: "flex-start", sm: "center" }}
>
<Breadcrumb flex={1}>
<BreadcrumbItem>
@@ -171,25 +102,9 @@ export default function Experiment() {
)}
</BreadcrumbItem>
</Breadcrumb>
{canModify && (
<HStack>
<Button
size="sm"
variant={{ base: "outline", lg: "ghost" }}
colorScheme="gray"
fontWeight="normal"
onClick={openDrawer}
>
<Icon as={BsGearFill} boxSize={4} color="gray.600" />
<Text display={{ base: "none", lg: "block" }} ml={2}>
Edit Vars & Evals
</Text>
</Button>
<DeleteButton />
</HStack>
)}
<HeaderButtons />
</Flex>
<SettingsDrawer />
<ExperimentSettingsDrawer />
<Box w="100%" overflowX="auto" flex={1}>
<OutputsTable experimentId={router.query.id as string | undefined} />
</Box>

View File

@@ -13,7 +13,11 @@ import {
import { RiFlaskLine } from "react-icons/ri";
import AppShell from "~/components/nav/AppShell";
import { api } from "~/utils/api";
import { ExperimentCard, NewExperimentCard } from "~/components/experiments/ExperimentCard";
import {
ExperimentCard,
ExperimentCardSkeleton,
NewExperimentCard,
} from "~/components/experiments/ExperimentCard";
import { signIn, useSession } from "next-auth/react";
export default function ExperimentsPage() {
@@ -47,7 +51,7 @@ export default function ExperimentsPage() {
return (
<AppShell title="Experiments">
<VStack alignItems={"flex-start"} px={4} py={2}>
<HStack minH={8} align="center">
<HStack minH={8} align="center" pt={2}>
<Breadcrumb flex={1}>
<BreadcrumbItem>
<Flex alignItems="center">
@@ -58,7 +62,15 @@ export default function ExperimentsPage() {
</HStack>
<SimpleGrid w="full" columns={{ base: 1, md: 2, lg: 3, xl: 4 }} spacing={8} p="4">
<NewExperimentCard />
{experiments?.data?.map((exp) => <ExperimentCard key={exp.id} exp={exp} />)}
{experiments.data && !experiments.isLoading ? (
experiments?.data?.map((exp) => <ExperimentCard key={exp.id} exp={exp} />)
) : (
<>
<ExperimentCardSkeleton />
<ExperimentCardSkeleton />
<ExperimentCardSkeleton />
</>
)}
</SimpleGrid>
</VStack>
</AppShell>

View File

@@ -1,5 +1,7 @@
import { z } from "zod";
import { v4 as uuidv4 } from "uuid";
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { type Prisma } from "@prisma/client";
import { prisma } from "~/server/db";
import dedent from "dedent";
import { generateNewCell } from "~/server/utils/generateNewCell";
@@ -20,7 +22,7 @@ export const experimentsRouter = createTRPCRouter({
const experiments = await prisma.experiment.findMany({
where: {
organization: {
OrganizationUser: {
organizationUsers: {
some: { userId: ctx.session.user.id },
},
},
@@ -77,6 +79,189 @@ export const experimentsRouter = createTRPCRouter({
};
}),
fork: protectedProcedure.input(z.object({ id: z.string() })).mutation(async ({ input, ctx }) => {
await requireCanViewExperiment(input.id, ctx);
const [
existingExp,
existingVariants,
existingScenarios,
existingCells,
evaluations,
templateVariables,
] = await prisma.$transaction([
prisma.experiment.findUniqueOrThrow({
where: {
id: input.id,
},
}),
prisma.promptVariant.findMany({
where: {
experimentId: input.id,
visible: true,
},
}),
prisma.testScenario.findMany({
where: {
experimentId: input.id,
visible: true,
},
}),
prisma.scenarioVariantCell.findMany({
where: {
testScenario: {
visible: true,
},
promptVariant: {
experimentId: input.id,
visible: true,
},
},
include: {
modelOutput: {
include: {
outputEvaluations: true,
},
},
},
}),
prisma.evaluation.findMany({
where: {
experimentId: input.id,
},
}),
prisma.templateVariable.findMany({
where: {
experimentId: input.id,
},
}),
]);
const newExperimentId = uuidv4();
const existingToNewVariantIds = new Map<string, string>();
const variantsToCreate: Prisma.PromptVariantCreateManyInput[] = [];
for (const variant of existingVariants) {
const newVariantId = uuidv4();
existingToNewVariantIds.set(variant.id, newVariantId);
variantsToCreate.push({
...variant,
id: newVariantId,
experimentId: newExperimentId,
});
}
const existingToNewScenarioIds = new Map<string, string>();
const scenariosToCreate: Prisma.TestScenarioCreateManyInput[] = [];
for (const scenario of existingScenarios) {
const newScenarioId = uuidv4();
existingToNewScenarioIds.set(scenario.id, newScenarioId);
scenariosToCreate.push({
...scenario,
id: newScenarioId,
experimentId: newExperimentId,
variableValues: scenario.variableValues as Prisma.InputJsonValue,
});
}
const existingToNewEvaluationIds = new Map<string, string>();
const evaluationsToCreate: Prisma.EvaluationCreateManyInput[] = [];
for (const evaluation of evaluations) {
const newEvaluationId = uuidv4();
existingToNewEvaluationIds.set(evaluation.id, newEvaluationId);
evaluationsToCreate.push({
...evaluation,
id: newEvaluationId,
experimentId: newExperimentId,
});
}
const cellsToCreate: Prisma.ScenarioVariantCellCreateManyInput[] = [];
const modelOutputsToCreate: Prisma.ModelOutputCreateManyInput[] = [];
const outputEvaluationsToCreate: Prisma.OutputEvaluationCreateManyInput[] = [];
for (const cell of existingCells) {
const newCellId = uuidv4();
const { modelOutput, ...cellData } = cell;
cellsToCreate.push({
...cellData,
id: newCellId,
promptVariantId: existingToNewVariantIds.get(cell.promptVariantId) ?? "",
testScenarioId: existingToNewScenarioIds.get(cell.testScenarioId) ?? "",
prompt: (cell.prompt as Prisma.InputJsonValue) ?? undefined,
});
if (modelOutput) {
const newModelOutputId = uuidv4();
const { outputEvaluations, ...modelOutputData } = modelOutput;
modelOutputsToCreate.push({
...modelOutputData,
id: newModelOutputId,
scenarioVariantCellId: newCellId,
output: (modelOutput.output as Prisma.InputJsonValue) ?? undefined,
});
for (const evaluation of outputEvaluations) {
outputEvaluationsToCreate.push({
...evaluation,
id: uuidv4(),
modelOutputId: newModelOutputId,
evaluationId: existingToNewEvaluationIds.get(evaluation.evaluationId) ?? "",
});
}
}
}
const templateVariablesToCreate: Prisma.TemplateVariableCreateManyInput[] = [];
for (const templateVariable of templateVariables) {
templateVariablesToCreate.push({
...templateVariable,
id: uuidv4(),
experimentId: newExperimentId,
});
}
const maxSortIndex =
(
await prisma.experiment.aggregate({
_max: {
sortIndex: true,
},
})
)._max?.sortIndex ?? 0;
await prisma.$transaction([
prisma.experiment.create({
data: {
id: newExperimentId,
sortIndex: maxSortIndex + 1,
label: `${existingExp.label} (forked)`,
organizationId: (await userOrg(ctx.session.user.id)).id,
},
}),
prisma.promptVariant.createMany({
data: variantsToCreate,
}),
prisma.testScenario.createMany({
data: scenariosToCreate,
}),
prisma.scenarioVariantCell.createMany({
data: cellsToCreate,
}),
prisma.modelOutput.createMany({
data: modelOutputsToCreate,
}),
prisma.evaluation.createMany({
data: evaluationsToCreate,
}),
prisma.outputEvaluation.createMany({
data: outputEvaluationsToCreate,
}),
prisma.templateVariable.createMany({
data: templateVariablesToCreate,
}),
]);
return newExperimentId;
}),
create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => {
// Anyone can create an experiment
requireNothing(ctx);

View File

@@ -29,7 +29,7 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
include: {
modelOutput: {
include: {
outputEvaluation: {
outputEvaluations: {
include: {
evaluation: {
select: { label: true },

View File

@@ -56,7 +56,7 @@ export const runAllEvals = async (experimentId: string) => {
testScenario: true,
},
},
outputEvaluation: true,
outputEvaluations: true,
},
});
const evals = await prisma.evaluation.findMany({
@@ -66,7 +66,7 @@ export const runAllEvals = async (experimentId: string) => {
await Promise.all(
outputs.map(async (output) => {
const unrunEvals = evals.filter(
(evaluation) => !output.outputEvaluation.find((e) => e.evaluationId === evaluation.id),
(evaluation) => !output.outputEvaluations.find((e) => e.evaluationId === evaluation.id),
);
await Promise.all(

View File

@@ -8,7 +8,7 @@ export default async function userOrg(userId: string) {
update: {},
create: {
personalOrgUserId: userId,
OrganizationUser: {
organizationUsers: {
create: {
userId: userId,
role: "ADMIN",

View File

@@ -22,7 +22,7 @@ export const canModifyExperiment = async (experimentId: string, userId: string)
where: {
id: experimentId,
organization: {
OrganizationUser: {
organizationUsers: {
some: {
role: { in: [OrganizationUserRole.ADMIN, OrganizationUserRole.MEMBER] },
userId,