From f728027ef6de194c3651af83c6552adea78116f5 Mon Sep 17 00:00:00 2001 From: Kyle Corbitt Date: Thu, 6 Jul 2023 13:39:13 -0700 Subject: [PATCH] add evaluations --- package.json | 2 + pnpm-lock.yaml | 16 +- .../migration.sql | 40 ++++ prisma/schema.prisma | 48 +++- .../OutputsTable/EditEvaluations.tsx | 223 ++++++++++++++++-- .../OutputsTable/EditScenarioVars.tsx | 4 +- .../OutputsTable/NewVariantButton.tsx | 6 +- src/components/OutputsTable/OutputCell.tsx | 45 +++- .../OutputsTable/SettingsDrawer.tsx | 8 +- src/components/OutputsTable/VariantStats.tsx | 38 +++ src/components/OutputsTable/index.tsx | 22 +- src/server/api/root.router.ts | 2 + src/server/api/routers/evaluations.router.ts | 77 ++++++ src/server/api/routers/modelOutputs.router.ts | 13 +- .../api/routers/promptVariants.router.ts | 4 +- src/server/utils/evaluateOutput.ts | 31 +++ src/server/utils/evaluations.ts | 91 +++++++ src/server/utils/fillTemplate.ts | 12 +- 18 files changed, 614 insertions(+), 68 deletions(-) create mode 100644 prisma/migrations/20230706201223_add_evaluations/migration.sql create mode 100644 src/components/OutputsTable/VariantStats.tsx create mode 100644 src/server/api/routers/evaluations.router.ts create mode 100644 src/server/utils/evaluateOutput.ts create mode 100644 src/server/utils/evaluations.ts diff --git a/package.json b/package.json index ef1c562..cd5402d 100644 --- a/package.json +++ b/package.json @@ -29,6 +29,7 @@ "@trpc/next": "^10.26.0", "@trpc/react-query": "^10.26.0", "@trpc/server": "^10.26.0", + "chroma-js": "^2.4.2", "concurrently": "^8.2.0", "cors": "^2.8.5", "dayjs": "^1.11.8", @@ -57,6 +58,7 @@ }, "devDependencies": { "@openapi-contrib/openapi-schema-to-json-schema": "^4.0.5", + "@types/chroma-js": "^2.4.0", "@types/cors": "^2.8.13", "@types/eslint": "^8.37.0", "@types/express": "^4.17.17", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index afabca1..41da3d7 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1,4 +1,4 @@ -lockfileVersion: '6.0' +lockfileVersion: '6.1' settings: autoInstallPeers: true @@ -50,6 +50,9 @@ dependencies: '@trpc/server': specifier: ^10.26.0 version: 10.26.0 + chroma-js: + specifier: ^2.4.2 + version: 2.4.2 concurrently: specifier: ^8.2.0 version: 8.2.0 @@ -130,6 +133,9 @@ devDependencies: '@openapi-contrib/openapi-schema-to-json-schema': specifier: ^4.0.5 version: 4.0.5 + '@types/chroma-js': + specifier: ^2.4.0 + version: 2.4.0 '@types/cors': specifier: ^2.8.13 version: 2.8.13 @@ -2072,6 +2078,10 @@ packages: '@types/node': 18.16.0 dev: true + /@types/chroma-js@2.4.0: + resolution: {integrity: sha512-JklMxityrwjBTjGY2anH8JaTx3yjRU3/sEHSblLH1ba5lqcSh1LnImXJZO5peJfXyqKYWjHTGy4s5Wz++hARrw==} + dev: true + /@types/connect@3.4.35: resolution: {integrity: sha512-cdeYyv4KWoEgpBISTxWvqYsVy444DOqehiF3fM3ne10AmJ62RSyNkUnxMJXHQWRQQX2eR94m5y1IZyDwBjV9FQ==} dependencies: @@ -2697,6 +2707,10 @@ packages: fsevents: 2.3.2 dev: false + /chroma-js@2.4.2: + resolution: {integrity: sha512-U9eDw6+wt7V8z5NncY2jJfZa+hUH8XEj8FQHgFJTrUFnJfXYf4Ml4adI2vXZOjqRDpFWtYVWypDfZwnJ+HIR4A==} + dev: false + /client-only@0.0.1: resolution: {integrity: sha512-IV3Ou0jSMzZrd3pZ48nLkT9DA7Ag1pnPzaiQhpW7c3RbcqqzvzzVu+L8gfqMp/8IM2MQtSiqaCxrrcfu8I8rMA==} dev: false diff --git a/prisma/migrations/20230706201223_add_evaluations/migration.sql b/prisma/migrations/20230706201223_add_evaluations/migration.sql new file mode 100644 index 0000000..2829fc6 --- /dev/null +++ b/prisma/migrations/20230706201223_add_evaluations/migration.sql @@ -0,0 +1,40 @@ +-- CreateEnum +CREATE TYPE "EvaluationMatchType" AS ENUM ('CONTAINS', 'DOES_NOT_CONTAIN'); + +-- CreateTable +CREATE TABLE "Evaluation" ( + "id" UUID NOT NULL, + "name" TEXT NOT NULL, + "matchString" TEXT NOT NULL, + "matchType" "EvaluationMatchType" NOT NULL, + "experimentId" UUID NOT NULL, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" TIMESTAMP(3) NOT NULL, + + CONSTRAINT "Evaluation_pkey" PRIMARY KEY ("id") +); + +-- CreateTable +CREATE TABLE "EvaluationResult" ( + "id" UUID NOT NULL, + "passCount" INTEGER NOT NULL, + "failCount" INTEGER NOT NULL, + "evaluationId" UUID NOT NULL, + "promptVariantId" UUID NOT NULL, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" TIMESTAMP(3) NOT NULL, + + CONSTRAINT "EvaluationResult_pkey" PRIMARY KEY ("id") +); + +-- CreateIndex +CREATE UNIQUE INDEX "EvaluationResult_evaluationId_promptVariantId_key" ON "EvaluationResult"("evaluationId", "promptVariantId"); + +-- AddForeignKey +ALTER TABLE "Evaluation" ADD CONSTRAINT "Evaluation_experimentId_fkey" FOREIGN KEY ("experimentId") REFERENCES "Experiment"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "EvaluationResult" ADD CONSTRAINT "EvaluationResult_evaluationId_fkey" FOREIGN KEY ("evaluationId") REFERENCES "Evaluation"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "EvaluationResult" ADD CONSTRAINT "EvaluationResult_promptVariantId_fkey" FOREIGN KEY ("promptVariantId") REFERENCES "PromptVariant"("id") ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/prisma/schema.prisma b/prisma/schema.prisma index 7e4e0d9..8a6f07b 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -22,6 +22,7 @@ model Experiment { TemplateVariable TemplateVariable[] PromptVariant PromptVariant[] TestScenario TestScenario[] + Evaluation Evaluation[] } model PromptVariant { @@ -37,9 +38,10 @@ model PromptVariant { experimentId String @db.Uuid experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade) - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - ModelOutput ModelOutput[] + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + ModelOutput ModelOutput[] + EvaluationResult EvaluationResult[] @@index([uiId]) } @@ -80,7 +82,7 @@ model ModelOutput { output Json statusCode Int errorMessage String? - timeToComplete Int @default(0) + timeToComplete Int @default(0) promptTokens Int? // Added promptTokens field completionTokens Int? // Added completionTokens field @@ -98,6 +100,44 @@ model ModelOutput { @@index([inputHash]) } +enum EvaluationMatchType { + CONTAINS + DOES_NOT_CONTAIN +} + +model Evaluation { + id String @id @default(uuid()) @db.Uuid + + name String + matchString String + matchType EvaluationMatchType + + experimentId String @db.Uuid + experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade) + + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + EvaluationResult EvaluationResult[] +} + +model EvaluationResult { + id String @id @default(uuid()) @db.Uuid + + passCount Int + failCount Int + + evaluationId String @db.Uuid + evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade) + + promptVariantId String @db.Uuid + promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade) + + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + @@unique([evaluationId, promptVariantId]) +} + // Necessary for Next auth model Account { id String @id @default(cuid()) diff --git a/src/components/OutputsTable/EditEvaluations.tsx b/src/components/OutputsTable/EditEvaluations.tsx index 76e0576..9a134ec 100644 --- a/src/components/OutputsTable/EditEvaluations.tsx +++ b/src/components/OutputsTable/EditEvaluations.tsx @@ -1,43 +1,216 @@ -import { Text, Heading, Stack } from "@chakra-ui/react"; -import { useState } from "react"; +import { + Text, + Button, + HStack, + Heading, + Icon, + Input, + Stack, + VStack, + FormControl, + FormLabel, + Select, + FormHelperText, +} from "@chakra-ui/react"; +import { type Evaluation, EvaluationMatchType } from "@prisma/client"; +import { useCallback, useState } from "react"; +import { BsPencil, BsX } from "react-icons/bs"; import { api } from "~/utils/api"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; -import { useStore } from "~/utils/store"; + +type EvalValues = Pick; + +export function EvaluationEditor(props: { + evaluation: Evaluation | null; + defaultName?: string; + onSave: (id: string | undefined, vals: EvalValues) => void; + onCancel: () => void; +}) { + const [values, setValues] = useState({ + name: props.evaluation?.name ?? props.defaultName ?? "", + matchString: props.evaluation?.matchString ?? "", + matchType: props.evaluation?.matchType ?? "CONTAINS", + }); + + return ( + + + + Evaluation Name + setValues((values) => ({ ...values, name: e.target.value }))} + /> + + + Match Type + + + + + Match String + + This string will be interpreted as a regex and checked against each model output. + + setValues((values) => ({ ...values, matchString: e.target.value }))} + /> + + + + + + + ); +} export default function EditEvaluations() { const experiment = useExperiment(); - const vars = - api.templateVars.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? []; + const evaluations = + api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? []; - const [newVariable, setNewVariable] = useState(""); - const newVarIsValid = newVariable.length > 0 && !vars.map((v) => v.label).includes(newVariable); + const [editingId, setEditingId] = useState(null); const utils = api.useContext(); - const addVarMutation = api.templateVars.create.useMutation(); - const [onAddVar] = useHandledAsyncCallback(async () => { - if (!experiment.data?.id) return; - if (!newVarIsValid) return; - await addVarMutation.mutateAsync({ - experimentId: experiment.data.id, - label: newVariable, - }); - await utils.templateVars.list.invalidate(); - setNewVariable(""); - }, [addVarMutation, experiment.data?.id, newVarIsValid, newVariable]); + const createMutation = api.evaluations.create.useMutation(); + const updateMutation = api.evaluations.update.useMutation(); - const deleteMutation = api.templateVars.delete.useMutation(); - const [onDeleteVar] = useHandledAsyncCallback(async (id: string) => { + const deleteMutation = api.evaluations.delete.useMutation(); + const [onDelete] = useHandledAsyncCallback(async (id: string) => { await deleteMutation.mutateAsync({ id }); - await utils.templateVars.list.invalidate(); + await utils.evaluations.list.invalidate(); + await utils.evaluations.results.invalidate(); }, []); - const closeDrawer = useStore((state) => state.closeDrawer); + const [onSave] = useHandledAsyncCallback(async (id: string | undefined, vals: EvalValues) => { + setEditingId(null); + if (!experiment.data?.id) return; + + if (id) { + await updateMutation.mutateAsync({ + id, + updates: vals, + }); + } else { + await createMutation.mutateAsync({ + experimentId: experiment.data.id, + ...vals, + }); + } + await utils.evaluations.list.invalidate(); + await utils.evaluations.results.invalidate(); + }, []); + + const onCancel = useCallback(() => { + setEditingId(null); + }, []); return ( - Edit Evaluations - - + Evaluations + + + Evaluations allow you to compare prompt performance in an automated way. + + + {evaluations.map((evaluation) => + editingId == evaluation.id ? ( + + ) : ( + + {evaluation.name} + + {evaluation.matchType}: "{evaluation.matchString}" + + + + + ) + )} + {editingId == null && ( + + )} + {editingId == "new" && ( + + )} + ); diff --git a/src/components/OutputsTable/EditScenarioVars.tsx b/src/components/OutputsTable/EditScenarioVars.tsx index fb2cfe0..1e912d7 100644 --- a/src/components/OutputsTable/EditScenarioVars.tsx +++ b/src/components/OutputsTable/EditScenarioVars.tsx @@ -33,8 +33,8 @@ export default function EditScenarioVars() { return ( - Edit Scenario Variables - + Scenario Variables + Scenario variables can be used in your prompt variants as well as evaluations. Reference them using {"{{curly_braces}}"}. diff --git a/src/components/OutputsTable/NewVariantButton.tsx b/src/components/OutputsTable/NewVariantButton.tsx index 34fac45..e4d6b47 100644 --- a/src/components/OutputsTable/NewVariantButton.tsx +++ b/src/components/OutputsTable/NewVariantButton.tsx @@ -1,4 +1,4 @@ -import { Button } from "@chakra-ui/react"; +import { Button, Icon, Spinner } from "@chakra-ui/react"; import { BsPlus } from "react-icons/bs"; import { api } from "~/utils/api"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; @@ -9,7 +9,7 @@ export default function NewVariantButton() { const mutation = api.promptVariants.create.useMutation(); const utils = api.useContext(); - const [onClick] = useHandledAsyncCallback(async () => { + const [onClick, loading] = useHandledAsyncCallback(async () => { if (!experiment.data) return; await mutation.mutateAsync({ experimentId: experiment.data.id, @@ -30,7 +30,7 @@ export default function NewVariantButton() { height="unset" minH={headerMinHeight} > - + Add Variant ); diff --git a/src/components/OutputsTable/OutputCell.tsx b/src/components/OutputsTable/OutputCell.tsx index 0aed4d5..c38b663 100644 --- a/src/components/OutputsTable/OutputCell.tsx +++ b/src/components/OutputsTable/OutputCell.tsx @@ -1,17 +1,18 @@ import { api } from "~/utils/api"; import { type PromptVariant, type Scenario } from "./types"; -import { Spinner, Text, Box, Center, Flex, Icon } from "@chakra-ui/react"; +import { Spinner, Text, Box, Center, Flex, Icon, HStack } from "@chakra-ui/react"; import { useExperiment } from "~/utils/hooks"; import SyntaxHighlighter from "react-syntax-highlighter"; import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs"; import stringify from "json-stringify-pretty-compact"; import { useMemo, type ReactElement } from "react"; -import { BsClock } from "react-icons/bs"; +import { BsCheck, BsClock, BsX } from "react-icons/bs"; import { type ModelOutput } from "@prisma/client"; import { type ChatCompletion } from "openai/resources/chat"; import { generateChannel } from "~/utils/generateChannel"; import { isObject } from "lodash"; import useSocket from "~/utils/useSocket"; +import { evaluateOutput } from "~/server/utils/evaluateOutput"; export default function OutputCell({ scenario, @@ -109,7 +110,7 @@ export default function OutputCell({ { maxLength: 40 } )} - + ); } @@ -120,18 +121,44 @@ export default function OutputCell({ return ( {contentToDisplay} - {output.data && } + {output.data && } ); } -const OutputStats = ({ modelOutput }: { modelOutput: ModelOutput }) => { +const OutputStats = ({ + modelOutput, + scenario, +}: { + modelOutput: ModelOutput; + scenario: Scenario; +}) => { const timeToComplete = modelOutput.timeToComplete; + const experiment = useExperiment(); + const evals = + api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? []; return ( - - - {(timeToComplete / 1000).toFixed(2)}s - + + + {evals.map((evaluation) => { + const passed = evaluateOutput(modelOutput, scenario, evaluation); + return ( + + {evaluation.name} + + + ); + })} + + + + {(timeToComplete / 1000).toFixed(2)}s + + ); }; diff --git a/src/components/OutputsTable/SettingsDrawer.tsx b/src/components/OutputsTable/SettingsDrawer.tsx index a07487a..1278b2d 100644 --- a/src/components/OutputsTable/SettingsDrawer.tsx +++ b/src/components/OutputsTable/SettingsDrawer.tsx @@ -6,9 +6,11 @@ import { DrawerHeader, DrawerOverlay, Heading, + Stack, } from "@chakra-ui/react"; import { useStore } from "~/utils/store"; import EditScenarioVars from "./EditScenarioVars"; +import EditEvaluations from "./EditEvaluations"; export default function SettingsDrawer() { const isOpen = useStore((state) => state.drawerOpen); @@ -23,8 +25,10 @@ export default function SettingsDrawer() { Settings - - {/* */} + + + + diff --git a/src/components/OutputsTable/VariantStats.tsx b/src/components/OutputsTable/VariantStats.tsx new file mode 100644 index 0000000..4143f54 --- /dev/null +++ b/src/components/OutputsTable/VariantStats.tsx @@ -0,0 +1,38 @@ +import { HStack, Text, useToken } from "@chakra-ui/react"; +import { type PromptVariant } from "./types"; +import { cellPadding } from "../constants"; +import { api } from "~/utils/api"; +import chroma from "chroma-js"; + +export default function VariantStats(props: { variant: PromptVariant }) { + const evalResults = + api.evaluations.results.useQuery({ + variantId: props.variant.id, + }).data ?? []; + + const [passColor, neutralColor, failColor] = useToken("colors", [ + "green.500", + "gray.500", + "red.500", + ]); + + const scale = chroma.scale([failColor, neutralColor, passColor]).domain([0, 0.5, 1]); + + if (!(evalResults.length > 0)) return null; + + return ( + + {evalResults.map((result) => { + const passedFrac = result.passCount / (result.passCount + result.failCount); + return ( + + {result.evaluation.name} + + {(passedFrac * 100).toFixed(1)}% + + + ); + })} + + ); +} diff --git a/src/components/OutputsTable/index.tsx b/src/components/OutputsTable/index.tsx index c8649db..4d468b9 100644 --- a/src/components/OutputsTable/index.tsx +++ b/src/components/OutputsTable/index.tsx @@ -1,11 +1,4 @@ -import { - Button, - Grid, - GridItem, - HStack, - Heading, - type SystemStyleObject, -} from "@chakra-ui/react"; +import { Button, Grid, GridItem, HStack, Heading, type SystemStyleObject } from "@chakra-ui/react"; import { api } from "~/utils/api"; import NewScenarioButton from "./NewScenarioButton"; import NewVariantButton from "./NewVariantButton"; @@ -15,6 +8,7 @@ import VariantHeader from "./VariantHeader"; import { cellPadding } from "../constants"; import { BsPencil } from "react-icons/bs"; import { useStore } from "~/utils/store"; +import VariantStats from "./VariantStats"; const stickyHeaderStyle: SystemStyleObject = { position: "sticky", @@ -38,6 +32,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string | if (!variants.data || !scenarios.data) return null; const allCols = variants.data.length + 1; + const headerRows = 3; return ( @@ -82,7 +77,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string | ))} *" selector on Grid style={{ borderRightWidth: 0, borderBottomWidth: 0 }} @@ -92,10 +87,15 @@ export default function OutputsTable({ experimentId }: { experimentId: string | {variants.data.map((variant) => ( - + ))} + {variants.data.map((variant) => ( + + + + ))} {scenarios.data.map((scenario) => ( ))} diff --git a/src/server/api/root.router.ts b/src/server/api/root.router.ts index 7383203..cb1ff9a 100644 --- a/src/server/api/root.router.ts +++ b/src/server/api/root.router.ts @@ -4,6 +4,7 @@ import { experimentsRouter } from "./routers/experiments.router"; import { scenariosRouter } from "./routers/scenarios.router"; import { modelOutputsRouter } from "./routers/modelOutputs.router"; import { templateVarsRouter } from "./routers/templateVariables.router"; +import { evaluationsRouter } from "./routers/evaluations.router"; /** * This is the primary router for your server. @@ -16,6 +17,7 @@ export const appRouter = createTRPCRouter({ scenarios: scenariosRouter, outputs: modelOutputsRouter, templateVars: templateVarsRouter, + evaluations: evaluationsRouter, }); // export type definition of API diff --git a/src/server/api/routers/evaluations.router.ts b/src/server/api/routers/evaluations.router.ts new file mode 100644 index 0000000..3f60e1d --- /dev/null +++ b/src/server/api/routers/evaluations.router.ts @@ -0,0 +1,77 @@ +import { EvaluationMatchType } from "@prisma/client"; +import { z } from "zod"; +import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; +import { prisma } from "~/server/db"; +import { reevaluateEvaluation } from "~/server/utils/evaluations"; + +export const evaluationsRouter = createTRPCRouter({ + list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { + return await prisma.evaluation.findMany({ + where: { + experimentId: input.experimentId, + }, + orderBy: { createdAt: "asc" }, + }); + }), + + results: publicProcedure.input(z.object({ variantId: z.string() })).query(async ({ input }) => { + return await prisma.evaluationResult.findMany({ + where: { + promptVariantId: input.variantId, + }, + include: { evaluation: true }, + }); + }), + + create: publicProcedure + .input( + z.object({ + experimentId: z.string(), + name: z.string(), + matchString: z.string(), + matchType: z.nativeEnum(EvaluationMatchType), + }) + ) + .mutation(async ({ input }) => { + const evaluation = await prisma.evaluation.create({ + data: { + experimentId: input.experimentId, + name: input.name, + matchString: input.matchString, + matchType: input.matchType, + }, + }); + await reevaluateEvaluation(evaluation); + }), + + update: publicProcedure + .input( + z.object({ + id: z.string(), + updates: z.object({ + name: z.string().optional(), + matchString: z.string().optional(), + matchType: z.nativeEnum(EvaluationMatchType).optional(), + }), + }) + ) + .mutation(async ({ input }) => { + await prisma.evaluation.update({ + where: { id: input.id }, + data: { + name: input.updates.name, + matchString: input.updates.matchString, + matchType: input.updates.matchType, + }, + }); + await reevaluateEvaluation( + await prisma.evaluation.findUniqueOrThrow({ where: { id: input.id } }) + ); + }), + + delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => { + await prisma.evaluation.delete({ + where: { id: input.id }, + }); + }), +}); diff --git a/src/server/api/routers/modelOutputs.router.ts b/src/server/api/routers/modelOutputs.router.ts index 44b7b52..c2f1e57 100644 --- a/src/server/api/routers/modelOutputs.router.ts +++ b/src/server/api/routers/modelOutputs.router.ts @@ -1,15 +1,18 @@ import { z } from "zod"; import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { prisma } from "~/server/db"; -import fillTemplate, { type VariableMap } from "~/server/utils/fillTemplate"; +import { fillTemplateJson, type VariableMap } from "~/server/utils/fillTemplate"; import { type JSONSerializable } from "~/server/types"; -import { getCompletion } from "~/server/utils/getCompletion"; import crypto from "crypto"; import type { Prisma } from "@prisma/client"; +import { reevaluateVariant } from "~/server/utils/evaluations"; +import { getCompletion } from "~/server/utils/getCompletion"; export const modelOutputsRouter = createTRPCRouter({ get: publicProcedure - .input(z.object({ scenarioId: z.string(), variantId: z.string(), channel: z.string().optional() })) + .input( + z.object({ scenarioId: z.string(), variantId: z.string(), channel: z.string().optional() }) + ) .query(async ({ input }) => { const existing = await prisma.modelOutput.findUnique({ where: { @@ -36,7 +39,7 @@ export const modelOutputsRouter = createTRPCRouter({ if (!variant || !scenario) return null; - const filledTemplate = fillTemplate( + const filledTemplate = fillTemplateJson( variant.config as JSONSerializable, scenario.variableValues as VariableMap ); @@ -73,6 +76,8 @@ export const modelOutputsRouter = createTRPCRouter({ }, }); + await reevaluateVariant(input.variantId); + return modelOutput; }), }); diff --git a/src/server/api/routers/promptVariants.router.ts b/src/server/api/routers/promptVariants.router.ts index 557c896..c519814 100644 --- a/src/server/api/routers/promptVariants.router.ts +++ b/src/server/api/routers/promptVariants.router.ts @@ -10,9 +10,7 @@ export const promptVariantsRouter = createTRPCRouter({ experimentId: input.experimentId, visible: true, }, - orderBy: { - sortIndex: "asc", - }, + orderBy: { sortIndex: "asc" }, }); }), diff --git a/src/server/utils/evaluateOutput.ts b/src/server/utils/evaluateOutput.ts new file mode 100644 index 0000000..404df09 --- /dev/null +++ b/src/server/utils/evaluateOutput.ts @@ -0,0 +1,31 @@ +import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/client"; +import { type ChatCompletion } from "openai/resources/chat"; +import { type VariableMap, fillTemplate } from "./fillTemplate"; + +export const evaluateOutput = ( + modelOutput: ModelOutput, + scenario: TestScenario, + evaluation: Evaluation +): boolean => { + const output = modelOutput.output as unknown as ChatCompletion; + const message = output.choices?.[0]?.message; + + if (!message) return false; + + const stringifiedMessage = JSON.stringify(message); + + const matchRegex = fillTemplate(evaluation.matchString, scenario.variableValues as VariableMap); + + let match; + + switch (evaluation.matchType) { + case "CONTAINS": + match = stringifiedMessage.match(matchRegex) !== null; + break; + case "DOES_NOT_CONTAIN": + match = stringifiedMessage.match(matchRegex) === null; + break; + } + + return match; +}; diff --git a/src/server/utils/evaluations.ts b/src/server/utils/evaluations.ts new file mode 100644 index 0000000..4509a00 --- /dev/null +++ b/src/server/utils/evaluations.ts @@ -0,0 +1,91 @@ +import { type Evaluation } from "@prisma/client"; +import { prisma } from "../db"; +import { evaluateOutput } from "./evaluateOutput"; + +export const reevaluateVariant = async (variantId: string) => { + const variant = await prisma.promptVariant.findUnique({ + where: { id: variantId }, + }); + if (!variant) return; + + const evaluations = await prisma.evaluation.findMany({ + where: { experimentId: variant.experimentId }, + }); + + const modelOutputs = await prisma.modelOutput.findMany({ + where: { promptVariantId: variantId }, + include: { testScenario: true }, + }); + + const scenarios = await prisma.testScenario.findMany({ + where: { experimentId: variant.experimentId, visible: true }, + }); + + await Promise.all( + evaluations.map(async (evaluation) => { + const passCount = modelOutputs.filter((output) => + evaluateOutput(output, output.testScenario, evaluation) + ).length; + const failCount = scenarios.length - passCount; + + await prisma.evaluationResult.upsert({ + where: { + evaluationId_promptVariantId: { + evaluationId: evaluation.id, + promptVariantId: variantId, + }, + }, + create: { + evaluationId: evaluation.id, + promptVariantId: variantId, + passCount, + failCount, + }, + update: { + passCount, + failCount, + }, + }); + }) + ); +}; + +export const reevaluateEvaluation = async (evaluation: Evaluation) => { + const variants = await prisma.promptVariant.findMany({ + where: { experimentId: evaluation.experimentId, visible: true }, + }); + + const modelOutputs = await prisma.modelOutput.findMany({ + where: { promptVariantId: { in: variants.map((v) => v.id) }, testScenario: { visible: true } }, + include: { testScenario: true }, + }); + + await Promise.all( + variants.map(async (variant) => { + const outputs = modelOutputs.filter((output) => output.promptVariantId === variant.id); + const passCount = outputs.filter((output) => + evaluateOutput(output, output.testScenario, evaluation) + ).length; + const failCount = outputs.length - passCount; + + await prisma.evaluationResult.upsert({ + where: { + evaluationId_promptVariantId: { + evaluationId: evaluation.id, + promptVariantId: variant.id, + }, + }, + create: { + evaluationId: evaluation.id, + promptVariantId: variant.id, + passCount, + failCount, + }, + update: { + passCount, + failCount, + }, + }); + }) + ); +}; diff --git a/src/server/utils/fillTemplate.ts b/src/server/utils/fillTemplate.ts index 923652c..1cebd22 100644 --- a/src/server/utils/fillTemplate.ts +++ b/src/server/utils/fillTemplate.ts @@ -2,17 +2,21 @@ import { type JSONSerializable } from "../types"; export type VariableMap = Record; -export default function fillTemplate( +export function fillTemplate(template: string, variables: VariableMap): string { + return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || ""); +} + +export function fillTemplateJson( template: T, variables: VariableMap ): T { if (typeof template === "string") { - return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || "") as T; + return fillTemplate(template, variables) as T; } else if (Array.isArray(template)) { - return template.map((item) => fillTemplate(item, variables)) as T; + return template.map((item) => fillTemplateJson(item, variables)) as T; } else if (typeof template === "object" && template !== null) { return Object.keys(template).reduce((acc, key) => { - acc[key] = fillTemplate(template[key] as JSONSerializable, variables); + acc[key] = fillTemplateJson(template[key] as JSONSerializable, variables); return acc; }, {} as { [key: string]: JSONSerializable } & T); } else {