diff --git a/prisma/migrations/20230725191512_migrate_model_response/migration.sql b/prisma/migrations/20230725191512_migrate_model_response/migration.sql
new file mode 100644
index 0000000..a0b3457
--- /dev/null
+++ b/prisma/migrations/20230725191512_migrate_model_response/migration.sql
@@ -0,0 +1,52 @@
+-- DropForeignKey
+ALTER TABLE "ModelOutput" DROP CONSTRAINT "ModelOutput_scenarioVariantCellId_fkey";
+
+-- DropForeignKey
+ALTER TABLE "OutputEvaluation" DROP CONSTRAINT "OutputEvaluation_modelOutputId_fkey";
+
+-- DropIndex
+DROP INDEX "OutputEvaluation_modelOutputId_evaluationId_key";
+
+-- AlterTable
+ALTER TABLE "OutputEvaluation" RENAME COLUMN "modelOutputId" TO "modelResponseId";
+
+-- AlterTable
+ALTER TABLE "ScenarioVariantCell" DROP COLUMN "retryTime",
+DROP COLUMN "statusCode",
+ADD COLUMN "jobQueuedAt" TIMESTAMP(3),
+ADD COLUMN "jobStartedAt" TIMESTAMP(3);
+
+ALTER TABLE "ModelOutput" RENAME TO "ModelResponse";
+
+ALTER TABLE "ModelResponse"
+ADD COLUMN "requestedAt" TIMESTAMP(3),
+ADD COLUMN "receivedAt" TIMESTAMP(3),
+ADD COLUMN "statusCode" INTEGER,
+ADD COLUMN "errorMessage" TEXT,
+ADD COLUMN "retryTime" TIMESTAMP(3),
+ADD COLUMN "outdated" BOOLEAN NOT NULL DEFAULT false;
+
+-- 3. Remove the unnecessary column
+ALTER TABLE "ModelResponse"
+DROP COLUMN "timeToComplete";
+
+-- AlterTable
+ALTER TABLE "ModelResponse" RENAME CONSTRAINT "ModelOutput_pkey" TO "ModelResponse_pkey";
+ALTER TABLE "ModelResponse" ALTER COLUMN "output" DROP NOT NULL;
+
+-- DropIndex
+DROP INDEX "ModelOutput_scenarioVariantCellId_key";
+
+-- AddForeignKey
+ALTER TABLE "ModelResponse" ADD CONSTRAINT "ModelResponse_scenarioVariantCellId_fkey" FOREIGN KEY ("scenarioVariantCellId") REFERENCES "ScenarioVariantCell"("id") ON DELETE CASCADE ON UPDATE CASCADE;
+
+-- RenameIndex
+ALTER INDEX "ModelOutput_inputHash_idx" RENAME TO "ModelResponse_inputHash_idx";
+
+-- CreateIndex
+CREATE UNIQUE INDEX "OutputEvaluation_modelResponseId_evaluationId_key" ON "OutputEvaluation"("modelResponseId", "evaluationId");
+
+-- AddForeignKey
+ALTER TABLE "OutputEvaluation" ADD CONSTRAINT "OutputEvaluation_modelResponseId_fkey" FOREIGN KEY ("modelResponseId") REFERENCES "ModelResponse"("id") ON DELETE CASCADE ON UPDATE CASCADE;
+
+
diff --git a/prisma/schema.prisma b/prisma/schema.prisma
index ca11195..11c6760 100644
--- a/prisma/schema.prisma
+++ b/prisma/schema.prisma
@@ -90,12 +90,11 @@ enum CellRetrievalStatus {
model ScenarioVariantCell {
id String @id @default(uuid()) @db.Uuid
- statusCode Int?
- errorMessage String?
- retryTime DateTime?
retrievalStatus CellRetrievalStatus @default(COMPLETE)
-
- modelOutput ModelOutput?
+ jobQueuedAt DateTime?
+ jobStartedAt DateTime?
+ modelResponses ModelResponse[]
+ errorMessage String? // Contains errors that occurred independently of model responses
promptVariantId String @db.Uuid
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
@@ -110,15 +109,20 @@ model ScenarioVariantCell {
@@unique([promptVariantId, testScenarioId])
}
-model ModelOutput {
+model ModelResponse {
id String @id @default(uuid()) @db.Uuid
- inputHash String
- output Json
- timeToComplete Int @default(0)
- cost Float?
- promptTokens Int?
- completionTokens Int?
+ inputHash String
+ requestedAt DateTime?
+ receivedAt DateTime?
+ output Json?
+ cost Float?
+ promptTokens Int?
+ completionTokens Int?
+ statusCode Int?
+ errorMessage String?
+ retryTime DateTime?
+ outdated Boolean @default(false)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
@@ -127,7 +131,6 @@ model ModelOutput {
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
outputEvaluations OutputEvaluation[]
- @@unique([scenarioVariantCellId])
@@index([inputHash])
}
@@ -159,8 +162,8 @@ model OutputEvaluation {
result Float
details String?
- modelOutputId String @db.Uuid
- modelOutput ModelOutput @relation(fields: [modelOutputId], references: [id], onDelete: Cascade)
+ modelResponseId String @db.Uuid
+ modelResponse ModelResponse @relation(fields: [modelResponseId], references: [id], onDelete: Cascade)
evaluationId String @db.Uuid
evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade)
@@ -168,7 +171,7 @@ model OutputEvaluation {
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
- @@unique([modelOutputId, evaluationId])
+ @@unique([modelResponseId, evaluationId])
}
model Organization {
diff --git a/src/components/OutputsTable/OutputCell/CellContent.tsx b/src/components/OutputsTable/OutputCell/CellContent.tsx
new file mode 100644
index 0000000..14b3572
--- /dev/null
+++ b/src/components/OutputsTable/OutputCell/CellContent.tsx
@@ -0,0 +1,17 @@
+import { type StackProps, VStack } from "@chakra-ui/react";
+import { CellOptions } from "./CellOptions";
+
+export const CellContent = ({
+ hardRefetch,
+ hardRefetching,
+ children,
+ ...props
+}: {
+ hardRefetch: () => void;
+ hardRefetching: boolean;
+} & StackProps) => (
+
+
+ {children}
+
+);
diff --git a/src/components/OutputsTable/OutputCell/CellOptions.tsx b/src/components/OutputsTable/OutputCell/CellOptions.tsx
index 00c7836..dfdd139 100644
--- a/src/components/OutputsTable/OutputCell/CellOptions.tsx
+++ b/src/components/OutputsTable/OutputCell/CellOptions.tsx
@@ -1,4 +1,4 @@
-import { Button, HStack, Icon, Tooltip } from "@chakra-ui/react";
+import { Button, HStack, Icon, Spinner, Tooltip } from "@chakra-ui/react";
import { BsArrowClockwise } from "react-icons/bs";
import { useExperimentAccess } from "~/utils/hooks";
@@ -12,7 +12,7 @@ export const CellOptions = ({
const { canModify } = useExperimentAccess();
return (
- {!refetchingOutput && canModify && (
+ {canModify && (
)}
diff --git a/src/components/OutputsTable/OutputCell/OutputCell.tsx b/src/components/OutputsTable/OutputCell/OutputCell.tsx
index e8450ef..dd0263d 100644
--- a/src/components/OutputsTable/OutputCell/OutputCell.tsx
+++ b/src/components/OutputsTable/OutputCell/OutputCell.tsx
@@ -1,16 +1,19 @@
import { api } from "~/utils/api";
import { type PromptVariant, type Scenario } from "../types";
-import { Spinner, Text, Center, VStack } from "@chakra-ui/react";
+import { Text, VStack } from "@chakra-ui/react";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import SyntaxHighlighter from "react-syntax-highlighter";
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
import stringify from "json-stringify-pretty-compact";
-import { type ReactElement, useState, useEffect } from "react";
+import { type ReactElement, useState, useEffect, Fragment } from "react";
import useSocket from "~/utils/useSocket";
import { OutputStats } from "./OutputStats";
-import { ErrorHandler } from "./ErrorHandler";
-import { CellOptions } from "./CellOptions";
+import { RetryCountdown } from "./RetryCountdown";
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
+import { ResponseLog } from "./ResponseLog";
+import { CellContent } from "./CellContent";
+
+const WAITING_MESSAGE_INTERVAL = 20000;
export default function OutputCell({
scenario,
@@ -65,46 +68,91 @@ export default function OutputCell({
hardRefetching;
useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]);
- const modelOutput = cell?.modelOutput;
-
// TODO: disconnect from socket if we're not streaming anymore
const streamedMessage = useSocket(cell?.id);
if (!vars) return null;
- if (disabledReason) return {disabledReason};
-
- if (awaitingOutput && !streamedMessage)
- return (
-
-
-
- );
-
if (!cell && !fetchingOutput)
return (
-
-
+
Error retrieving output
-
+
);
if (cell && cell.errorMessage) {
return (
-
-
-
-
+
+ {cell.errorMessage}
+
);
}
- const normalizedOutput = modelOutput
- ? provider.normalizeOutput(modelOutput.output)
+ if (disabledReason) return {disabledReason};
+
+ const mostRecentResponse = cell?.modelResponses[cell.modelResponses.length - 1];
+ const showLogs = !streamedMessage && !mostRecentResponse?.output;
+
+ if (showLogs)
+ return (
+
+ {cell?.jobQueuedAt && }
+ {cell?.jobStartedAt && }
+ {cell?.modelResponses?.map((response) => {
+ let numWaitingMessages = 0;
+ const relativeWaitingTime = response.receivedAt
+ ? response.receivedAt.getTime()
+ : Date.now();
+ if (response.requestedAt) {
+ numWaitingMessages = Math.floor(
+ (relativeWaitingTime - response.requestedAt.getTime()) / WAITING_MESSAGE_INTERVAL,
+ );
+ }
+ return (
+
+ {response.requestedAt && (
+
+ )}
+ {response.requestedAt &&
+ Array.from({ length: numWaitingMessages }, (_, i) => (
+
+ ))}
+ {response.receivedAt && (
+
+ )}
+
+ );
+ }) ?? null}
+ {mostRecentResponse?.retryTime && (
+
+ )}
+
+ );
+
+ const normalizedOutput = mostRecentResponse?.output
+ ? provider.normalizeOutput(mostRecentResponse?.output)
: streamedMessage
? provider.normalizeOutput(streamedMessage)
: null;
- if (modelOutput && normalizedOutput?.type === "json") {
+ if (mostRecentResponse?.output && normalizedOutput?.type === "json") {
return (
-
-
+
{stringify(normalizedOutput.value, { maxLength: 40 })}
-
-
+
+
);
}
@@ -138,10 +191,13 @@ export default function OutputCell({
return (
-
- {contentToDisplay}
+
+ {contentToDisplay}
+
- {modelOutput && }
+ {mostRecentResponse?.output && (
+
+ )}
);
}
diff --git a/src/components/OutputsTable/OutputCell/OutputStats.tsx b/src/components/OutputsTable/OutputCell/OutputStats.tsx
index 9a77a3a..2797463 100644
--- a/src/components/OutputsTable/OutputCell/OutputStats.tsx
+++ b/src/components/OutputsTable/OutputCell/OutputStats.tsx
@@ -7,28 +7,32 @@ import { CostTooltip } from "~/components/tooltip/CostTooltip";
const SHOW_TIME = true;
export const OutputStats = ({
- modelOutput,
+ modelResponse,
}: {
- modelOutput: NonNullable<
- NonNullable["modelOutput"]
+ modelResponse: NonNullable<
+ NonNullable["modelResponses"][0]
>;
scenario: Scenario;
}) => {
- const timeToComplete = modelOutput.timeToComplete;
+ const timeToComplete =
+ modelResponse.receivedAt && modelResponse.requestedAt
+ ? modelResponse.receivedAt.getTime() - modelResponse.requestedAt.getTime()
+ : 0;
- const promptTokens = modelOutput.promptTokens;
- const completionTokens = modelOutput.completionTokens;
+ const promptTokens = modelResponse.promptTokens;
+ const completionTokens = modelResponse.completionTokens;
return (
- {modelOutput.outputEvaluations.map((evaluation) => {
+ {modelResponse.outputEvaluations.map((evaluation) => {
const passed = evaluation.result > 0.5;
return (
{evaluation.evaluation.label}
@@ -42,15 +46,15 @@ export const OutputStats = ({
);
})}
- {modelOutput.cost && (
+ {modelResponse.cost && (
- {modelOutput.cost.toFixed(3)}
+ {modelResponse.cost.toFixed(3)}
)}
diff --git a/src/components/OutputsTable/OutputCell/ResponseLog.tsx b/src/components/OutputsTable/OutputCell/ResponseLog.tsx
new file mode 100644
index 0000000..6672321
--- /dev/null
+++ b/src/components/OutputsTable/OutputCell/ResponseLog.tsx
@@ -0,0 +1,22 @@
+import { HStack, VStack, Text } from "@chakra-ui/react";
+import dayjs from "dayjs";
+
+export const ResponseLog = ({
+ time,
+ title,
+ message,
+}: {
+ time: Date;
+ title: string;
+ message?: string;
+}) => {
+ return (
+
+
+ {dayjs(time).format("HH:mm:ss")}
+ {title}
+
+ {message && {message}}
+
+ );
+};
diff --git a/src/components/OutputsTable/OutputCell/ErrorHandler.tsx b/src/components/OutputsTable/OutputCell/RetryCountdown.tsx
similarity index 50%
rename from src/components/OutputsTable/OutputCell/ErrorHandler.tsx
rename to src/components/OutputsTable/OutputCell/RetryCountdown.tsx
index b626bcc..d836da3 100644
--- a/src/components/OutputsTable/OutputCell/ErrorHandler.tsx
+++ b/src/components/OutputsTable/OutputCell/RetryCountdown.tsx
@@ -1,21 +1,12 @@
-import { type ScenarioVariantCell } from "@prisma/client";
-import { VStack, Text } from "@chakra-ui/react";
+import { Text } from "@chakra-ui/react";
import { useEffect, useState } from "react";
import pluralize from "pluralize";
-export const ErrorHandler = ({
- cell,
- refetchOutput,
-}: {
- cell: ScenarioVariantCell;
- refetchOutput: () => void;
-}) => {
+export const RetryCountdown = ({ retryTime }: { retryTime: Date }) => {
const [msToWait, setMsToWait] = useState(0);
useEffect(() => {
- if (!cell.retryTime) return;
-
- const initialWaitTime = cell.retryTime.getTime() - Date.now();
+ const initialWaitTime = retryTime.getTime() - Date.now();
const msModuloOneSecond = initialWaitTime % 1000;
let remainingTime = initialWaitTime - msModuloOneSecond;
setMsToWait(remainingTime);
@@ -36,18 +27,13 @@ export const ErrorHandler = ({
clearInterval(interval);
clearTimeout(timeout);
};
- }, [cell.retryTime, cell.statusCode, setMsToWait, refetchOutput]);
+ }, [retryTime]);
+
+ if (msToWait <= 0) return null;
return (
-
-
- {cell.errorMessage}
-
- {msToWait > 0 && (
-
- Retrying in {pluralize("second", Math.ceil(msToWait / 1000), true)}...
-
- )}
-
+
+ Retrying in {pluralize("second", Math.ceil(msToWait / 1000), true)}...
+
);
};
diff --git a/src/server/api/routers/experiments.router.ts b/src/server/api/routers/experiments.router.ts
index c5f4bf4..26f87ff 100644
--- a/src/server/api/routers/experiments.router.ts
+++ b/src/server/api/routers/experiments.router.ts
@@ -118,7 +118,7 @@ export const experimentsRouter = createTRPCRouter({
},
},
include: {
- modelOutput: {
+ modelResponses: {
include: {
outputEvaluations: true,
},
@@ -177,11 +177,11 @@ export const experimentsRouter = createTRPCRouter({
}
const cellsToCreate: Prisma.ScenarioVariantCellCreateManyInput[] = [];
- const modelOutputsToCreate: Prisma.ModelOutputCreateManyInput[] = [];
+ const modelResponsesToCreate: Prisma.ModelResponseCreateManyInput[] = [];
const outputEvaluationsToCreate: Prisma.OutputEvaluationCreateManyInput[] = [];
for (const cell of existingCells) {
const newCellId = uuidv4();
- const { modelOutput, ...cellData } = cell;
+ const { modelResponses, ...cellData } = cell;
cellsToCreate.push({
...cellData,
id: newCellId,
@@ -189,20 +189,20 @@ export const experimentsRouter = createTRPCRouter({
testScenarioId: existingToNewScenarioIds.get(cell.testScenarioId) ?? "",
prompt: (cell.prompt as Prisma.InputJsonValue) ?? undefined,
});
- if (modelOutput) {
- const newModelOutputId = uuidv4();
- const { outputEvaluations, ...modelOutputData } = modelOutput;
- modelOutputsToCreate.push({
- ...modelOutputData,
- id: newModelOutputId,
+ for (const modelResponse of modelResponses) {
+ const newModelResponseId = uuidv4();
+ const { outputEvaluations, ...modelResponseData } = modelResponse;
+ modelResponsesToCreate.push({
+ ...modelResponseData,
+ id: newModelResponseId,
scenarioVariantCellId: newCellId,
- output: (modelOutput.output as Prisma.InputJsonValue) ?? undefined,
+ output: (modelResponse.output as Prisma.InputJsonValue) ?? undefined,
});
for (const evaluation of outputEvaluations) {
outputEvaluationsToCreate.push({
...evaluation,
id: uuidv4(),
- modelOutputId: newModelOutputId,
+ modelResponseId: newModelResponseId,
evaluationId: existingToNewEvaluationIds.get(evaluation.evaluationId) ?? "",
});
}
@@ -245,8 +245,8 @@ export const experimentsRouter = createTRPCRouter({
prisma.scenarioVariantCell.createMany({
data: cellsToCreate,
}),
- prisma.modelOutput.createMany({
- data: modelOutputsToCreate,
+ prisma.modelResponse.createMany({
+ data: modelResponsesToCreate,
}),
prisma.evaluation.createMany({
data: evaluationsToCreate,
diff --git a/src/server/api/routers/promptVariants.router.ts b/src/server/api/routers/promptVariants.router.ts
index 8832dc4..92e6cd5 100644
--- a/src/server/api/routers/promptVariants.router.ts
+++ b/src/server/api/routers/promptVariants.router.ts
@@ -1,6 +1,7 @@
import { z } from "zod";
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
+import { Prisma } from "@prisma/client";
import { generateNewCell } from "~/server/utils/generateNewCell";
import userError from "~/server/utils/error";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
@@ -51,7 +52,9 @@ export const promptVariantsRouter = createTRPCRouter({
id: true,
},
where: {
- modelOutput: {
+ modelResponse: {
+ outdated: false,
+ output: { not: Prisma.AnyNull },
scenarioVariantCell: {
promptVariant: {
id: input.variantId,
@@ -93,14 +96,23 @@ export const promptVariantsRouter = createTRPCRouter({
where: {
promptVariantId: input.variantId,
testScenario: { visible: true },
- modelOutput: {
- is: {},
+ modelResponses: {
+ some: {
+ outdated: false,
+ output: {
+ not: Prisma.AnyNull,
+ },
+ },
},
},
});
- const overallTokens = await prisma.modelOutput.aggregate({
+ const overallTokens = await prisma.modelResponse.aggregate({
where: {
+ outdated: false,
+ output: {
+ not: Prisma.AnyNull,
+ },
scenarioVariantCell: {
promptVariantId: input.variantId,
testScenario: {
diff --git a/src/server/api/routers/scenarioVariantCells.router.ts b/src/server/api/routers/scenarioVariantCells.router.ts
index b0dbecb..8a9a60d 100644
--- a/src/server/api/routers/scenarioVariantCells.router.ts
+++ b/src/server/api/routers/scenarioVariantCells.router.ts
@@ -27,7 +27,10 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
},
},
include: {
- modelOutput: {
+ modelResponses: {
+ where: {
+ outdated: false,
+ },
include: {
outputEvaluations: {
include: {
@@ -62,7 +65,6 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
testScenarioId: input.scenarioId,
},
},
- include: { modelOutput: true },
});
if (!cell) {
@@ -70,12 +72,12 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
return;
}
- if (cell.modelOutput) {
- // TODO: Maybe keep these around to show previous generations?
- await prisma.modelOutput.delete({
- where: { id: cell.modelOutput.id },
- });
- }
+ await prisma.modelResponse.updateMany({
+ where: { scenarioVariantCellId: cell.id },
+ data: {
+ outdated: true,
+ },
+ });
await queueQueryModel(cell.id, true);
}),
diff --git a/src/server/tasks/queryModel.task.ts b/src/server/tasks/queryModel.task.ts
index 64b4f61..de42388 100644
--- a/src/server/tasks/queryModel.task.ts
+++ b/src/server/tasks/queryModel.task.ts
@@ -29,17 +29,9 @@ export const queryModel = defineTask("queryModel", async (task) =
const { cellId, stream } = task;
const cell = await prisma.scenarioVariantCell.findUnique({
where: { id: cellId },
- include: { modelOutput: true },
+ include: { modelResponses: true },
});
if (!cell) {
- await prisma.scenarioVariantCell.update({
- where: { id: cellId },
- data: {
- statusCode: 404,
- errorMessage: "Cell not found",
- retrievalStatus: "ERROR",
- },
- });
return;
}
@@ -51,6 +43,7 @@ export const queryModel = defineTask("queryModel", async (task) =
where: { id: cellId },
data: {
retrievalStatus: "IN_PROGRESS",
+ jobStartedAt: new Date(),
},
});
@@ -61,7 +54,6 @@ export const queryModel = defineTask("queryModel", async (task) =
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
- statusCode: 404,
errorMessage: "Prompt Variant not found",
retrievalStatus: "ERROR",
},
@@ -76,7 +68,6 @@ export const queryModel = defineTask("queryModel", async (task) =
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
- statusCode: 404,
errorMessage: "Scenario not found",
retrievalStatus: "ERROR",
},
@@ -90,7 +81,6 @@ export const queryModel = defineTask("queryModel", async (task) =
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
- statusCode: 400,
errorMessage: prompt.error,
retrievalStatus: "ERROR",
},
@@ -106,17 +96,24 @@ export const queryModel = defineTask("queryModel", async (task) =
}
: null;
+ const inputHash = hashPrompt(prompt);
+
for (let i = 0; true; i++) {
+ const modelResponse = await prisma.modelResponse.create({
+ data: {
+ inputHash,
+ scenarioVariantCellId: cellId,
+ requestedAt: new Date(),
+ },
+ });
const response = await provider.getCompletion(prompt.modelInput, onStream);
if (response.type === "success") {
- const inputHash = hashPrompt(prompt);
-
- const modelOutput = await prisma.modelOutput.create({
+ await prisma.modelResponse.update({
+ where: { id: modelResponse.id },
data: {
- scenarioVariantCellId: cellId,
- inputHash,
output: response.value as Prisma.InputJsonObject,
- timeToComplete: response.timeToComplete,
+ statusCode: response.statusCode,
+ receivedAt: new Date(),
promptTokens: response.promptTokens,
completionTokens: response.completionTokens,
cost: response.cost,
@@ -126,30 +123,35 @@ export const queryModel = defineTask("queryModel", async (task) =
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
- statusCode: response.statusCode,
retrievalStatus: "COMPLETE",
},
});
- await runEvalsForOutput(variant.experimentId, scenario, modelOutput);
+ await runEvalsForOutput(variant.experimentId, scenario, modelResponse);
break;
} else {
const shouldRetry = response.autoRetry && i < MAX_AUTO_RETRIES;
const delay = calculateDelay(i);
- await prisma.scenarioVariantCell.update({
- where: { id: cellId },
+ await prisma.modelResponse.update({
+ where: { id: modelResponse.id },
data: {
- errorMessage: response.message,
statusCode: response.statusCode,
+ errorMessage: response.message,
+ receivedAt: new Date(),
retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
- retrievalStatus: "ERROR",
},
});
if (shouldRetry) {
await sleep(delay);
} else {
+ await prisma.scenarioVariantCell.update({
+ where: { id: cellId },
+ data: {
+ retrievalStatus: "ERROR",
+ },
+ });
break;
}
}
@@ -165,6 +167,7 @@ export const queueQueryModel = async (cellId: string, stream: boolean) => {
data: {
retrievalStatus: "PENDING",
errorMessage: null,
+ jobQueuedAt: new Date(),
},
}),
queryModel.enqueue({ cellId, stream }),
diff --git a/src/server/utils/evaluations.ts b/src/server/utils/evaluations.ts
index 1bcadab..530598c 100644
--- a/src/server/utils/evaluations.ts
+++ b/src/server/utils/evaluations.ts
@@ -1,19 +1,23 @@
-import { type ModelOutput, type Evaluation } from "@prisma/client";
+import { type ModelResponse, type Evaluation, Prisma } from "@prisma/client";
import { prisma } from "../db";
import { runOneEval } from "./runOneEval";
import { type Scenario } from "~/components/OutputsTable/types";
-const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutput: ModelOutput) => {
- const result = await runOneEval(evaluation, scenario, modelOutput);
+const saveResult = async (
+ evaluation: Evaluation,
+ scenario: Scenario,
+ modelResponse: ModelResponse,
+) => {
+ const result = await runOneEval(evaluation, scenario, modelResponse);
return await prisma.outputEvaluation.upsert({
where: {
- modelOutputId_evaluationId: {
- modelOutputId: modelOutput.id,
+ modelResponseId_evaluationId: {
+ modelResponseId: modelResponse.id,
evaluationId: evaluation.id,
},
},
create: {
- modelOutputId: modelOutput.id,
+ modelResponseId: modelResponse.id,
evaluationId: evaluation.id,
...result,
},
@@ -26,20 +30,24 @@ const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutpu
export const runEvalsForOutput = async (
experimentId: string,
scenario: Scenario,
- modelOutput: ModelOutput,
+ modelResponse: ModelResponse,
) => {
const evaluations = await prisma.evaluation.findMany({
where: { experimentId },
});
await Promise.all(
- evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelOutput)),
+ evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelResponse)),
);
};
export const runAllEvals = async (experimentId: string) => {
- const outputs = await prisma.modelOutput.findMany({
+ const outputs = await prisma.modelResponse.findMany({
where: {
+ outdated: false,
+ output: {
+ not: Prisma.AnyNull,
+ },
scenarioVariantCell: {
promptVariant: {
experimentId,
diff --git a/src/server/utils/generateNewCell.ts b/src/server/utils/generateNewCell.ts
index 0c4c2be..dcc9977 100644
--- a/src/server/utils/generateNewCell.ts
+++ b/src/server/utils/generateNewCell.ts
@@ -1,4 +1,4 @@
-import { type Prisma } from "@prisma/client";
+import { Prisma } from "@prisma/client";
import { prisma } from "../db";
import parseConstructFn from "./parseConstructFn";
import { type JsonObject } from "type-fest";
@@ -35,7 +35,7 @@ export const generateNewCell = async (
},
},
include: {
- modelOutput: true,
+ modelResponses: true,
},
});
@@ -51,8 +51,6 @@ export const generateNewCell = async (
data: {
promptVariantId: variantId,
testScenarioId: scenarioId,
- statusCode: 400,
- errorMessage: parsedConstructFn.error,
retrievalStatus: "ERROR",
},
});
@@ -69,36 +67,55 @@ export const generateNewCell = async (
retrievalStatus: "PENDING",
},
include: {
- modelOutput: true,
+ modelResponses: true,
},
});
- const matchingModelOutput = await prisma.modelOutput.findFirst({
- where: { inputHash },
+ const matchingModelResponse = await prisma.modelResponse.findFirst({
+ where: {
+ inputHash,
+ output: {
+ not: Prisma.AnyNull,
+ },
+ },
+ orderBy: {
+ receivedAt: "desc",
+ },
+ include: {
+ scenarioVariantCell: true,
+ },
+ take: 1,
});
- if (matchingModelOutput) {
- const newModelOutput = await prisma.modelOutput.create({
+ if (matchingModelResponse) {
+ const newModelResponse = await prisma.modelResponse.create({
data: {
- ...omit(matchingModelOutput, ["id"]),
+ ...omit(matchingModelResponse, ["id", "scenarioVariantCell"]),
scenarioVariantCellId: cell.id,
- output: matchingModelOutput.output as Prisma.InputJsonValue,
+ output: matchingModelResponse.output as Prisma.InputJsonValue,
},
});
+
await prisma.scenarioVariantCell.update({
where: { id: cell.id },
- data: { retrievalStatus: "COMPLETE" },
+ data: {
+ retrievalStatus: "COMPLETE",
+ jobStartedAt: matchingModelResponse.scenarioVariantCell.jobStartedAt,
+ jobQueuedAt: matchingModelResponse.scenarioVariantCell.jobQueuedAt,
+ },
});
// Copy over all eval results as well
await Promise.all(
(
- await prisma.outputEvaluation.findMany({ where: { modelOutputId: matchingModelOutput.id } })
+ await prisma.outputEvaluation.findMany({
+ where: { modelResponseId: matchingModelResponse.id },
+ })
).map(async (evaluation) => {
await prisma.outputEvaluation.create({
data: {
...omit(evaluation, ["id"]),
- modelOutputId: newModelOutput.id,
+ modelResponseId: newModelResponse.id,
},
});
}),
diff --git a/src/server/utils/runOneEval.ts b/src/server/utils/runOneEval.ts
index 2c2693d..b38abb6 100644
--- a/src/server/utils/runOneEval.ts
+++ b/src/server/utils/runOneEval.ts
@@ -1,4 +1,4 @@
-import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/client";
+import { type Evaluation, type ModelResponse, type TestScenario } from "@prisma/client";
import { type ChatCompletion } from "openai/resources/chat";
import { type VariableMap, fillTemplate, escapeRegExp, escapeQuotes } from "./fillTemplate";
import { openai } from "./openai";
@@ -70,9 +70,9 @@ export const runGpt4Eval = async (
export const runOneEval = async (
evaluation: Evaluation,
scenario: TestScenario,
- modelOutput: ModelOutput,
+ modelResponse: ModelResponse,
): Promise<{ result: number; details?: string }> => {
- const output = modelOutput.output as unknown as ChatCompletion;
+ const output = modelResponse.output as unknown as ChatCompletion;
const message = output?.choices?.[0]?.message;