diff --git a/src/components/OutputsTable/OutputCell/OutputCell.tsx b/src/components/OutputsTable/OutputCell/OutputCell.tsx index dd0263d..eda9aef 100644 --- a/src/components/OutputsTable/OutputCell/OutputCell.tsx +++ b/src/components/OutputsTable/OutputCell/OutputCell.tsx @@ -63,6 +63,7 @@ export default function OutputCell({ const awaitingOutput = !cell || + !cell.evalsComplete || cell.retrievalStatus === "PENDING" || cell.retrievalStatus === "IN_PROGRESS" || hardRefetching; diff --git a/src/server/api/routers/scenarioVariantCells.router.ts b/src/server/api/routers/scenarioVariantCells.router.ts index 8a9a60d..29812f2 100644 --- a/src/server/api/routers/scenarioVariantCells.router.ts +++ b/src/server/api/routers/scenarioVariantCells.router.ts @@ -19,30 +19,45 @@ export const scenarioVariantCellsRouter = createTRPCRouter({ }); await requireCanViewExperiment(experimentId, ctx); - return await prisma.scenarioVariantCell.findUnique({ - where: { - promptVariantId_testScenarioId: { - promptVariantId: input.variantId, - testScenarioId: input.scenarioId, - }, - }, - include: { - modelResponses: { - where: { - outdated: false, + const [cell, numTotalEvals] = await prisma.$transaction([ + prisma.scenarioVariantCell.findUnique({ + where: { + promptVariantId_testScenarioId: { + promptVariantId: input.variantId, + testScenarioId: input.scenarioId, }, - include: { - outputEvaluations: { - include: { - evaluation: { - select: { label: true }, + }, + include: { + modelResponses: { + where: { + outdated: false, + }, + include: { + outputEvaluations: { + include: { + evaluation: { + select: { label: true }, + }, }, }, }, }, }, - }, - }); + }), + prisma.evaluation.count({ + where: { experimentId }, + }), + ]); + + if (!cell) return null; + + const lastResponse = cell.modelResponses?.[cell.modelResponses?.length - 1]; + const evalsComplete = lastResponse?.outputEvaluations?.length === numTotalEvals; + + return { + ...cell, + evalsComplete, + }; }), forceRefetch: protectedProcedure .input( diff --git a/src/server/tasks/queryModel.task.ts b/src/server/tasks/queryModel.task.ts index de42388..929eb29 100644 --- a/src/server/tasks/queryModel.task.ts +++ b/src/server/tasks/queryModel.task.ts @@ -99,7 +99,7 @@ export const queryModel = defineTask("queryModel", async (task) = const inputHash = hashPrompt(prompt); for (let i = 0; true; i++) { - const modelResponse = await prisma.modelResponse.create({ + let modelResponse = await prisma.modelResponse.create({ data: { inputHash, scenarioVariantCellId: cellId, @@ -108,7 +108,7 @@ export const queryModel = defineTask("queryModel", async (task) = }); const response = await provider.getCompletion(prompt.modelInput, onStream); if (response.type === "success") { - await prisma.modelResponse.update({ + modelResponse = await prisma.modelResponse.update({ where: { id: modelResponse.id }, data: { output: response.value as Prisma.InputJsonObject, @@ -127,7 +127,7 @@ export const queryModel = defineTask("queryModel", async (task) = }, }); - await runEvalsForOutput(variant.experimentId, scenario, modelResponse); + await runEvalsForOutput(variant.experimentId, scenario, modelResponse, prompt.modelProvider); break; } else { const shouldRetry = response.autoRetry && i < MAX_AUTO_RETRIES; diff --git a/src/server/utils/evaluations.ts b/src/server/utils/evaluations.ts index 530598c..e4a7086 100644 --- a/src/server/utils/evaluations.ts +++ b/src/server/utils/evaluations.ts @@ -2,13 +2,15 @@ import { type ModelResponse, type Evaluation, Prisma } from "@prisma/client"; import { prisma } from "../db"; import { runOneEval } from "./runOneEval"; import { type Scenario } from "~/components/OutputsTable/types"; +import { type SupportedProvider } from "~/modelProviders/types"; -const saveResult = async ( +const runAndSaveEval = async ( evaluation: Evaluation, scenario: Scenario, modelResponse: ModelResponse, + provider: SupportedProvider, ) => { - const result = await runOneEval(evaluation, scenario, modelResponse); + const result = await runOneEval(evaluation, scenario, modelResponse, provider); return await prisma.outputEvaluation.upsert({ where: { modelResponseId_evaluationId: { @@ -31,13 +33,16 @@ export const runEvalsForOutput = async ( experimentId: string, scenario: Scenario, modelResponse: ModelResponse, + provider: SupportedProvider, ) => { const evaluations = await prisma.evaluation.findMany({ where: { experimentId }, }); await Promise.all( - evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelResponse)), + evaluations.map( + async (evaluation) => await runAndSaveEval(evaluation, scenario, modelResponse, provider), + ), ); }; @@ -62,6 +67,7 @@ export const runAllEvals = async (experimentId: string) => { scenarioVariantCell: { include: { testScenario: true, + promptVariant: true, }, }, outputEvaluations: true, @@ -73,13 +79,18 @@ export const runAllEvals = async (experimentId: string) => { await Promise.all( outputs.map(async (output) => { - const unrunEvals = evals.filter( + const evalsToBeRun = evals.filter( (evaluation) => !output.outputEvaluations.find((e) => e.evaluationId === evaluation.id), ); await Promise.all( - unrunEvals.map(async (evaluation) => { - await saveResult(evaluation, output.scenarioVariantCell.testScenario, output); + evalsToBeRun.map(async (evaluation) => { + await runAndSaveEval( + evaluation, + output.scenarioVariantCell.testScenario, + output, + output.scenarioVariantCell.promptVariant.modelProvider as SupportedProvider, + ); }), ); }), diff --git a/src/server/utils/runOneEval.ts b/src/server/utils/runOneEval.ts index b38abb6..a65f417 100644 --- a/src/server/utils/runOneEval.ts +++ b/src/server/utils/runOneEval.ts @@ -1,13 +1,14 @@ import { type Evaluation, type ModelResponse, type TestScenario } from "@prisma/client"; -import { type ChatCompletion } from "openai/resources/chat"; import { type VariableMap, fillTemplate, escapeRegExp, escapeQuotes } from "./fillTemplate"; import { openai } from "./openai"; import dedent from "dedent"; +import modelProviders from "~/modelProviders/modelProviders"; +import { type SupportedProvider } from "~/modelProviders/types"; export const runGpt4Eval = async ( evaluation: Evaluation, scenario: TestScenario, - message: ChatCompletion.Choice.Message, + stringifiedOutput: string, ): Promise<{ result: number; details: string }> => { const output = await openai.chat.completions.create({ model: "gpt-4-0613", @@ -26,11 +27,7 @@ export const runGpt4Eval = async ( }, { role: "user", - content: `The full output of the simpler message:\n---\n${JSON.stringify( - message.content ?? message.function_call, - null, - 2, - )}`, + content: `The full output of the simpler message:\n---\n${stringifiedOutput}`, }, ], function_call: { @@ -71,14 +68,15 @@ export const runOneEval = async ( evaluation: Evaluation, scenario: TestScenario, modelResponse: ModelResponse, + provider: SupportedProvider, ): Promise<{ result: number; details?: string }> => { - const output = modelResponse.output as unknown as ChatCompletion; - - const message = output?.choices?.[0]?.message; + const modelProvider = modelProviders[provider]; + const message = modelProvider.normalizeOutput(modelResponse.output); if (!message) return { result: 0 }; - const stringifiedMessage = message.content ?? JSON.stringify(message.function_call); + const stringifiedOutput = + message.type === "json" ? JSON.stringify(message.value, null, 2) : message.value; const matchRegex = escapeRegExp( fillTemplate(escapeQuotes(evaluation.value), scenario.variableValues as VariableMap), @@ -86,10 +84,10 @@ export const runOneEval = async ( switch (evaluation.evalType) { case "CONTAINS": - return { result: stringifiedMessage.match(matchRegex) !== null ? 1 : 0 }; + return { result: stringifiedOutput.match(matchRegex) !== null ? 1 : 0 }; case "DOES_NOT_CONTAIN": - return { result: stringifiedMessage.match(matchRegex) === null ? 1 : 0 }; + return { result: stringifiedOutput.match(matchRegex) === null ? 1 : 0 }; case "GPT4_EVAL": - return await runGpt4Eval(evaluation, scenario, message); + return await runGpt4Eval(evaluation, scenario, stringifiedOutput); } };