From 98b231c8bd7612871400acc1029cf6e827c9dd4e Mon Sep 17 00:00:00 2001 From: arcticfly <41524992+arcticfly@users.noreply.github.com> Date: Tue, 25 Jul 2023 18:54:38 -0700 Subject: [PATCH] Store multiple ModelResponses (#95) * Store multiple ModelResponses * Fix prettier * Add CellContent container --- .../migration.sql | 52 ++++++++ prisma/schema.prisma | 35 ++--- .../OutputsTable/OutputCell/CellContent.tsx | 17 +++ .../OutputsTable/OutputCell/CellOptions.tsx | 6 +- .../OutputsTable/OutputCell/OutputCell.tsx | 120 +++++++++++++----- .../OutputsTable/OutputCell/OutputStats.tsx | 24 ++-- .../OutputsTable/OutputCell/ResponseLog.tsx | 22 ++++ .../{ErrorHandler.tsx => RetryCountdown.tsx} | 32 ++--- src/server/api/routers/experiments.router.ts | 26 ++-- .../api/routers/promptVariants.router.ts | 20 ++- .../routers/scenarioVariantCells.router.ts | 18 +-- src/server/tasks/queryModel.task.ts | 51 ++++---- src/server/utils/evaluations.ts | 26 ++-- src/server/utils/generateNewCell.ts | 45 +++++-- src/server/utils/runOneEval.ts | 6 +- 15 files changed, 341 insertions(+), 159 deletions(-) create mode 100644 prisma/migrations/20230725191512_migrate_model_response/migration.sql create mode 100644 src/components/OutputsTable/OutputCell/CellContent.tsx create mode 100644 src/components/OutputsTable/OutputCell/ResponseLog.tsx rename src/components/OutputsTable/OutputCell/{ErrorHandler.tsx => RetryCountdown.tsx} (50%) diff --git a/prisma/migrations/20230725191512_migrate_model_response/migration.sql b/prisma/migrations/20230725191512_migrate_model_response/migration.sql new file mode 100644 index 0000000..a0b3457 --- /dev/null +++ b/prisma/migrations/20230725191512_migrate_model_response/migration.sql @@ -0,0 +1,52 @@ +-- DropForeignKey +ALTER TABLE "ModelOutput" DROP CONSTRAINT "ModelOutput_scenarioVariantCellId_fkey"; + +-- DropForeignKey +ALTER TABLE "OutputEvaluation" DROP CONSTRAINT "OutputEvaluation_modelOutputId_fkey"; + +-- DropIndex +DROP INDEX "OutputEvaluation_modelOutputId_evaluationId_key"; + +-- AlterTable +ALTER TABLE "OutputEvaluation" RENAME COLUMN "modelOutputId" TO "modelResponseId"; + +-- AlterTable +ALTER TABLE "ScenarioVariantCell" DROP COLUMN "retryTime", +DROP COLUMN "statusCode", +ADD COLUMN "jobQueuedAt" TIMESTAMP(3), +ADD COLUMN "jobStartedAt" TIMESTAMP(3); + +ALTER TABLE "ModelOutput" RENAME TO "ModelResponse"; + +ALTER TABLE "ModelResponse" +ADD COLUMN "requestedAt" TIMESTAMP(3), +ADD COLUMN "receivedAt" TIMESTAMP(3), +ADD COLUMN "statusCode" INTEGER, +ADD COLUMN "errorMessage" TEXT, +ADD COLUMN "retryTime" TIMESTAMP(3), +ADD COLUMN "outdated" BOOLEAN NOT NULL DEFAULT false; + +-- 3. Remove the unnecessary column +ALTER TABLE "ModelResponse" +DROP COLUMN "timeToComplete"; + +-- AlterTable +ALTER TABLE "ModelResponse" RENAME CONSTRAINT "ModelOutput_pkey" TO "ModelResponse_pkey"; +ALTER TABLE "ModelResponse" ALTER COLUMN "output" DROP NOT NULL; + +-- DropIndex +DROP INDEX "ModelOutput_scenarioVariantCellId_key"; + +-- AddForeignKey +ALTER TABLE "ModelResponse" ADD CONSTRAINT "ModelResponse_scenarioVariantCellId_fkey" FOREIGN KEY ("scenarioVariantCellId") REFERENCES "ScenarioVariantCell"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- RenameIndex +ALTER INDEX "ModelOutput_inputHash_idx" RENAME TO "ModelResponse_inputHash_idx"; + +-- CreateIndex +CREATE UNIQUE INDEX "OutputEvaluation_modelResponseId_evaluationId_key" ON "OutputEvaluation"("modelResponseId", "evaluationId"); + +-- AddForeignKey +ALTER TABLE "OutputEvaluation" ADD CONSTRAINT "OutputEvaluation_modelResponseId_fkey" FOREIGN KEY ("modelResponseId") REFERENCES "ModelResponse"("id") ON DELETE CASCADE ON UPDATE CASCADE; + + diff --git a/prisma/schema.prisma b/prisma/schema.prisma index ca11195..11c6760 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -90,12 +90,11 @@ enum CellRetrievalStatus { model ScenarioVariantCell { id String @id @default(uuid()) @db.Uuid - statusCode Int? - errorMessage String? - retryTime DateTime? retrievalStatus CellRetrievalStatus @default(COMPLETE) - - modelOutput ModelOutput? + jobQueuedAt DateTime? + jobStartedAt DateTime? + modelResponses ModelResponse[] + errorMessage String? // Contains errors that occurred independently of model responses promptVariantId String @db.Uuid promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade) @@ -110,15 +109,20 @@ model ScenarioVariantCell { @@unique([promptVariantId, testScenarioId]) } -model ModelOutput { +model ModelResponse { id String @id @default(uuid()) @db.Uuid - inputHash String - output Json - timeToComplete Int @default(0) - cost Float? - promptTokens Int? - completionTokens Int? + inputHash String + requestedAt DateTime? + receivedAt DateTime? + output Json? + cost Float? + promptTokens Int? + completionTokens Int? + statusCode Int? + errorMessage String? + retryTime DateTime? + outdated Boolean @default(false) createdAt DateTime @default(now()) updatedAt DateTime @updatedAt @@ -127,7 +131,6 @@ model ModelOutput { scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade) outputEvaluations OutputEvaluation[] - @@unique([scenarioVariantCellId]) @@index([inputHash]) } @@ -159,8 +162,8 @@ model OutputEvaluation { result Float details String? - modelOutputId String @db.Uuid - modelOutput ModelOutput @relation(fields: [modelOutputId], references: [id], onDelete: Cascade) + modelResponseId String @db.Uuid + modelResponse ModelResponse @relation(fields: [modelResponseId], references: [id], onDelete: Cascade) evaluationId String @db.Uuid evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade) @@ -168,7 +171,7 @@ model OutputEvaluation { createdAt DateTime @default(now()) updatedAt DateTime @updatedAt - @@unique([modelOutputId, evaluationId]) + @@unique([modelResponseId, evaluationId]) } model Organization { diff --git a/src/components/OutputsTable/OutputCell/CellContent.tsx b/src/components/OutputsTable/OutputCell/CellContent.tsx new file mode 100644 index 0000000..14b3572 --- /dev/null +++ b/src/components/OutputsTable/OutputCell/CellContent.tsx @@ -0,0 +1,17 @@ +import { type StackProps, VStack } from "@chakra-ui/react"; +import { CellOptions } from "./CellOptions"; + +export const CellContent = ({ + hardRefetch, + hardRefetching, + children, + ...props +}: { + hardRefetch: () => void; + hardRefetching: boolean; +} & StackProps) => ( + + + {children} + +); diff --git a/src/components/OutputsTable/OutputCell/CellOptions.tsx b/src/components/OutputsTable/OutputCell/CellOptions.tsx index 00c7836..dfdd139 100644 --- a/src/components/OutputsTable/OutputCell/CellOptions.tsx +++ b/src/components/OutputsTable/OutputCell/CellOptions.tsx @@ -1,4 +1,4 @@ -import { Button, HStack, Icon, Tooltip } from "@chakra-ui/react"; +import { Button, HStack, Icon, Spinner, Tooltip } from "@chakra-ui/react"; import { BsArrowClockwise } from "react-icons/bs"; import { useExperimentAccess } from "~/utils/hooks"; @@ -12,7 +12,7 @@ export const CellOptions = ({ const { canModify } = useExperimentAccess(); return ( - {!refetchingOutput && canModify && ( + {canModify && ( )} diff --git a/src/components/OutputsTable/OutputCell/OutputCell.tsx b/src/components/OutputsTable/OutputCell/OutputCell.tsx index e8450ef..dd0263d 100644 --- a/src/components/OutputsTable/OutputCell/OutputCell.tsx +++ b/src/components/OutputsTable/OutputCell/OutputCell.tsx @@ -1,16 +1,19 @@ import { api } from "~/utils/api"; import { type PromptVariant, type Scenario } from "../types"; -import { Spinner, Text, Center, VStack } from "@chakra-ui/react"; +import { Text, VStack } from "@chakra-ui/react"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import SyntaxHighlighter from "react-syntax-highlighter"; import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs"; import stringify from "json-stringify-pretty-compact"; -import { type ReactElement, useState, useEffect } from "react"; +import { type ReactElement, useState, useEffect, Fragment } from "react"; import useSocket from "~/utils/useSocket"; import { OutputStats } from "./OutputStats"; -import { ErrorHandler } from "./ErrorHandler"; -import { CellOptions } from "./CellOptions"; +import { RetryCountdown } from "./RetryCountdown"; import frontendModelProviders from "~/modelProviders/frontendModelProviders"; +import { ResponseLog } from "./ResponseLog"; +import { CellContent } from "./CellContent"; + +const WAITING_MESSAGE_INTERVAL = 20000; export default function OutputCell({ scenario, @@ -65,46 +68,91 @@ export default function OutputCell({ hardRefetching; useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]); - const modelOutput = cell?.modelOutput; - // TODO: disconnect from socket if we're not streaming anymore const streamedMessage = useSocket(cell?.id); if (!vars) return null; - if (disabledReason) return {disabledReason}; - - if (awaitingOutput && !streamedMessage) - return ( -
- -
- ); - if (!cell && !fetchingOutput) return ( - - + Error retrieving output - + ); if (cell && cell.errorMessage) { return ( - - - - + + {cell.errorMessage} + ); } - const normalizedOutput = modelOutput - ? provider.normalizeOutput(modelOutput.output) + if (disabledReason) return {disabledReason}; + + const mostRecentResponse = cell?.modelResponses[cell.modelResponses.length - 1]; + const showLogs = !streamedMessage && !mostRecentResponse?.output; + + if (showLogs) + return ( + + {cell?.jobQueuedAt && } + {cell?.jobStartedAt && } + {cell?.modelResponses?.map((response) => { + let numWaitingMessages = 0; + const relativeWaitingTime = response.receivedAt + ? response.receivedAt.getTime() + : Date.now(); + if (response.requestedAt) { + numWaitingMessages = Math.floor( + (relativeWaitingTime - response.requestedAt.getTime()) / WAITING_MESSAGE_INTERVAL, + ); + } + return ( + + {response.requestedAt && ( + + )} + {response.requestedAt && + Array.from({ length: numWaitingMessages }, (_, i) => ( + + ))} + {response.receivedAt && ( + + )} + + ); + }) ?? null} + {mostRecentResponse?.retryTime && ( + + )} + + ); + + const normalizedOutput = mostRecentResponse?.output + ? provider.normalizeOutput(mostRecentResponse?.output) : streamedMessage ? provider.normalizeOutput(streamedMessage) : null; - if (modelOutput && normalizedOutput?.type === "json") { + if (mostRecentResponse?.output && normalizedOutput?.type === "json") { return ( - - + {stringify(normalizedOutput.value, { maxLength: 40 })} - - + + ); } @@ -138,10 +191,13 @@ export default function OutputCell({ return ( - - {contentToDisplay} + + {contentToDisplay} + - {modelOutput && } + {mostRecentResponse?.output && ( + + )} ); } diff --git a/src/components/OutputsTable/OutputCell/OutputStats.tsx b/src/components/OutputsTable/OutputCell/OutputStats.tsx index 9a77a3a..2797463 100644 --- a/src/components/OutputsTable/OutputCell/OutputStats.tsx +++ b/src/components/OutputsTable/OutputCell/OutputStats.tsx @@ -7,28 +7,32 @@ import { CostTooltip } from "~/components/tooltip/CostTooltip"; const SHOW_TIME = true; export const OutputStats = ({ - modelOutput, + modelResponse, }: { - modelOutput: NonNullable< - NonNullable["modelOutput"] + modelResponse: NonNullable< + NonNullable["modelResponses"][0] >; scenario: Scenario; }) => { - const timeToComplete = modelOutput.timeToComplete; + const timeToComplete = + modelResponse.receivedAt && modelResponse.requestedAt + ? modelResponse.receivedAt.getTime() - modelResponse.requestedAt.getTime() + : 0; - const promptTokens = modelOutput.promptTokens; - const completionTokens = modelOutput.completionTokens; + const promptTokens = modelResponse.promptTokens; + const completionTokens = modelResponse.completionTokens; return ( - {modelOutput.outputEvaluations.map((evaluation) => { + {modelResponse.outputEvaluations.map((evaluation) => { const passed = evaluation.result > 0.5; return ( {evaluation.evaluation.label} @@ -42,15 +46,15 @@ export const OutputStats = ({ ); })} - {modelOutput.cost && ( + {modelResponse.cost && ( - {modelOutput.cost.toFixed(3)} + {modelResponse.cost.toFixed(3)} )} diff --git a/src/components/OutputsTable/OutputCell/ResponseLog.tsx b/src/components/OutputsTable/OutputCell/ResponseLog.tsx new file mode 100644 index 0000000..6672321 --- /dev/null +++ b/src/components/OutputsTable/OutputCell/ResponseLog.tsx @@ -0,0 +1,22 @@ +import { HStack, VStack, Text } from "@chakra-ui/react"; +import dayjs from "dayjs"; + +export const ResponseLog = ({ + time, + title, + message, +}: { + time: Date; + title: string; + message?: string; +}) => { + return ( + + + {dayjs(time).format("HH:mm:ss")} + {title} + + {message && {message}} + + ); +}; diff --git a/src/components/OutputsTable/OutputCell/ErrorHandler.tsx b/src/components/OutputsTable/OutputCell/RetryCountdown.tsx similarity index 50% rename from src/components/OutputsTable/OutputCell/ErrorHandler.tsx rename to src/components/OutputsTable/OutputCell/RetryCountdown.tsx index b626bcc..d836da3 100644 --- a/src/components/OutputsTable/OutputCell/ErrorHandler.tsx +++ b/src/components/OutputsTable/OutputCell/RetryCountdown.tsx @@ -1,21 +1,12 @@ -import { type ScenarioVariantCell } from "@prisma/client"; -import { VStack, Text } from "@chakra-ui/react"; +import { Text } from "@chakra-ui/react"; import { useEffect, useState } from "react"; import pluralize from "pluralize"; -export const ErrorHandler = ({ - cell, - refetchOutput, -}: { - cell: ScenarioVariantCell; - refetchOutput: () => void; -}) => { +export const RetryCountdown = ({ retryTime }: { retryTime: Date }) => { const [msToWait, setMsToWait] = useState(0); useEffect(() => { - if (!cell.retryTime) return; - - const initialWaitTime = cell.retryTime.getTime() - Date.now(); + const initialWaitTime = retryTime.getTime() - Date.now(); const msModuloOneSecond = initialWaitTime % 1000; let remainingTime = initialWaitTime - msModuloOneSecond; setMsToWait(remainingTime); @@ -36,18 +27,13 @@ export const ErrorHandler = ({ clearInterval(interval); clearTimeout(timeout); }; - }, [cell.retryTime, cell.statusCode, setMsToWait, refetchOutput]); + }, [retryTime]); + + if (msToWait <= 0) return null; return ( - - - {cell.errorMessage} - - {msToWait > 0 && ( - - Retrying in {pluralize("second", Math.ceil(msToWait / 1000), true)}... - - )} - + + Retrying in {pluralize("second", Math.ceil(msToWait / 1000), true)}... + ); }; diff --git a/src/server/api/routers/experiments.router.ts b/src/server/api/routers/experiments.router.ts index c5f4bf4..26f87ff 100644 --- a/src/server/api/routers/experiments.router.ts +++ b/src/server/api/routers/experiments.router.ts @@ -118,7 +118,7 @@ export const experimentsRouter = createTRPCRouter({ }, }, include: { - modelOutput: { + modelResponses: { include: { outputEvaluations: true, }, @@ -177,11 +177,11 @@ export const experimentsRouter = createTRPCRouter({ } const cellsToCreate: Prisma.ScenarioVariantCellCreateManyInput[] = []; - const modelOutputsToCreate: Prisma.ModelOutputCreateManyInput[] = []; + const modelResponsesToCreate: Prisma.ModelResponseCreateManyInput[] = []; const outputEvaluationsToCreate: Prisma.OutputEvaluationCreateManyInput[] = []; for (const cell of existingCells) { const newCellId = uuidv4(); - const { modelOutput, ...cellData } = cell; + const { modelResponses, ...cellData } = cell; cellsToCreate.push({ ...cellData, id: newCellId, @@ -189,20 +189,20 @@ export const experimentsRouter = createTRPCRouter({ testScenarioId: existingToNewScenarioIds.get(cell.testScenarioId) ?? "", prompt: (cell.prompt as Prisma.InputJsonValue) ?? undefined, }); - if (modelOutput) { - const newModelOutputId = uuidv4(); - const { outputEvaluations, ...modelOutputData } = modelOutput; - modelOutputsToCreate.push({ - ...modelOutputData, - id: newModelOutputId, + for (const modelResponse of modelResponses) { + const newModelResponseId = uuidv4(); + const { outputEvaluations, ...modelResponseData } = modelResponse; + modelResponsesToCreate.push({ + ...modelResponseData, + id: newModelResponseId, scenarioVariantCellId: newCellId, - output: (modelOutput.output as Prisma.InputJsonValue) ?? undefined, + output: (modelResponse.output as Prisma.InputJsonValue) ?? undefined, }); for (const evaluation of outputEvaluations) { outputEvaluationsToCreate.push({ ...evaluation, id: uuidv4(), - modelOutputId: newModelOutputId, + modelResponseId: newModelResponseId, evaluationId: existingToNewEvaluationIds.get(evaluation.evaluationId) ?? "", }); } @@ -245,8 +245,8 @@ export const experimentsRouter = createTRPCRouter({ prisma.scenarioVariantCell.createMany({ data: cellsToCreate, }), - prisma.modelOutput.createMany({ - data: modelOutputsToCreate, + prisma.modelResponse.createMany({ + data: modelResponsesToCreate, }), prisma.evaluation.createMany({ data: evaluationsToCreate, diff --git a/src/server/api/routers/promptVariants.router.ts b/src/server/api/routers/promptVariants.router.ts index 8832dc4..92e6cd5 100644 --- a/src/server/api/routers/promptVariants.router.ts +++ b/src/server/api/routers/promptVariants.router.ts @@ -1,6 +1,7 @@ import { z } from "zod"; import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc"; import { prisma } from "~/server/db"; +import { Prisma } from "@prisma/client"; import { generateNewCell } from "~/server/utils/generateNewCell"; import userError from "~/server/utils/error"; import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated"; @@ -51,7 +52,9 @@ export const promptVariantsRouter = createTRPCRouter({ id: true, }, where: { - modelOutput: { + modelResponse: { + outdated: false, + output: { not: Prisma.AnyNull }, scenarioVariantCell: { promptVariant: { id: input.variantId, @@ -93,14 +96,23 @@ export const promptVariantsRouter = createTRPCRouter({ where: { promptVariantId: input.variantId, testScenario: { visible: true }, - modelOutput: { - is: {}, + modelResponses: { + some: { + outdated: false, + output: { + not: Prisma.AnyNull, + }, + }, }, }, }); - const overallTokens = await prisma.modelOutput.aggregate({ + const overallTokens = await prisma.modelResponse.aggregate({ where: { + outdated: false, + output: { + not: Prisma.AnyNull, + }, scenarioVariantCell: { promptVariantId: input.variantId, testScenario: { diff --git a/src/server/api/routers/scenarioVariantCells.router.ts b/src/server/api/routers/scenarioVariantCells.router.ts index b0dbecb..8a9a60d 100644 --- a/src/server/api/routers/scenarioVariantCells.router.ts +++ b/src/server/api/routers/scenarioVariantCells.router.ts @@ -27,7 +27,10 @@ export const scenarioVariantCellsRouter = createTRPCRouter({ }, }, include: { - modelOutput: { + modelResponses: { + where: { + outdated: false, + }, include: { outputEvaluations: { include: { @@ -62,7 +65,6 @@ export const scenarioVariantCellsRouter = createTRPCRouter({ testScenarioId: input.scenarioId, }, }, - include: { modelOutput: true }, }); if (!cell) { @@ -70,12 +72,12 @@ export const scenarioVariantCellsRouter = createTRPCRouter({ return; } - if (cell.modelOutput) { - // TODO: Maybe keep these around to show previous generations? - await prisma.modelOutput.delete({ - where: { id: cell.modelOutput.id }, - }); - } + await prisma.modelResponse.updateMany({ + where: { scenarioVariantCellId: cell.id }, + data: { + outdated: true, + }, + }); await queueQueryModel(cell.id, true); }), diff --git a/src/server/tasks/queryModel.task.ts b/src/server/tasks/queryModel.task.ts index 64b4f61..de42388 100644 --- a/src/server/tasks/queryModel.task.ts +++ b/src/server/tasks/queryModel.task.ts @@ -29,17 +29,9 @@ export const queryModel = defineTask("queryModel", async (task) = const { cellId, stream } = task; const cell = await prisma.scenarioVariantCell.findUnique({ where: { id: cellId }, - include: { modelOutput: true }, + include: { modelResponses: true }, }); if (!cell) { - await prisma.scenarioVariantCell.update({ - where: { id: cellId }, - data: { - statusCode: 404, - errorMessage: "Cell not found", - retrievalStatus: "ERROR", - }, - }); return; } @@ -51,6 +43,7 @@ export const queryModel = defineTask("queryModel", async (task) = where: { id: cellId }, data: { retrievalStatus: "IN_PROGRESS", + jobStartedAt: new Date(), }, }); @@ -61,7 +54,6 @@ export const queryModel = defineTask("queryModel", async (task) = await prisma.scenarioVariantCell.update({ where: { id: cellId }, data: { - statusCode: 404, errorMessage: "Prompt Variant not found", retrievalStatus: "ERROR", }, @@ -76,7 +68,6 @@ export const queryModel = defineTask("queryModel", async (task) = await prisma.scenarioVariantCell.update({ where: { id: cellId }, data: { - statusCode: 404, errorMessage: "Scenario not found", retrievalStatus: "ERROR", }, @@ -90,7 +81,6 @@ export const queryModel = defineTask("queryModel", async (task) = await prisma.scenarioVariantCell.update({ where: { id: cellId }, data: { - statusCode: 400, errorMessage: prompt.error, retrievalStatus: "ERROR", }, @@ -106,17 +96,24 @@ export const queryModel = defineTask("queryModel", async (task) = } : null; + const inputHash = hashPrompt(prompt); + for (let i = 0; true; i++) { + const modelResponse = await prisma.modelResponse.create({ + data: { + inputHash, + scenarioVariantCellId: cellId, + requestedAt: new Date(), + }, + }); const response = await provider.getCompletion(prompt.modelInput, onStream); if (response.type === "success") { - const inputHash = hashPrompt(prompt); - - const modelOutput = await prisma.modelOutput.create({ + await prisma.modelResponse.update({ + where: { id: modelResponse.id }, data: { - scenarioVariantCellId: cellId, - inputHash, output: response.value as Prisma.InputJsonObject, - timeToComplete: response.timeToComplete, + statusCode: response.statusCode, + receivedAt: new Date(), promptTokens: response.promptTokens, completionTokens: response.completionTokens, cost: response.cost, @@ -126,30 +123,35 @@ export const queryModel = defineTask("queryModel", async (task) = await prisma.scenarioVariantCell.update({ where: { id: cellId }, data: { - statusCode: response.statusCode, retrievalStatus: "COMPLETE", }, }); - await runEvalsForOutput(variant.experimentId, scenario, modelOutput); + await runEvalsForOutput(variant.experimentId, scenario, modelResponse); break; } else { const shouldRetry = response.autoRetry && i < MAX_AUTO_RETRIES; const delay = calculateDelay(i); - await prisma.scenarioVariantCell.update({ - where: { id: cellId }, + await prisma.modelResponse.update({ + where: { id: modelResponse.id }, data: { - errorMessage: response.message, statusCode: response.statusCode, + errorMessage: response.message, + receivedAt: new Date(), retryTime: shouldRetry ? new Date(Date.now() + delay) : null, - retrievalStatus: "ERROR", }, }); if (shouldRetry) { await sleep(delay); } else { + await prisma.scenarioVariantCell.update({ + where: { id: cellId }, + data: { + retrievalStatus: "ERROR", + }, + }); break; } } @@ -165,6 +167,7 @@ export const queueQueryModel = async (cellId: string, stream: boolean) => { data: { retrievalStatus: "PENDING", errorMessage: null, + jobQueuedAt: new Date(), }, }), queryModel.enqueue({ cellId, stream }), diff --git a/src/server/utils/evaluations.ts b/src/server/utils/evaluations.ts index 1bcadab..530598c 100644 --- a/src/server/utils/evaluations.ts +++ b/src/server/utils/evaluations.ts @@ -1,19 +1,23 @@ -import { type ModelOutput, type Evaluation } from "@prisma/client"; +import { type ModelResponse, type Evaluation, Prisma } from "@prisma/client"; import { prisma } from "../db"; import { runOneEval } from "./runOneEval"; import { type Scenario } from "~/components/OutputsTable/types"; -const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutput: ModelOutput) => { - const result = await runOneEval(evaluation, scenario, modelOutput); +const saveResult = async ( + evaluation: Evaluation, + scenario: Scenario, + modelResponse: ModelResponse, +) => { + const result = await runOneEval(evaluation, scenario, modelResponse); return await prisma.outputEvaluation.upsert({ where: { - modelOutputId_evaluationId: { - modelOutputId: modelOutput.id, + modelResponseId_evaluationId: { + modelResponseId: modelResponse.id, evaluationId: evaluation.id, }, }, create: { - modelOutputId: modelOutput.id, + modelResponseId: modelResponse.id, evaluationId: evaluation.id, ...result, }, @@ -26,20 +30,24 @@ const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutpu export const runEvalsForOutput = async ( experimentId: string, scenario: Scenario, - modelOutput: ModelOutput, + modelResponse: ModelResponse, ) => { const evaluations = await prisma.evaluation.findMany({ where: { experimentId }, }); await Promise.all( - evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelOutput)), + evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelResponse)), ); }; export const runAllEvals = async (experimentId: string) => { - const outputs = await prisma.modelOutput.findMany({ + const outputs = await prisma.modelResponse.findMany({ where: { + outdated: false, + output: { + not: Prisma.AnyNull, + }, scenarioVariantCell: { promptVariant: { experimentId, diff --git a/src/server/utils/generateNewCell.ts b/src/server/utils/generateNewCell.ts index 0c4c2be..dcc9977 100644 --- a/src/server/utils/generateNewCell.ts +++ b/src/server/utils/generateNewCell.ts @@ -1,4 +1,4 @@ -import { type Prisma } from "@prisma/client"; +import { Prisma } from "@prisma/client"; import { prisma } from "../db"; import parseConstructFn from "./parseConstructFn"; import { type JsonObject } from "type-fest"; @@ -35,7 +35,7 @@ export const generateNewCell = async ( }, }, include: { - modelOutput: true, + modelResponses: true, }, }); @@ -51,8 +51,6 @@ export const generateNewCell = async ( data: { promptVariantId: variantId, testScenarioId: scenarioId, - statusCode: 400, - errorMessage: parsedConstructFn.error, retrievalStatus: "ERROR", }, }); @@ -69,36 +67,55 @@ export const generateNewCell = async ( retrievalStatus: "PENDING", }, include: { - modelOutput: true, + modelResponses: true, }, }); - const matchingModelOutput = await prisma.modelOutput.findFirst({ - where: { inputHash }, + const matchingModelResponse = await prisma.modelResponse.findFirst({ + where: { + inputHash, + output: { + not: Prisma.AnyNull, + }, + }, + orderBy: { + receivedAt: "desc", + }, + include: { + scenarioVariantCell: true, + }, + take: 1, }); - if (matchingModelOutput) { - const newModelOutput = await prisma.modelOutput.create({ + if (matchingModelResponse) { + const newModelResponse = await prisma.modelResponse.create({ data: { - ...omit(matchingModelOutput, ["id"]), + ...omit(matchingModelResponse, ["id", "scenarioVariantCell"]), scenarioVariantCellId: cell.id, - output: matchingModelOutput.output as Prisma.InputJsonValue, + output: matchingModelResponse.output as Prisma.InputJsonValue, }, }); + await prisma.scenarioVariantCell.update({ where: { id: cell.id }, - data: { retrievalStatus: "COMPLETE" }, + data: { + retrievalStatus: "COMPLETE", + jobStartedAt: matchingModelResponse.scenarioVariantCell.jobStartedAt, + jobQueuedAt: matchingModelResponse.scenarioVariantCell.jobQueuedAt, + }, }); // Copy over all eval results as well await Promise.all( ( - await prisma.outputEvaluation.findMany({ where: { modelOutputId: matchingModelOutput.id } }) + await prisma.outputEvaluation.findMany({ + where: { modelResponseId: matchingModelResponse.id }, + }) ).map(async (evaluation) => { await prisma.outputEvaluation.create({ data: { ...omit(evaluation, ["id"]), - modelOutputId: newModelOutput.id, + modelResponseId: newModelResponse.id, }, }); }), diff --git a/src/server/utils/runOneEval.ts b/src/server/utils/runOneEval.ts index 2c2693d..b38abb6 100644 --- a/src/server/utils/runOneEval.ts +++ b/src/server/utils/runOneEval.ts @@ -1,4 +1,4 @@ -import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/client"; +import { type Evaluation, type ModelResponse, type TestScenario } from "@prisma/client"; import { type ChatCompletion } from "openai/resources/chat"; import { type VariableMap, fillTemplate, escapeRegExp, escapeQuotes } from "./fillTemplate"; import { openai } from "./openai"; @@ -70,9 +70,9 @@ export const runGpt4Eval = async ( export const runOneEval = async ( evaluation: Evaluation, scenario: TestScenario, - modelOutput: ModelOutput, + modelResponse: ModelResponse, ): Promise<{ result: number; details?: string }> => { - const output = modelOutput.output as unknown as ChatCompletion; + const output = modelResponse.output as unknown as ChatCompletion; const message = output?.choices?.[0]?.message;