diff --git a/prisma/schema.prisma b/prisma/schema.prisma
index bf05a85..119a78d 100644
--- a/prisma/schema.prisma
+++ b/prisma/schema.prisma
@@ -41,7 +41,6 @@ model PromptVariant {
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
scenarioVariantCells ScenarioVariantCell[]
- EvaluationResult EvaluationResult[]
@@index([uiId])
}
@@ -124,6 +123,7 @@ model ModelOutput {
scenarioVariantCellId String @db.Uuid
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
+ outputEvaluation OutputEvaluation[]
@@unique([scenarioVariantCellId])
@@index([inputHash])
@@ -146,25 +146,26 @@ model Evaluation {
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
- EvaluationResult EvaluationResult[]
+ OutputEvaluation OutputEvaluation[]
}
-model EvaluationResult {
+model OutputEvaluation {
id String @id @default(uuid()) @db.Uuid
- passCount Int
- failCount Int
+ // Number between 0 (fail) and 1 (pass)
+ result Float
+ details String?
+
+ modelOutputId String @db.Uuid
+ modelOutput ModelOutput @relation(fields: [modelOutputId], references: [id], onDelete: Cascade)
evaluationId String @db.Uuid
evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade)
- promptVariantId String @db.Uuid
- promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
-
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
- @@unique([evaluationId, promptVariantId])
+ @@unique([modelOutputId, evaluationId])
}
// Necessary for Next auth
diff --git a/src/components/OutputsTable/EditEvaluations.tsx b/src/components/OutputsTable/EditEvaluations.tsx
index 7d1a44c..a672064 100644
--- a/src/components/OutputsTable/EditEvaluations.tsx
+++ b/src/components/OutputsTable/EditEvaluations.tsx
@@ -40,7 +40,7 @@ export function EvaluationEditor(props: {
setValues((values) => ({ ...values, name: e.target.value }))}
+ onChange={(e) => setValues((values) => ({ ...values, label: e.target.value }))}
/>
@@ -125,6 +125,7 @@ export default function EditEvaluations() {
}
await utils.evaluations.list.invalidate();
await utils.promptVariants.stats.invalidate();
+ await utils.scenarioVariantCells.get.invalidate();
}, []);
const onCancel = useCallback(() => {
diff --git a/src/components/OutputsTable/OutputCell/OutputStats.tsx b/src/components/OutputsTable/OutputCell/OutputStats.tsx
index 24a92b0..5256168 100644
--- a/src/components/OutputsTable/OutputCell/OutputStats.tsx
+++ b/src/components/OutputsTable/OutputCell/OutputStats.tsx
@@ -1,10 +1,7 @@
-import { type ModelOutput } from "@prisma/client";
import { type SupportedModel } from "~/server/types";
import { type Scenario } from "../types";
-import { useExperiment } from "~/utils/hooks";
-import { api } from "~/utils/api";
+import { type RouterOutputs } from "~/utils/api";
import { calculateTokenCost } from "~/utils/calculateTokenCost";
-import { evaluateOutput } from "~/server/utils/evaluateOutput";
import { HStack, Icon, Text } from "@chakra-ui/react";
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
import { CostTooltip } from "~/components/tooltip/CostTooltip";
@@ -15,16 +12,14 @@ const SHOW_TIME = true;
export const OutputStats = ({
model,
modelOutput,
- scenario,
}: {
model: SupportedModel | string | null;
- modelOutput: ModelOutput;
+ modelOutput: NonNullable<
+ NonNullable["modelOutput"]
+ >;
scenario: Scenario;
}) => {
const timeToComplete = modelOutput.timeToComplete;
- const experiment = useExperiment();
- const evals =
- api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
const promptTokens = modelOutput.promptTokens;
const completionTokens = modelOutput.completionTokens;
@@ -38,11 +33,11 @@ export const OutputStats = ({
return (
- {evals.map((evaluation) => {
- const passed = evaluateOutput(modelOutput, scenario, evaluation);
+ {modelOutput.outputEvaluation.map((evaluation) => {
+ const passed = evaluation.result > 0.5;
return (
- {evaluation.label}
+ {evaluation.evaluation.label}
{data.evalResults.map((result) => {
- const passedFrac = result.passCount / (result.passCount + result.failCount);
+ const passedFrac = result.passCount / result.totalCount;
return (
- {result.evaluation.label}
+ {result.label}
{(passedFrac * 100).toFixed(1)}%
diff --git a/src/server/api/routers/evaluations.router.ts b/src/server/api/routers/evaluations.router.ts
index 99c1a8c..7ee0d12 100644
--- a/src/server/api/routers/evaluations.router.ts
+++ b/src/server/api/routers/evaluations.router.ts
@@ -2,7 +2,7 @@ import { EvalType } from "@prisma/client";
import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
-import { reevaluateEvaluation } from "~/server/utils/evaluations";
+import { runAllEvals } from "~/server/utils/evaluations";
export const evaluationsRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
@@ -24,7 +24,7 @@ export const evaluationsRouter = createTRPCRouter({
}),
)
.mutation(async ({ input }) => {
- const evaluation = await prisma.evaluation.create({
+ await prisma.evaluation.create({
data: {
experimentId: input.experimentId,
label: input.label,
@@ -32,7 +32,10 @@ export const evaluationsRouter = createTRPCRouter({
evalType: input.evalType,
},
});
- await reevaluateEvaluation(evaluation);
+
+ // TODO: this may be a bad UX for slow evals (eg. GPT-4 evals) Maybe need
+ // to kick off a background job or something instead
+ await runAllEvals(input.experimentId);
}),
update: publicProcedure
@@ -40,24 +43,30 @@ export const evaluationsRouter = createTRPCRouter({
z.object({
id: z.string(),
updates: z.object({
- name: z.string().optional(),
+ label: z.string().optional(),
value: z.string().optional(),
evalType: z.nativeEnum(EvalType).optional(),
}),
}),
)
.mutation(async ({ input }) => {
- await prisma.evaluation.update({
+ const evaluation = await prisma.evaluation.update({
where: { id: input.id },
data: {
- label: input.updates.name,
+ label: input.updates.label,
value: input.updates.value,
evalType: input.updates.evalType,
},
});
- await reevaluateEvaluation(
- await prisma.evaluation.findUniqueOrThrow({ where: { id: input.id } }),
- );
+
+ await prisma.outputEvaluation.deleteMany({
+ where: {
+ evaluationId: evaluation.id,
+ },
+ });
+ // Re-run all evals. Other eval results will already be cached, so this
+ // should only re-run the updated one.
+ await runAllEvals(evaluation.experimentId);
}),
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
diff --git a/src/server/api/routers/promptVariants.router.ts b/src/server/api/routers/promptVariants.router.ts
index ece4815..20f50d0 100644
--- a/src/server/api/routers/promptVariants.router.ts
+++ b/src/server/api/routers/promptVariants.router.ts
@@ -32,11 +32,43 @@ export const promptVariantsRouter = createTRPCRouter({
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
}
- const evalResults = await prisma.evaluationResult.findMany({
- where: {
- promptVariantId: input.variantId,
+ const outputEvals = await prisma.outputEvaluation.groupBy({
+ by: ["evaluationId"],
+ _sum: {
+ result: true,
},
- include: { evaluation: true },
+ _count: {
+ id: true,
+ },
+ where: {
+ modelOutput: {
+ scenarioVariantCell: {
+ promptVariant: {
+ id: input.variantId,
+ visible: true,
+ },
+ testScenario: {
+ visible: true,
+ },
+ },
+ },
+ },
+ });
+
+ const evals = await prisma.evaluation.findMany({
+ where: {
+ experimentId: variant.experimentId,
+ },
+ });
+
+ const evalResults = evals.map((evalItem) => {
+ const evalResult = outputEvals.find((outputEval) => outputEval.evaluationId === evalItem.id);
+ return {
+ id: evalItem.id,
+ label: evalItem.label,
+ passCount: evalResult?._sum?.result ?? 0,
+ totalCount: evalResult?._count?.id ?? 1,
+ };
});
const scenarioCount = await prisma.testScenario.count({
@@ -50,7 +82,7 @@ export const promptVariantsRouter = createTRPCRouter({
promptVariantId: input.variantId,
testScenario: { visible: true },
modelOutput: {
- isNot: null,
+ is: {},
},
},
});
diff --git a/src/server/api/routers/scenarioVariantCells.router.ts b/src/server/api/routers/scenarioVariantCells.router.ts
index 09e1172..b07657e 100644
--- a/src/server/api/routers/scenarioVariantCells.router.ts
+++ b/src/server/api/routers/scenarioVariantCells.router.ts
@@ -21,7 +21,17 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
},
},
include: {
- modelOutput: true,
+ modelOutput: {
+ include: {
+ outputEvaluation: {
+ include: {
+ evaluation: {
+ select: { label: true },
+ },
+ },
+ },
+ },
+ },
},
});
}),
diff --git a/src/server/api/routers/scenarios.router.ts b/src/server/api/routers/scenarios.router.ts
index 5075ae7..0ddfb0b 100644
--- a/src/server/api/routers/scenarios.router.ts
+++ b/src/server/api/routers/scenarios.router.ts
@@ -3,7 +3,7 @@ import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { autogenerateScenarioValues } from "../autogen";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
-import { reevaluateAll } from "~/server/utils/evaluations";
+import { runAllEvals } from "~/server/utils/evaluations";
import { generateNewCell } from "~/server/utils/generateNewCell";
export const scenariosRouter = createTRPCRouter({
@@ -73,7 +73,7 @@ export const scenariosRouter = createTRPCRouter({
});
// Reevaluate all evaluations now that this scenario is hidden
- await reevaluateAll(hiddenScenario.experimentId);
+ await runAllEvals(hiddenScenario.experimentId);
return hiddenScenario;
}),
diff --git a/src/server/tasks/queryLLM.task.ts b/src/server/tasks/queryLLM.task.ts
index 541d17b..1eb94c7 100644
--- a/src/server/tasks/queryLLM.task.ts
+++ b/src/server/tasks/queryLLM.task.ts
@@ -6,7 +6,7 @@ import { type JSONSerializable } from "../types";
import { sleep } from "../utils/sleep";
import { shouldStream } from "../utils/shouldStream";
import { generateChannel } from "~/utils/generateChannel";
-import { reevaluateVariant } from "../utils/evaluations";
+import { runEvalsForOutput } from "../utils/evaluations";
import { constructPrompt } from "../utils/constructPrompt";
import { type CompletionCreateParams } from "openai/resources/chat";
import { type Prisma } from "@prisma/client";
@@ -148,5 +148,7 @@ export const queryLLM = defineTask("queryLLM", async (task) => {
},
});
- await reevaluateVariant(cell.promptVariantId);
+ if (modelOutput) {
+ await runEvalsForOutput(variant.experimentId, scenario, modelOutput);
+ }
});
diff --git a/src/server/utils/evaluations.ts b/src/server/utils/evaluations.ts
index 3b24334..8740b88 100644
--- a/src/server/utils/evaluations.ts
+++ b/src/server/utils/evaluations.ts
@@ -1,105 +1,154 @@
import { type ModelOutput, type Evaluation } from "@prisma/client";
import { prisma } from "../db";
-import { evaluateOutput } from "./evaluateOutput";
+import { runOneEval } from "./runOneEval";
+import { type Scenario } from "~/components/OutputsTable/types";
-export const reevaluateVariant = async (variantId: string) => {
- const variant = await prisma.promptVariant.findUnique({
- where: { id: variantId },
- });
- if (!variant) return;
-
- const evaluations = await prisma.evaluation.findMany({
- where: { experimentId: variant.experimentId },
- });
-
- const cells = await prisma.scenarioVariantCell.findMany({
+const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutput: ModelOutput) => {
+ const result = runOneEval(evaluation, scenario, modelOutput);
+ return await prisma.outputEvaluation.upsert({
where: {
- promptVariantId: variantId,
- retrievalStatus: "COMPLETE",
- testScenario: { visible: true },
- modelOutput: { isNot: null },
+ modelOutputId_evaluationId: {
+ modelOutputId: modelOutput.id,
+ evaluationId: evaluation.id,
+ },
+ },
+ create: {
+ modelOutputId: modelOutput.id,
+ evaluationId: evaluation.id,
+ result,
+ },
+ update: {
+ result,
},
- include: { testScenario: true, modelOutput: true },
});
-
- await Promise.all(
- evaluations.map(async (evaluation) => {
- const passCount = cells.filter((cell) =>
- evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
- ).length;
- const failCount = cells.length - passCount;
-
- await prisma.evaluationResult.upsert({
- where: {
- evaluationId_promptVariantId: {
- evaluationId: evaluation.id,
- promptVariantId: variantId,
- },
- },
- create: {
- evaluationId: evaluation.id,
- promptVariantId: variantId,
- passCount,
- failCount,
- },
- update: {
- passCount,
- failCount,
- },
- });
- }),
- );
};
-export const reevaluateEvaluation = async (evaluation: Evaluation) => {
- const variants = await prisma.promptVariant.findMany({
- where: { experimentId: evaluation.experimentId, visible: true },
- });
-
- const cells = await prisma.scenarioVariantCell.findMany({
- where: {
- promptVariantId: { in: variants.map((v) => v.id) },
- testScenario: { visible: true },
- statusCode: { notIn: [429] },
- modelOutput: { isNot: null },
- },
- include: { testScenario: true, modelOutput: true },
- });
-
- await Promise.all(
- variants.map(async (variant) => {
- const variantCells = cells.filter((cell) => cell.promptVariantId === variant.id);
- const passCount = variantCells.filter((cell) =>
- evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
- ).length;
- const failCount = variantCells.length - passCount;
-
- await prisma.evaluationResult.upsert({
- where: {
- evaluationId_promptVariantId: {
- evaluationId: evaluation.id,
- promptVariantId: variant.id,
- },
- },
- create: {
- evaluationId: evaluation.id,
- promptVariantId: variant.id,
- passCount,
- failCount,
- },
- update: {
- passCount,
- failCount,
- },
- });
- }),
- );
-};
-
-export const reevaluateAll = async (experimentId: string) => {
+export const runEvalsForOutput = async (
+ experimentId: string,
+ scenario: Scenario,
+ modelOutput: ModelOutput,
+) => {
const evaluations = await prisma.evaluation.findMany({
where: { experimentId },
});
- await Promise.all(evaluations.map(reevaluateEvaluation));
+ await Promise.all(
+ evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelOutput)),
+ );
+
+ // const cells = await prisma.scenarioVariantCell.findMany({
+ // where: {
+ // promptVariantId: variantId,
+ // retrievalStatus: "COMPLETE",
+ // testScenario: { visible: true },
+ // },
+ // include: { testScenario: true, modelOutput: { include: { OutputEvaluation: true } } },
+ // });
+
+ // await Promise.all(
+ // evaluations.map(async (evaluation) => {
+ // const passCount = cells.filter((cell) =>
+ // runOneEval(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
+ // ).length;
+ // const failCount = cells.length - passCount;
+
+ // await prisma.evaluationResult.upsert({
+ // where: {
+ // evaluationId_promptVariantId: {
+ // evaluationId: evaluation.id,
+ // promptVariantId: variantId,
+ // },
+ // },
+ // create: {
+ // evaluationId: evaluation.id,
+ // promptVariantId: variantId,
+ // passCount,
+ // failCount,
+ // },
+ // update: {
+ // passCount,
+ // failCount,
+ // },
+ // });
+ // }),
+ // );
+};
+
+export const runAllEvals = async (experimentId: string) => {
+ const outputs = await prisma.modelOutput.findMany({
+ where: {
+ scenarioVariantCell: {
+ promptVariant: {
+ experimentId,
+ visible: true,
+ },
+ testScenario: {
+ visible: true,
+ },
+ },
+ },
+ include: {
+ scenarioVariantCell: {
+ include: {
+ testScenario: true,
+ },
+ },
+ outputEvaluation: true,
+ },
+ });
+ const evals = await prisma.evaluation.findMany({
+ where: { experimentId },
+ });
+
+ await Promise.all(
+ outputs.map(async (output) => {
+ const unrunEvals = evals.filter(
+ (evaluation) => !output.outputEvaluation.find((e) => e.evaluationId === evaluation.id),
+ );
+
+ await Promise.all(
+ unrunEvals.map(async (evaluation) => {
+ await saveResult(evaluation, output.scenarioVariantCell.testScenario, output);
+ }),
+ );
+ }),
+ );
+
+ // const cells = await prisma.scenarioVariantCell.findMany({
+ // where: {
+ // promptVariantId: { in: variants.map((v) => v.id) },
+ // testScenario: { visible: true },
+ // statusCode: { notIn: [429] },
+ // },
+ // include: { testScenario: true, modelOutput: true },
+ // });
+
+ // await Promise.all(
+ // variants.map(async (variant) => {
+ // const variantCells = cells.filter((cell) => cell.promptVariantId === variant.id);
+ // const passCount = variantCells.filter((cell) =>
+ // runOneEval(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
+ // ).length;
+ // const failCount = variantCells.length - passCount;
+
+ // await prisma.evaluationResult.upsert({
+ // where: {
+ // evaluationId_promptVariantId: {
+ // evaluationId: evaluation.id,
+ // promptVariantId: variant.id,
+ // },
+ // },
+ // create: {
+ // evaluationId: evaluation.id,
+ // promptVariantId: variant.id,
+ // passCount,
+ // failCount,
+ // },
+ // update: {
+ // passCount,
+ // failCount,
+ // },
+ // });
+ // }),
+ // );
};
diff --git a/src/server/utils/evaluateOutput.ts b/src/server/utils/runOneEval.ts
similarity index 74%
rename from src/server/utils/evaluateOutput.ts
rename to src/server/utils/runOneEval.ts
index accf9df..619f271 100644
--- a/src/server/utils/evaluateOutput.ts
+++ b/src/server/utils/runOneEval.ts
@@ -2,30 +2,31 @@ import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/cl
import { type ChatCompletion } from "openai/resources/chat";
import { type VariableMap, fillTemplate } from "./fillTemplate";
-export const evaluateOutput = (
- modelOutput: ModelOutput,
- scenario: TestScenario,
+export const runOneEval = (
evaluation: Evaluation,
-): boolean => {
+ scenario: TestScenario,
+ modelOutput: ModelOutput,
+): number => {
const output = modelOutput.output as unknown as ChatCompletion;
+
const message = output?.choices?.[0]?.message;
- if (!message) return false;
+ if (!message) return 0;
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
const matchRegex = fillTemplate(evaluation.value, scenario.variableValues as VariableMap);
- let match;
+ let result;
switch (evaluation.evalType) {
case "CONTAINS":
- match = stringifiedMessage.match(matchRegex) !== null;
+ result = stringifiedMessage.match(matchRegex) !== null ? 1 : 0;
break;
case "DOES_NOT_CONTAIN":
- match = stringifiedMessage.match(matchRegex) === null;
+ result = stringifiedMessage.match(matchRegex) === null ? 1 : 0;
break;
}
- return match;
+ return result;
};