add evaluations

This commit is contained in:
Kyle Corbitt
2023-07-06 13:39:13 -07:00
parent 1ae5612d55
commit f728027ef6
18 changed files with 614 additions and 68 deletions

View File

@@ -29,6 +29,7 @@
"@trpc/next": "^10.26.0", "@trpc/next": "^10.26.0",
"@trpc/react-query": "^10.26.0", "@trpc/react-query": "^10.26.0",
"@trpc/server": "^10.26.0", "@trpc/server": "^10.26.0",
"chroma-js": "^2.4.2",
"concurrently": "^8.2.0", "concurrently": "^8.2.0",
"cors": "^2.8.5", "cors": "^2.8.5",
"dayjs": "^1.11.8", "dayjs": "^1.11.8",
@@ -57,6 +58,7 @@
}, },
"devDependencies": { "devDependencies": {
"@openapi-contrib/openapi-schema-to-json-schema": "^4.0.5", "@openapi-contrib/openapi-schema-to-json-schema": "^4.0.5",
"@types/chroma-js": "^2.4.0",
"@types/cors": "^2.8.13", "@types/cors": "^2.8.13",
"@types/eslint": "^8.37.0", "@types/eslint": "^8.37.0",
"@types/express": "^4.17.17", "@types/express": "^4.17.17",

16
pnpm-lock.yaml generated
View File

@@ -1,4 +1,4 @@
lockfileVersion: '6.0' lockfileVersion: '6.1'
settings: settings:
autoInstallPeers: true autoInstallPeers: true
@@ -50,6 +50,9 @@ dependencies:
'@trpc/server': '@trpc/server':
specifier: ^10.26.0 specifier: ^10.26.0
version: 10.26.0 version: 10.26.0
chroma-js:
specifier: ^2.4.2
version: 2.4.2
concurrently: concurrently:
specifier: ^8.2.0 specifier: ^8.2.0
version: 8.2.0 version: 8.2.0
@@ -130,6 +133,9 @@ devDependencies:
'@openapi-contrib/openapi-schema-to-json-schema': '@openapi-contrib/openapi-schema-to-json-schema':
specifier: ^4.0.5 specifier: ^4.0.5
version: 4.0.5 version: 4.0.5
'@types/chroma-js':
specifier: ^2.4.0
version: 2.4.0
'@types/cors': '@types/cors':
specifier: ^2.8.13 specifier: ^2.8.13
version: 2.8.13 version: 2.8.13
@@ -2072,6 +2078,10 @@ packages:
'@types/node': 18.16.0 '@types/node': 18.16.0
dev: true dev: true
/@types/chroma-js@2.4.0:
resolution: {integrity: sha512-JklMxityrwjBTjGY2anH8JaTx3yjRU3/sEHSblLH1ba5lqcSh1LnImXJZO5peJfXyqKYWjHTGy4s5Wz++hARrw==}
dev: true
/@types/connect@3.4.35: /@types/connect@3.4.35:
resolution: {integrity: sha512-cdeYyv4KWoEgpBISTxWvqYsVy444DOqehiF3fM3ne10AmJ62RSyNkUnxMJXHQWRQQX2eR94m5y1IZyDwBjV9FQ==} resolution: {integrity: sha512-cdeYyv4KWoEgpBISTxWvqYsVy444DOqehiF3fM3ne10AmJ62RSyNkUnxMJXHQWRQQX2eR94m5y1IZyDwBjV9FQ==}
dependencies: dependencies:
@@ -2697,6 +2707,10 @@ packages:
fsevents: 2.3.2 fsevents: 2.3.2
dev: false dev: false
/chroma-js@2.4.2:
resolution: {integrity: sha512-U9eDw6+wt7V8z5NncY2jJfZa+hUH8XEj8FQHgFJTrUFnJfXYf4Ml4adI2vXZOjqRDpFWtYVWypDfZwnJ+HIR4A==}
dev: false
/client-only@0.0.1: /client-only@0.0.1:
resolution: {integrity: sha512-IV3Ou0jSMzZrd3pZ48nLkT9DA7Ag1pnPzaiQhpW7c3RbcqqzvzzVu+L8gfqMp/8IM2MQtSiqaCxrrcfu8I8rMA==} resolution: {integrity: sha512-IV3Ou0jSMzZrd3pZ48nLkT9DA7Ag1pnPzaiQhpW7c3RbcqqzvzzVu+L8gfqMp/8IM2MQtSiqaCxrrcfu8I8rMA==}
dev: false dev: false

View File

@@ -0,0 +1,40 @@
-- CreateEnum
CREATE TYPE "EvaluationMatchType" AS ENUM ('CONTAINS', 'DOES_NOT_CONTAIN');
-- CreateTable
CREATE TABLE "Evaluation" (
"id" UUID NOT NULL,
"name" TEXT NOT NULL,
"matchString" TEXT NOT NULL,
"matchType" "EvaluationMatchType" NOT NULL,
"experimentId" UUID NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
CONSTRAINT "Evaluation_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "EvaluationResult" (
"id" UUID NOT NULL,
"passCount" INTEGER NOT NULL,
"failCount" INTEGER NOT NULL,
"evaluationId" UUID NOT NULL,
"promptVariantId" UUID NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
CONSTRAINT "EvaluationResult_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "EvaluationResult_evaluationId_promptVariantId_key" ON "EvaluationResult"("evaluationId", "promptVariantId");
-- AddForeignKey
ALTER TABLE "Evaluation" ADD CONSTRAINT "Evaluation_experimentId_fkey" FOREIGN KEY ("experimentId") REFERENCES "Experiment"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "EvaluationResult" ADD CONSTRAINT "EvaluationResult_evaluationId_fkey" FOREIGN KEY ("evaluationId") REFERENCES "Evaluation"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "EvaluationResult" ADD CONSTRAINT "EvaluationResult_promptVariantId_fkey" FOREIGN KEY ("promptVariantId") REFERENCES "PromptVariant"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -22,6 +22,7 @@ model Experiment {
TemplateVariable TemplateVariable[] TemplateVariable TemplateVariable[]
PromptVariant PromptVariant[] PromptVariant PromptVariant[]
TestScenario TestScenario[] TestScenario TestScenario[]
Evaluation Evaluation[]
} }
model PromptVariant { model PromptVariant {
@@ -37,9 +38,10 @@ model PromptVariant {
experimentId String @db.Uuid experimentId String @db.Uuid
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade) experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
ModelOutput ModelOutput[] ModelOutput ModelOutput[]
EvaluationResult EvaluationResult[]
@@index([uiId]) @@index([uiId])
} }
@@ -80,7 +82,7 @@ model ModelOutput {
output Json output Json
statusCode Int statusCode Int
errorMessage String? errorMessage String?
timeToComplete Int @default(0) timeToComplete Int @default(0)
promptTokens Int? // Added promptTokens field promptTokens Int? // Added promptTokens field
completionTokens Int? // Added completionTokens field completionTokens Int? // Added completionTokens field
@@ -98,6 +100,44 @@ model ModelOutput {
@@index([inputHash]) @@index([inputHash])
} }
enum EvaluationMatchType {
CONTAINS
DOES_NOT_CONTAIN
}
model Evaluation {
id String @id @default(uuid()) @db.Uuid
name String
matchString String
matchType EvaluationMatchType
experimentId String @db.Uuid
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
EvaluationResult EvaluationResult[]
}
model EvaluationResult {
id String @id @default(uuid()) @db.Uuid
passCount Int
failCount Int
evaluationId String @db.Uuid
evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade)
promptVariantId String @db.Uuid
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
@@unique([evaluationId, promptVariantId])
}
// Necessary for Next auth // Necessary for Next auth
model Account { model Account {
id String @id @default(cuid()) id String @id @default(cuid())

View File

@@ -1,43 +1,216 @@
import { Text, Heading, Stack } from "@chakra-ui/react"; import {
import { useState } from "react"; Text,
Button,
HStack,
Heading,
Icon,
Input,
Stack,
VStack,
FormControl,
FormLabel,
Select,
FormHelperText,
} from "@chakra-ui/react";
import { type Evaluation, EvaluationMatchType } from "@prisma/client";
import { useCallback, useState } from "react";
import { BsPencil, BsX } from "react-icons/bs";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import { useStore } from "~/utils/store";
type EvalValues = Pick<Evaluation, "name" | "matchString" | "matchType">;
export function EvaluationEditor(props: {
evaluation: Evaluation | null;
defaultName?: string;
onSave: (id: string | undefined, vals: EvalValues) => void;
onCancel: () => void;
}) {
const [values, setValues] = useState<EvalValues>({
name: props.evaluation?.name ?? props.defaultName ?? "",
matchString: props.evaluation?.matchString ?? "",
matchType: props.evaluation?.matchType ?? "CONTAINS",
});
return (
<VStack borderTopWidth={1} borderColor="gray.200" py={4}>
<HStack w="100%">
<FormControl flex={1}>
<FormLabel fontSize="sm">Evaluation Name</FormLabel>
<Input
size="sm"
value={values.name}
onChange={(e) => setValues((values) => ({ ...values, name: e.target.value }))}
/>
</FormControl>
<FormControl flex={1}>
<FormLabel fontSize="sm">Match Type</FormLabel>
<Select
size="sm"
value={values.matchType}
onChange={(e) =>
setValues((values) => ({
...values,
matchType: e.target.value as EvaluationMatchType,
}))
}
>
{Object.values(EvaluationMatchType).map((type) => (
<option key={type} value={type}>
{type}
</option>
))}
</Select>
</FormControl>
</HStack>
<FormControl>
<FormLabel fontSize="sm">Match String</FormLabel>
<FormHelperText>
This string will be interpreted as a regex and checked against each model output.
</FormHelperText>
<Input
size="sm"
value={values.matchString}
onChange={(e) => setValues((values) => ({ ...values, matchString: e.target.value }))}
/>
</FormControl>
<HStack alignSelf="flex-end">
<Button size="sm" onClick={props.onCancel} colorScheme="gray">
Cancel
</Button>
<Button
size="sm"
onClick={() => props.onSave(props.evaluation?.id, values)}
colorScheme="blue"
>
Save
</Button>
</HStack>
</VStack>
);
}
export default function EditEvaluations() { export default function EditEvaluations() {
const experiment = useExperiment(); const experiment = useExperiment();
const vars = const evaluations =
api.templateVars.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? []; api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
const [newVariable, setNewVariable] = useState<string>(""); const [editingId, setEditingId] = useState<string | null>(null);
const newVarIsValid = newVariable.length > 0 && !vars.map((v) => v.label).includes(newVariable);
const utils = api.useContext(); const utils = api.useContext();
const addVarMutation = api.templateVars.create.useMutation(); const createMutation = api.evaluations.create.useMutation();
const [onAddVar] = useHandledAsyncCallback(async () => { const updateMutation = api.evaluations.update.useMutation();
if (!experiment.data?.id) return;
if (!newVarIsValid) return;
await addVarMutation.mutateAsync({
experimentId: experiment.data.id,
label: newVariable,
});
await utils.templateVars.list.invalidate();
setNewVariable("");
}, [addVarMutation, experiment.data?.id, newVarIsValid, newVariable]);
const deleteMutation = api.templateVars.delete.useMutation(); const deleteMutation = api.evaluations.delete.useMutation();
const [onDeleteVar] = useHandledAsyncCallback(async (id: string) => { const [onDelete] = useHandledAsyncCallback(async (id: string) => {
await deleteMutation.mutateAsync({ id }); await deleteMutation.mutateAsync({ id });
await utils.templateVars.list.invalidate(); await utils.evaluations.list.invalidate();
await utils.evaluations.results.invalidate();
}, []); }, []);
const closeDrawer = useStore((state) => state.closeDrawer); const [onSave] = useHandledAsyncCallback(async (id: string | undefined, vals: EvalValues) => {
setEditingId(null);
if (!experiment.data?.id) return;
if (id) {
await updateMutation.mutateAsync({
id,
updates: vals,
});
} else {
await createMutation.mutateAsync({
experimentId: experiment.data.id,
...vals,
});
}
await utils.evaluations.list.invalidate();
await utils.evaluations.results.invalidate();
}, []);
const onCancel = useCallback(() => {
setEditingId(null);
}, []);
return ( return (
<Stack> <Stack>
<Heading size="sm">Edit Evaluations</Heading> <Heading size="sm">Evaluations</Heading>
<Stack spacing={2} pt={2}> <Stack spacing={4}>
<Text fontSize="sm"></Text> <Text fontSize="sm">
Evaluations allow you to compare prompt performance in an automated way.
</Text>
<Stack spacing={2}>
{evaluations.map((evaluation) =>
editingId == evaluation.id ? (
<EvaluationEditor
evaluation={evaluation}
onSave={onSave}
onCancel={onCancel}
key={evaluation.id}
/>
) : (
<HStack
fontSize="sm"
borderTopWidth={1}
borderColor="gray.200"
py={4}
align="center"
key={evaluation.id}
>
<Text fontWeight="bold">{evaluation.name}</Text>
<Text flex={1}>
{evaluation.matchType}: &quot;{evaluation.matchString}&quot;
</Text>
<Button
variant="unstyled"
color="gray.400"
height="unset"
width="unset"
minW="unset"
onClick={() => setEditingId(evaluation.id)}
_hover={{
color: "gray.800",
cursor: "pointer",
}}
>
<Icon as={BsPencil} boxSize={4} />
</Button>
<Button
variant="unstyled"
color="gray.400"
height="unset"
width="unset"
minW="unset"
onClick={() => onDelete(evaluation.id)}
_hover={{
color: "gray.800",
cursor: "pointer",
}}
>
<Icon as={BsX} boxSize={6} />
</Button>
</HStack>
)
)}
{editingId == null && (
<Button
onClick={() => setEditingId("new")}
alignSelf="flex-start"
size="sm"
mt={4}
colorScheme="blue"
>
Add Evaluation
</Button>
)}
{editingId == "new" && (
<EvaluationEditor
evaluation={null}
defaultName={`Eval${evaluations.length + 1}`}
onSave={onSave}
onCancel={onCancel}
/>
)}
</Stack>
</Stack> </Stack>
</Stack> </Stack>
); );

View File

@@ -33,8 +33,8 @@ export default function EditScenarioVars() {
return ( return (
<Stack> <Stack>
<Heading size="sm">Edit Scenario Variables</Heading> <Heading size="sm">Scenario Variables</Heading>
<Stack spacing={2} pt={2}> <Stack spacing={2}>
<Text fontSize="sm"> <Text fontSize="sm">
Scenario variables can be used in your prompt variants as well as evaluations. Reference Scenario variables can be used in your prompt variants as well as evaluations. Reference
them using <Code>{"{{curly_braces}}"}</Code>. them using <Code>{"{{curly_braces}}"}</Code>.

View File

@@ -1,4 +1,4 @@
import { Button } from "@chakra-ui/react"; import { Button, Icon, Spinner } from "@chakra-ui/react";
import { BsPlus } from "react-icons/bs"; import { BsPlus } from "react-icons/bs";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
@@ -9,7 +9,7 @@ export default function NewVariantButton() {
const mutation = api.promptVariants.create.useMutation(); const mutation = api.promptVariants.create.useMutation();
const utils = api.useContext(); const utils = api.useContext();
const [onClick] = useHandledAsyncCallback(async () => { const [onClick, loading] = useHandledAsyncCallback(async () => {
if (!experiment.data) return; if (!experiment.data) return;
await mutation.mutateAsync({ await mutation.mutateAsync({
experimentId: experiment.data.id, experimentId: experiment.data.id,
@@ -30,7 +30,7 @@ export default function NewVariantButton() {
height="unset" height="unset"
minH={headerMinHeight} minH={headerMinHeight}
> >
<BsPlus size={24} /> <Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />
Add Variant Add Variant
</Button> </Button>
); );

View File

@@ -1,17 +1,18 @@
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { type PromptVariant, type Scenario } from "./types"; import { type PromptVariant, type Scenario } from "./types";
import { Spinner, Text, Box, Center, Flex, Icon } from "@chakra-ui/react"; import { Spinner, Text, Box, Center, Flex, Icon, HStack } from "@chakra-ui/react";
import { useExperiment } from "~/utils/hooks"; import { useExperiment } from "~/utils/hooks";
import SyntaxHighlighter from "react-syntax-highlighter"; import SyntaxHighlighter from "react-syntax-highlighter";
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs"; import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
import stringify from "json-stringify-pretty-compact"; import stringify from "json-stringify-pretty-compact";
import { useMemo, type ReactElement } from "react"; import { useMemo, type ReactElement } from "react";
import { BsClock } from "react-icons/bs"; import { BsCheck, BsClock, BsX } from "react-icons/bs";
import { type ModelOutput } from "@prisma/client"; import { type ModelOutput } from "@prisma/client";
import { type ChatCompletion } from "openai/resources/chat"; import { type ChatCompletion } from "openai/resources/chat";
import { generateChannel } from "~/utils/generateChannel"; import { generateChannel } from "~/utils/generateChannel";
import { isObject } from "lodash"; import { isObject } from "lodash";
import useSocket from "~/utils/useSocket"; import useSocket from "~/utils/useSocket";
import { evaluateOutput } from "~/server/utils/evaluateOutput";
export default function OutputCell({ export default function OutputCell({
scenario, scenario,
@@ -109,7 +110,7 @@ export default function OutputCell({
{ maxLength: 40 } { maxLength: 40 }
)} )}
</SyntaxHighlighter> </SyntaxHighlighter>
<OutputStats modelOutput={output.data} /> <OutputStats modelOutput={output.data} scenario={scenario} />
</Box> </Box>
); );
} }
@@ -120,18 +121,44 @@ export default function OutputCell({
return ( return (
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap"> <Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
{contentToDisplay} {contentToDisplay}
{output.data && <OutputStats modelOutput={output.data} />} {output.data && <OutputStats modelOutput={output.data} scenario={scenario} />}
</Flex> </Flex>
); );
} }
const OutputStats = ({ modelOutput }: { modelOutput: ModelOutput }) => { const OutputStats = ({
modelOutput,
scenario,
}: {
modelOutput: ModelOutput;
scenario: Scenario;
}) => {
const timeToComplete = modelOutput.timeToComplete; const timeToComplete = modelOutput.timeToComplete;
const experiment = useExperiment();
const evals =
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
return ( return (
<Flex justifyContent="flex-end" alignItems="center" color="gray.500" fontSize="xs" mt={2}> <HStack align="center" color="gray.500" fontSize="xs" mt={2}>
<Icon as={BsClock} mr={0.5} /> <HStack flex={1}>
<Text>{(timeToComplete / 1000).toFixed(2)}s</Text> {evals.map((evaluation) => {
</Flex> const passed = evaluateOutput(modelOutput, scenario, evaluation);
return (
<HStack spacing={0} key={evaluation.id}>
<Text>{evaluation.name}</Text>
<Icon
as={passed ? BsCheck : BsX}
color={passed ? "green.500" : "red.500"}
boxSize={6}
/>
</HStack>
);
})}
</HStack>
<HStack>
<Icon as={BsClock} mr={0.5} />
<Text>{(timeToComplete / 1000).toFixed(2)}s</Text>
</HStack>
</HStack>
); );
}; };

View File

@@ -6,9 +6,11 @@ import {
DrawerHeader, DrawerHeader,
DrawerOverlay, DrawerOverlay,
Heading, Heading,
Stack,
} from "@chakra-ui/react"; } from "@chakra-ui/react";
import { useStore } from "~/utils/store"; import { useStore } from "~/utils/store";
import EditScenarioVars from "./EditScenarioVars"; import EditScenarioVars from "./EditScenarioVars";
import EditEvaluations from "./EditEvaluations";
export default function SettingsDrawer() { export default function SettingsDrawer() {
const isOpen = useStore((state) => state.drawerOpen); const isOpen = useStore((state) => state.drawerOpen);
@@ -23,8 +25,10 @@ export default function SettingsDrawer() {
<Heading size="md">Settings</Heading> <Heading size="md">Settings</Heading>
</DrawerHeader> </DrawerHeader>
<DrawerBody> <DrawerBody>
<EditScenarioVars /> <Stack spacing={6}>
{/* <EditEvaluations /> */} <EditScenarioVars />
<EditEvaluations />
</Stack>
</DrawerBody> </DrawerBody>
</DrawerContent> </DrawerContent>
</Drawer> </Drawer>

View File

@@ -0,0 +1,38 @@
import { HStack, Text, useToken } from "@chakra-ui/react";
import { type PromptVariant } from "./types";
import { cellPadding } from "../constants";
import { api } from "~/utils/api";
import chroma from "chroma-js";
export default function VariantStats(props: { variant: PromptVariant }) {
const evalResults =
api.evaluations.results.useQuery({
variantId: props.variant.id,
}).data ?? [];
const [passColor, neutralColor, failColor] = useToken("colors", [
"green.500",
"gray.500",
"red.500",
]);
const scale = chroma.scale([failColor, neutralColor, passColor]).domain([0, 0.5, 1]);
if (!(evalResults.length > 0)) return null;
return (
<HStack px={cellPadding.x} py={cellPadding.y} fontSize="sm">
{evalResults.map((result) => {
const passedFrac = result.passCount / (result.passCount + result.failCount);
return (
<HStack key={result.id}>
<Text>{result.evaluation.name}</Text>
<Text color={scale(passedFrac).hex()} fontWeight="bold">
{(passedFrac * 100).toFixed(1)}%
</Text>
</HStack>
);
})}
</HStack>
);
}

View File

@@ -1,11 +1,4 @@
import { import { Button, Grid, GridItem, HStack, Heading, type SystemStyleObject } from "@chakra-ui/react";
Button,
Grid,
GridItem,
HStack,
Heading,
type SystemStyleObject,
} from "@chakra-ui/react";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import NewScenarioButton from "./NewScenarioButton"; import NewScenarioButton from "./NewScenarioButton";
import NewVariantButton from "./NewVariantButton"; import NewVariantButton from "./NewVariantButton";
@@ -15,6 +8,7 @@ import VariantHeader from "./VariantHeader";
import { cellPadding } from "../constants"; import { cellPadding } from "../constants";
import { BsPencil } from "react-icons/bs"; import { BsPencil } from "react-icons/bs";
import { useStore } from "~/utils/store"; import { useStore } from "~/utils/store";
import VariantStats from "./VariantStats";
const stickyHeaderStyle: SystemStyleObject = { const stickyHeaderStyle: SystemStyleObject = {
position: "sticky", position: "sticky",
@@ -38,6 +32,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
if (!variants.data || !scenarios.data) return null; if (!variants.data || !scenarios.data) return null;
const allCols = variants.data.length + 1; const allCols = variants.data.length + 1;
const headerRows = 3;
return ( return (
<Grid <Grid
@@ -55,7 +50,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
<GridItem <GridItem
display="flex" display="flex"
alignItems="flex-end" alignItems="flex-end"
rowSpan={2} rowSpan={headerRows}
px={cellPadding.x} px={cellPadding.x}
py={cellPadding.y} py={cellPadding.y}
> >
@@ -82,7 +77,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
</GridItem> </GridItem>
))} ))}
<GridItem <GridItem
rowSpan={scenarios.data.length + 2} rowSpan={scenarios.data.length + headerRows}
padding={0} padding={0}
// Have to use `style` instead of emotion style props to work around css specificity issues conflicting with the "> *" selector on Grid // Have to use `style` instead of emotion style props to work around css specificity issues conflicting with the "> *" selector on Grid
style={{ borderRightWidth: 0, borderBottomWidth: 0 }} style={{ borderRightWidth: 0, borderBottomWidth: 0 }}
@@ -92,10 +87,15 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
</GridItem> </GridItem>
{variants.data.map((variant) => ( {variants.data.map((variant) => (
<GridItem key={variant.uiId} padding={0}> <GridItem key={variant.uiId}>
<VariantConfigEditor variant={variant} /> <VariantConfigEditor variant={variant} />
</GridItem> </GridItem>
))} ))}
{variants.data.map((variant) => (
<GridItem key={variant.uiId}>
<VariantStats variant={variant} />
</GridItem>
))}
{scenarios.data.map((scenario) => ( {scenarios.data.map((scenario) => (
<ScenarioRow key={scenario.uiId} scenario={scenario} variants={variants.data} /> <ScenarioRow key={scenario.uiId} scenario={scenario} variants={variants.data} />
))} ))}

View File

@@ -4,6 +4,7 @@ import { experimentsRouter } from "./routers/experiments.router";
import { scenariosRouter } from "./routers/scenarios.router"; import { scenariosRouter } from "./routers/scenarios.router";
import { modelOutputsRouter } from "./routers/modelOutputs.router"; import { modelOutputsRouter } from "./routers/modelOutputs.router";
import { templateVarsRouter } from "./routers/templateVariables.router"; import { templateVarsRouter } from "./routers/templateVariables.router";
import { evaluationsRouter } from "./routers/evaluations.router";
/** /**
* This is the primary router for your server. * This is the primary router for your server.
@@ -16,6 +17,7 @@ export const appRouter = createTRPCRouter({
scenarios: scenariosRouter, scenarios: scenariosRouter,
outputs: modelOutputsRouter, outputs: modelOutputsRouter,
templateVars: templateVarsRouter, templateVars: templateVarsRouter,
evaluations: evaluationsRouter,
}); });
// export type definition of API // export type definition of API

View File

@@ -0,0 +1,77 @@
import { EvaluationMatchType } from "@prisma/client";
import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { reevaluateEvaluation } from "~/server/utils/evaluations";
export const evaluationsRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
return await prisma.evaluation.findMany({
where: {
experimentId: input.experimentId,
},
orderBy: { createdAt: "asc" },
});
}),
results: publicProcedure.input(z.object({ variantId: z.string() })).query(async ({ input }) => {
return await prisma.evaluationResult.findMany({
where: {
promptVariantId: input.variantId,
},
include: { evaluation: true },
});
}),
create: publicProcedure
.input(
z.object({
experimentId: z.string(),
name: z.string(),
matchString: z.string(),
matchType: z.nativeEnum(EvaluationMatchType),
})
)
.mutation(async ({ input }) => {
const evaluation = await prisma.evaluation.create({
data: {
experimentId: input.experimentId,
name: input.name,
matchString: input.matchString,
matchType: input.matchType,
},
});
await reevaluateEvaluation(evaluation);
}),
update: publicProcedure
.input(
z.object({
id: z.string(),
updates: z.object({
name: z.string().optional(),
matchString: z.string().optional(),
matchType: z.nativeEnum(EvaluationMatchType).optional(),
}),
})
)
.mutation(async ({ input }) => {
await prisma.evaluation.update({
where: { id: input.id },
data: {
name: input.updates.name,
matchString: input.updates.matchString,
matchType: input.updates.matchType,
},
});
await reevaluateEvaluation(
await prisma.evaluation.findUniqueOrThrow({ where: { id: input.id } })
);
}),
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
await prisma.evaluation.delete({
where: { id: input.id },
});
}),
});

View File

@@ -1,15 +1,18 @@
import { z } from "zod"; import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import fillTemplate, { type VariableMap } from "~/server/utils/fillTemplate"; import { fillTemplateJson, type VariableMap } from "~/server/utils/fillTemplate";
import { type JSONSerializable } from "~/server/types"; import { type JSONSerializable } from "~/server/types";
import { getCompletion } from "~/server/utils/getCompletion";
import crypto from "crypto"; import crypto from "crypto";
import type { Prisma } from "@prisma/client"; import type { Prisma } from "@prisma/client";
import { reevaluateVariant } from "~/server/utils/evaluations";
import { getCompletion } from "~/server/utils/getCompletion";
export const modelOutputsRouter = createTRPCRouter({ export const modelOutputsRouter = createTRPCRouter({
get: publicProcedure get: publicProcedure
.input(z.object({ scenarioId: z.string(), variantId: z.string(), channel: z.string().optional() })) .input(
z.object({ scenarioId: z.string(), variantId: z.string(), channel: z.string().optional() })
)
.query(async ({ input }) => { .query(async ({ input }) => {
const existing = await prisma.modelOutput.findUnique({ const existing = await prisma.modelOutput.findUnique({
where: { where: {
@@ -36,7 +39,7 @@ export const modelOutputsRouter = createTRPCRouter({
if (!variant || !scenario) return null; if (!variant || !scenario) return null;
const filledTemplate = fillTemplate( const filledTemplate = fillTemplateJson(
variant.config as JSONSerializable, variant.config as JSONSerializable,
scenario.variableValues as VariableMap scenario.variableValues as VariableMap
); );
@@ -73,6 +76,8 @@ export const modelOutputsRouter = createTRPCRouter({
}, },
}); });
await reevaluateVariant(input.variantId);
return modelOutput; return modelOutput;
}), }),
}); });

View File

@@ -10,9 +10,7 @@ export const promptVariantsRouter = createTRPCRouter({
experimentId: input.experimentId, experimentId: input.experimentId,
visible: true, visible: true,
}, },
orderBy: { orderBy: { sortIndex: "asc" },
sortIndex: "asc",
},
}); });
}), }),

View File

@@ -0,0 +1,31 @@
import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/client";
import { type ChatCompletion } from "openai/resources/chat";
import { type VariableMap, fillTemplate } from "./fillTemplate";
export const evaluateOutput = (
modelOutput: ModelOutput,
scenario: TestScenario,
evaluation: Evaluation
): boolean => {
const output = modelOutput.output as unknown as ChatCompletion;
const message = output.choices?.[0]?.message;
if (!message) return false;
const stringifiedMessage = JSON.stringify(message);
const matchRegex = fillTemplate(evaluation.matchString, scenario.variableValues as VariableMap);
let match;
switch (evaluation.matchType) {
case "CONTAINS":
match = stringifiedMessage.match(matchRegex) !== null;
break;
case "DOES_NOT_CONTAIN":
match = stringifiedMessage.match(matchRegex) === null;
break;
}
return match;
};

View File

@@ -0,0 +1,91 @@
import { type Evaluation } from "@prisma/client";
import { prisma } from "../db";
import { evaluateOutput } from "./evaluateOutput";
export const reevaluateVariant = async (variantId: string) => {
const variant = await prisma.promptVariant.findUnique({
where: { id: variantId },
});
if (!variant) return;
const evaluations = await prisma.evaluation.findMany({
where: { experimentId: variant.experimentId },
});
const modelOutputs = await prisma.modelOutput.findMany({
where: { promptVariantId: variantId },
include: { testScenario: true },
});
const scenarios = await prisma.testScenario.findMany({
where: { experimentId: variant.experimentId, visible: true },
});
await Promise.all(
evaluations.map(async (evaluation) => {
const passCount = modelOutputs.filter((output) =>
evaluateOutput(output, output.testScenario, evaluation)
).length;
const failCount = scenarios.length - passCount;
await prisma.evaluationResult.upsert({
where: {
evaluationId_promptVariantId: {
evaluationId: evaluation.id,
promptVariantId: variantId,
},
},
create: {
evaluationId: evaluation.id,
promptVariantId: variantId,
passCount,
failCount,
},
update: {
passCount,
failCount,
},
});
})
);
};
export const reevaluateEvaluation = async (evaluation: Evaluation) => {
const variants = await prisma.promptVariant.findMany({
where: { experimentId: evaluation.experimentId, visible: true },
});
const modelOutputs = await prisma.modelOutput.findMany({
where: { promptVariantId: { in: variants.map((v) => v.id) }, testScenario: { visible: true } },
include: { testScenario: true },
});
await Promise.all(
variants.map(async (variant) => {
const outputs = modelOutputs.filter((output) => output.promptVariantId === variant.id);
const passCount = outputs.filter((output) =>
evaluateOutput(output, output.testScenario, evaluation)
).length;
const failCount = outputs.length - passCount;
await prisma.evaluationResult.upsert({
where: {
evaluationId_promptVariantId: {
evaluationId: evaluation.id,
promptVariantId: variant.id,
},
},
create: {
evaluationId: evaluation.id,
promptVariantId: variant.id,
passCount,
failCount,
},
update: {
passCount,
failCount,
},
});
})
);
};

View File

@@ -2,17 +2,21 @@ import { type JSONSerializable } from "../types";
export type VariableMap = Record<string, string>; export type VariableMap = Record<string, string>;
export default function fillTemplate<T extends JSONSerializable>( export function fillTemplate(template: string, variables: VariableMap): string {
return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || "");
}
export function fillTemplateJson<T extends JSONSerializable>(
template: T, template: T,
variables: VariableMap variables: VariableMap
): T { ): T {
if (typeof template === "string") { if (typeof template === "string") {
return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || "") as T; return fillTemplate(template, variables) as T;
} else if (Array.isArray(template)) { } else if (Array.isArray(template)) {
return template.map((item) => fillTemplate(item, variables)) as T; return template.map((item) => fillTemplateJson(item, variables)) as T;
} else if (typeof template === "object" && template !== null) { } else if (typeof template === "object" && template !== null) {
return Object.keys(template).reduce((acc, key) => { return Object.keys(template).reduce((acc, key) => {
acc[key] = fillTemplate(template[key] as JSONSerializable, variables); acc[key] = fillTemplateJson(template[key] as JSONSerializable, variables);
return acc; return acc;
}, {} as { [key: string]: JSONSerializable } & T); }, {} as { [key: string]: JSONSerializable } & T);
} else { } else {