Store multiple ModelResponses (#95)

* Store multiple ModelResponses

* Fix prettier

* Add CellContent container
This commit is contained in:
arcticfly
2023-07-25 18:54:38 -07:00
committed by GitHub
parent 45afb1f1f4
commit 98b231c8bd
15 changed files with 341 additions and 159 deletions

View File

@@ -0,0 +1,52 @@
-- DropForeignKey
ALTER TABLE "ModelOutput" DROP CONSTRAINT "ModelOutput_scenarioVariantCellId_fkey";
-- DropForeignKey
ALTER TABLE "OutputEvaluation" DROP CONSTRAINT "OutputEvaluation_modelOutputId_fkey";
-- DropIndex
DROP INDEX "OutputEvaluation_modelOutputId_evaluationId_key";
-- AlterTable
ALTER TABLE "OutputEvaluation" RENAME COLUMN "modelOutputId" TO "modelResponseId";
-- AlterTable
ALTER TABLE "ScenarioVariantCell" DROP COLUMN "retryTime",
DROP COLUMN "statusCode",
ADD COLUMN "jobQueuedAt" TIMESTAMP(3),
ADD COLUMN "jobStartedAt" TIMESTAMP(3);
ALTER TABLE "ModelOutput" RENAME TO "ModelResponse";
ALTER TABLE "ModelResponse"
ADD COLUMN "requestedAt" TIMESTAMP(3),
ADD COLUMN "receivedAt" TIMESTAMP(3),
ADD COLUMN "statusCode" INTEGER,
ADD COLUMN "errorMessage" TEXT,
ADD COLUMN "retryTime" TIMESTAMP(3),
ADD COLUMN "outdated" BOOLEAN NOT NULL DEFAULT false;
-- 3. Remove the unnecessary column
ALTER TABLE "ModelResponse"
DROP COLUMN "timeToComplete";
-- AlterTable
ALTER TABLE "ModelResponse" RENAME CONSTRAINT "ModelOutput_pkey" TO "ModelResponse_pkey";
ALTER TABLE "ModelResponse" ALTER COLUMN "output" DROP NOT NULL;
-- DropIndex
DROP INDEX "ModelOutput_scenarioVariantCellId_key";
-- AddForeignKey
ALTER TABLE "ModelResponse" ADD CONSTRAINT "ModelResponse_scenarioVariantCellId_fkey" FOREIGN KEY ("scenarioVariantCellId") REFERENCES "ScenarioVariantCell"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- RenameIndex
ALTER INDEX "ModelOutput_inputHash_idx" RENAME TO "ModelResponse_inputHash_idx";
-- CreateIndex
CREATE UNIQUE INDEX "OutputEvaluation_modelResponseId_evaluationId_key" ON "OutputEvaluation"("modelResponseId", "evaluationId");
-- AddForeignKey
ALTER TABLE "OutputEvaluation" ADD CONSTRAINT "OutputEvaluation_modelResponseId_fkey" FOREIGN KEY ("modelResponseId") REFERENCES "ModelResponse"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -90,12 +90,11 @@ enum CellRetrievalStatus {
model ScenarioVariantCell {
id String @id @default(uuid()) @db.Uuid
statusCode Int?
errorMessage String?
retryTime DateTime?
retrievalStatus CellRetrievalStatus @default(COMPLETE)
modelOutput ModelOutput?
jobQueuedAt DateTime?
jobStartedAt DateTime?
modelResponses ModelResponse[]
errorMessage String? // Contains errors that occurred independently of model responses
promptVariantId String @db.Uuid
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
@@ -110,15 +109,20 @@ model ScenarioVariantCell {
@@unique([promptVariantId, testScenarioId])
}
model ModelOutput {
model ModelResponse {
id String @id @default(uuid()) @db.Uuid
inputHash String
output Json
timeToComplete Int @default(0)
requestedAt DateTime?
receivedAt DateTime?
output Json?
cost Float?
promptTokens Int?
completionTokens Int?
statusCode Int?
errorMessage String?
retryTime DateTime?
outdated Boolean @default(false)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
@@ -127,7 +131,6 @@ model ModelOutput {
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
outputEvaluations OutputEvaluation[]
@@unique([scenarioVariantCellId])
@@index([inputHash])
}
@@ -159,8 +162,8 @@ model OutputEvaluation {
result Float
details String?
modelOutputId String @db.Uuid
modelOutput ModelOutput @relation(fields: [modelOutputId], references: [id], onDelete: Cascade)
modelResponseId String @db.Uuid
modelResponse ModelResponse @relation(fields: [modelResponseId], references: [id], onDelete: Cascade)
evaluationId String @db.Uuid
evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade)
@@ -168,7 +171,7 @@ model OutputEvaluation {
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
@@unique([modelOutputId, evaluationId])
@@unique([modelResponseId, evaluationId])
}
model Organization {

View File

@@ -0,0 +1,17 @@
import { type StackProps, VStack } from "@chakra-ui/react";
import { CellOptions } from "./CellOptions";
export const CellContent = ({
hardRefetch,
hardRefetching,
children,
...props
}: {
hardRefetch: () => void;
hardRefetching: boolean;
} & StackProps) => (
<VStack maxH={500} w="full" overflowY="auto" alignItems="flex-start" {...props}>
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
{children}
</VStack>
);

View File

@@ -1,4 +1,4 @@
import { Button, HStack, Icon, Tooltip } from "@chakra-ui/react";
import { Button, HStack, Icon, Spinner, Tooltip } from "@chakra-ui/react";
import { BsArrowClockwise } from "react-icons/bs";
import { useExperimentAccess } from "~/utils/hooks";
@@ -12,7 +12,7 @@ export const CellOptions = ({
const { canModify } = useExperimentAccess();
return (
<HStack justifyContent="flex-end" w="full">
{!refetchingOutput && canModify && (
{canModify && (
<Tooltip label="Refetch output" aria-label="refetch output">
<Button
size="xs"
@@ -28,7 +28,7 @@ export const CellOptions = ({
onClick={refetchOutput}
aria-label="refetch output"
>
<Icon as={BsArrowClockwise} boxSize={4} />
<Icon as={refetchingOutput ? Spinner : BsArrowClockwise} boxSize={4} />
</Button>
</Tooltip>
)}

View File

@@ -1,16 +1,19 @@
import { api } from "~/utils/api";
import { type PromptVariant, type Scenario } from "../types";
import { Spinner, Text, Center, VStack } from "@chakra-ui/react";
import { Text, VStack } from "@chakra-ui/react";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import SyntaxHighlighter from "react-syntax-highlighter";
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
import stringify from "json-stringify-pretty-compact";
import { type ReactElement, useState, useEffect } from "react";
import { type ReactElement, useState, useEffect, Fragment } from "react";
import useSocket from "~/utils/useSocket";
import { OutputStats } from "./OutputStats";
import { ErrorHandler } from "./ErrorHandler";
import { CellOptions } from "./CellOptions";
import { RetryCountdown } from "./RetryCountdown";
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
import { ResponseLog } from "./ResponseLog";
import { CellContent } from "./CellContent";
const WAITING_MESSAGE_INTERVAL = 20000;
export default function OutputCell({
scenario,
@@ -65,46 +68,91 @@ export default function OutputCell({
hardRefetching;
useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]);
const modelOutput = cell?.modelOutput;
// TODO: disconnect from socket if we're not streaming anymore
const streamedMessage = useSocket<OutputSchema>(cell?.id);
if (!vars) return null;
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
if (awaitingOutput && !streamedMessage)
return (
<Center h="100%" w="100%">
<Spinner />
</Center>
);
if (!cell && !fetchingOutput)
return (
<VStack>
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
<CellContent hardRefetching={hardRefetching} hardRefetch={hardRefetch}>
<Text color="gray.500">Error retrieving output</Text>
</VStack>
</CellContent>
);
if (cell && cell.errorMessage) {
return (
<VStack>
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
<ErrorHandler cell={cell} refetchOutput={hardRefetch} />
</VStack>
<CellContent hardRefetching={hardRefetching} hardRefetch={hardRefetch}>
<Text color="red.500">{cell.errorMessage}</Text>
</CellContent>
);
}
const normalizedOutput = modelOutput
? provider.normalizeOutput(modelOutput.output)
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
const mostRecentResponse = cell?.modelResponses[cell.modelResponses.length - 1];
const showLogs = !streamedMessage && !mostRecentResponse?.output;
if (showLogs)
return (
<CellContent
hardRefetching={hardRefetching}
hardRefetch={hardRefetch}
alignItems="flex-start"
fontFamily="inconsolata, monospace"
spacing={0}
>
{cell?.jobQueuedAt && <ResponseLog time={cell.jobQueuedAt} title="Job queued" />}
{cell?.jobStartedAt && <ResponseLog time={cell.jobStartedAt} title="Job started" />}
{cell?.modelResponses?.map((response) => {
let numWaitingMessages = 0;
const relativeWaitingTime = response.receivedAt
? response.receivedAt.getTime()
: Date.now();
if (response.requestedAt) {
numWaitingMessages = Math.floor(
(relativeWaitingTime - response.requestedAt.getTime()) / WAITING_MESSAGE_INTERVAL,
);
}
return (
<Fragment key={response.id}>
{response.requestedAt && (
<ResponseLog time={response.requestedAt} title="Request sent to API" />
)}
{response.requestedAt &&
Array.from({ length: numWaitingMessages }, (_, i) => (
<ResponseLog
key={`waiting-${i}`}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
time={new Date(response.requestedAt!.getTime() + i * WAITING_MESSAGE_INTERVAL)}
title="Waiting for response"
/>
))}
{response.receivedAt && (
<ResponseLog
time={response.receivedAt}
title="Response received from API"
message={`statusCode: ${response.statusCode ?? ""}\n ${
response.errorMessage ?? ""
}`}
/>
)}
</Fragment>
);
}) ?? null}
{mostRecentResponse?.retryTime && (
<RetryCountdown retryTime={mostRecentResponse.retryTime} />
)}
</CellContent>
);
const normalizedOutput = mostRecentResponse?.output
? provider.normalizeOutput(mostRecentResponse?.output)
: streamedMessage
? provider.normalizeOutput(streamedMessage)
: null;
if (modelOutput && normalizedOutput?.type === "json") {
if (mostRecentResponse?.output && normalizedOutput?.type === "json") {
return (
<VStack
w="100%"
@@ -114,8 +162,13 @@ export default function OutputCell({
overflowX="hidden"
justifyContent="space-between"
>
<VStack w="full" flex={1} spacing={0}>
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
<CellContent
hardRefetching={hardRefetching}
hardRefetch={hardRefetch}
w="full"
flex={1}
spacing={0}
>
<SyntaxHighlighter
customStyle={{ overflowX: "unset", width: "100%", flex: 1 }}
language="json"
@@ -127,8 +180,8 @@ export default function OutputCell({
>
{stringify(normalizedOutput.value, { maxLength: 40 })}
</SyntaxHighlighter>
</VStack>
<OutputStats modelOutput={modelOutput} scenario={scenario} />
</CellContent>
<OutputStats modelResponse={mostRecentResponse} scenario={scenario} />
</VStack>
);
}
@@ -138,10 +191,13 @@ export default function OutputCell({
return (
<VStack w="100%" h="100%" justifyContent="space-between" whiteSpace="pre-wrap">
<VStack w="full" alignItems="flex-start" spacing={0}>
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
<CellContent hardRefetching={hardRefetching} hardRefetch={hardRefetch}>
<Text>{contentToDisplay}</Text>
</CellContent>
</VStack>
{modelOutput && <OutputStats modelOutput={modelOutput} scenario={scenario} />}
{mostRecentResponse?.output && (
<OutputStats modelResponse={mostRecentResponse} scenario={scenario} />
)}
</VStack>
);
}

View File

@@ -7,28 +7,32 @@ import { CostTooltip } from "~/components/tooltip/CostTooltip";
const SHOW_TIME = true;
export const OutputStats = ({
modelOutput,
modelResponse,
}: {
modelOutput: NonNullable<
NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelOutput"]
modelResponse: NonNullable<
NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelResponses"][0]
>;
scenario: Scenario;
}) => {
const timeToComplete = modelOutput.timeToComplete;
const timeToComplete =
modelResponse.receivedAt && modelResponse.requestedAt
? modelResponse.receivedAt.getTime() - modelResponse.requestedAt.getTime()
: 0;
const promptTokens = modelOutput.promptTokens;
const completionTokens = modelOutput.completionTokens;
const promptTokens = modelResponse.promptTokens;
const completionTokens = modelResponse.completionTokens;
return (
<HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
<HStack flex={1}>
{modelOutput.outputEvaluations.map((evaluation) => {
{modelResponse.outputEvaluations.map((evaluation) => {
const passed = evaluation.result > 0.5;
return (
<Tooltip
isDisabled={!evaluation.details}
label={evaluation.details}
key={evaluation.id}
shouldWrapChildren
>
<HStack spacing={0}>
<Text>{evaluation.evaluation.label}</Text>
@@ -42,15 +46,15 @@ export const OutputStats = ({
);
})}
</HStack>
{modelOutput.cost && (
{modelResponse.cost && (
<CostTooltip
promptTokens={promptTokens}
completionTokens={completionTokens}
cost={modelOutput.cost}
cost={modelResponse.cost}
>
<HStack spacing={0}>
<Icon as={BsCurrencyDollar} />
<Text mr={1}>{modelOutput.cost.toFixed(3)}</Text>
<Text mr={1}>{modelResponse.cost.toFixed(3)}</Text>
</HStack>
</CostTooltip>
)}

View File

@@ -0,0 +1,22 @@
import { HStack, VStack, Text } from "@chakra-ui/react";
import dayjs from "dayjs";
export const ResponseLog = ({
time,
title,
message,
}: {
time: Date;
title: string;
message?: string;
}) => {
return (
<VStack spacing={0} alignItems="flex-start">
<HStack>
<Text>{dayjs(time).format("HH:mm:ss")}</Text>
<Text>{title}</Text>
</HStack>
{message && <Text pl={4}>{message}</Text>}
</VStack>
);
};

View File

@@ -1,21 +1,12 @@
import { type ScenarioVariantCell } from "@prisma/client";
import { VStack, Text } from "@chakra-ui/react";
import { Text } from "@chakra-ui/react";
import { useEffect, useState } from "react";
import pluralize from "pluralize";
export const ErrorHandler = ({
cell,
refetchOutput,
}: {
cell: ScenarioVariantCell;
refetchOutput: () => void;
}) => {
export const RetryCountdown = ({ retryTime }: { retryTime: Date }) => {
const [msToWait, setMsToWait] = useState(0);
useEffect(() => {
if (!cell.retryTime) return;
const initialWaitTime = cell.retryTime.getTime() - Date.now();
const initialWaitTime = retryTime.getTime() - Date.now();
const msModuloOneSecond = initialWaitTime % 1000;
let remainingTime = initialWaitTime - msModuloOneSecond;
setMsToWait(remainingTime);
@@ -36,18 +27,13 @@ export const ErrorHandler = ({
clearInterval(interval);
clearTimeout(timeout);
};
}, [cell.retryTime, cell.statusCode, setMsToWait, refetchOutput]);
}, [retryTime]);
if (msToWait <= 0) return null;
return (
<VStack w="full">
<Text color="red.600" wordBreak="break-word">
{cell.errorMessage}
</Text>
{msToWait > 0 && (
<Text color="red.600" fontSize="sm">
Retrying in {pluralize("second", Math.ceil(msToWait / 1000), true)}...
</Text>
)}
</VStack>
);
};

View File

@@ -118,7 +118,7 @@ export const experimentsRouter = createTRPCRouter({
},
},
include: {
modelOutput: {
modelResponses: {
include: {
outputEvaluations: true,
},
@@ -177,11 +177,11 @@ export const experimentsRouter = createTRPCRouter({
}
const cellsToCreate: Prisma.ScenarioVariantCellCreateManyInput[] = [];
const modelOutputsToCreate: Prisma.ModelOutputCreateManyInput[] = [];
const modelResponsesToCreate: Prisma.ModelResponseCreateManyInput[] = [];
const outputEvaluationsToCreate: Prisma.OutputEvaluationCreateManyInput[] = [];
for (const cell of existingCells) {
const newCellId = uuidv4();
const { modelOutput, ...cellData } = cell;
const { modelResponses, ...cellData } = cell;
cellsToCreate.push({
...cellData,
id: newCellId,
@@ -189,20 +189,20 @@ export const experimentsRouter = createTRPCRouter({
testScenarioId: existingToNewScenarioIds.get(cell.testScenarioId) ?? "",
prompt: (cell.prompt as Prisma.InputJsonValue) ?? undefined,
});
if (modelOutput) {
const newModelOutputId = uuidv4();
const { outputEvaluations, ...modelOutputData } = modelOutput;
modelOutputsToCreate.push({
...modelOutputData,
id: newModelOutputId,
for (const modelResponse of modelResponses) {
const newModelResponseId = uuidv4();
const { outputEvaluations, ...modelResponseData } = modelResponse;
modelResponsesToCreate.push({
...modelResponseData,
id: newModelResponseId,
scenarioVariantCellId: newCellId,
output: (modelOutput.output as Prisma.InputJsonValue) ?? undefined,
output: (modelResponse.output as Prisma.InputJsonValue) ?? undefined,
});
for (const evaluation of outputEvaluations) {
outputEvaluationsToCreate.push({
...evaluation,
id: uuidv4(),
modelOutputId: newModelOutputId,
modelResponseId: newModelResponseId,
evaluationId: existingToNewEvaluationIds.get(evaluation.evaluationId) ?? "",
});
}
@@ -245,8 +245,8 @@ export const experimentsRouter = createTRPCRouter({
prisma.scenarioVariantCell.createMany({
data: cellsToCreate,
}),
prisma.modelOutput.createMany({
data: modelOutputsToCreate,
prisma.modelResponse.createMany({
data: modelResponsesToCreate,
}),
prisma.evaluation.createMany({
data: evaluationsToCreate,

View File

@@ -1,6 +1,7 @@
import { z } from "zod";
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { Prisma } from "@prisma/client";
import { generateNewCell } from "~/server/utils/generateNewCell";
import userError from "~/server/utils/error";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
@@ -51,7 +52,9 @@ export const promptVariantsRouter = createTRPCRouter({
id: true,
},
where: {
modelOutput: {
modelResponse: {
outdated: false,
output: { not: Prisma.AnyNull },
scenarioVariantCell: {
promptVariant: {
id: input.variantId,
@@ -93,14 +96,23 @@ export const promptVariantsRouter = createTRPCRouter({
where: {
promptVariantId: input.variantId,
testScenario: { visible: true },
modelOutput: {
is: {},
modelResponses: {
some: {
outdated: false,
output: {
not: Prisma.AnyNull,
},
},
},
},
});
const overallTokens = await prisma.modelOutput.aggregate({
const overallTokens = await prisma.modelResponse.aggregate({
where: {
outdated: false,
output: {
not: Prisma.AnyNull,
},
scenarioVariantCell: {
promptVariantId: input.variantId,
testScenario: {

View File

@@ -27,7 +27,10 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
},
},
include: {
modelOutput: {
modelResponses: {
where: {
outdated: false,
},
include: {
outputEvaluations: {
include: {
@@ -62,7 +65,6 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
testScenarioId: input.scenarioId,
},
},
include: { modelOutput: true },
});
if (!cell) {
@@ -70,12 +72,12 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
return;
}
if (cell.modelOutput) {
// TODO: Maybe keep these around to show previous generations?
await prisma.modelOutput.delete({
where: { id: cell.modelOutput.id },
await prisma.modelResponse.updateMany({
where: { scenarioVariantCellId: cell.id },
data: {
outdated: true,
},
});
}
await queueQueryModel(cell.id, true);
}),

View File

@@ -29,17 +29,9 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
const { cellId, stream } = task;
const cell = await prisma.scenarioVariantCell.findUnique({
where: { id: cellId },
include: { modelOutput: true },
include: { modelResponses: true },
});
if (!cell) {
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
statusCode: 404,
errorMessage: "Cell not found",
retrievalStatus: "ERROR",
},
});
return;
}
@@ -51,6 +43,7 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
where: { id: cellId },
data: {
retrievalStatus: "IN_PROGRESS",
jobStartedAt: new Date(),
},
});
@@ -61,7 +54,6 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
statusCode: 404,
errorMessage: "Prompt Variant not found",
retrievalStatus: "ERROR",
},
@@ -76,7 +68,6 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
statusCode: 404,
errorMessage: "Scenario not found",
retrievalStatus: "ERROR",
},
@@ -90,7 +81,6 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
statusCode: 400,
errorMessage: prompt.error,
retrievalStatus: "ERROR",
},
@@ -106,17 +96,24 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
}
: null;
for (let i = 0; true; i++) {
const response = await provider.getCompletion(prompt.modelInput, onStream);
if (response.type === "success") {
const inputHash = hashPrompt(prompt);
const modelOutput = await prisma.modelOutput.create({
for (let i = 0; true; i++) {
const modelResponse = await prisma.modelResponse.create({
data: {
scenarioVariantCellId: cellId,
inputHash,
scenarioVariantCellId: cellId,
requestedAt: new Date(),
},
});
const response = await provider.getCompletion(prompt.modelInput, onStream);
if (response.type === "success") {
await prisma.modelResponse.update({
where: { id: modelResponse.id },
data: {
output: response.value as Prisma.InputJsonObject,
timeToComplete: response.timeToComplete,
statusCode: response.statusCode,
receivedAt: new Date(),
promptTokens: response.promptTokens,
completionTokens: response.completionTokens,
cost: response.cost,
@@ -126,30 +123,35 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
statusCode: response.statusCode,
retrievalStatus: "COMPLETE",
},
});
await runEvalsForOutput(variant.experimentId, scenario, modelOutput);
await runEvalsForOutput(variant.experimentId, scenario, modelResponse);
break;
} else {
const shouldRetry = response.autoRetry && i < MAX_AUTO_RETRIES;
const delay = calculateDelay(i);
await prisma.scenarioVariantCell.update({
where: { id: cellId },
await prisma.modelResponse.update({
where: { id: modelResponse.id },
data: {
errorMessage: response.message,
statusCode: response.statusCode,
errorMessage: response.message,
receivedAt: new Date(),
retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
retrievalStatus: "ERROR",
},
});
if (shouldRetry) {
await sleep(delay);
} else {
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
retrievalStatus: "ERROR",
},
});
break;
}
}
@@ -165,6 +167,7 @@ export const queueQueryModel = async (cellId: string, stream: boolean) => {
data: {
retrievalStatus: "PENDING",
errorMessage: null,
jobQueuedAt: new Date(),
},
}),
queryModel.enqueue({ cellId, stream }),

View File

@@ -1,19 +1,23 @@
import { type ModelOutput, type Evaluation } from "@prisma/client";
import { type ModelResponse, type Evaluation, Prisma } from "@prisma/client";
import { prisma } from "../db";
import { runOneEval } from "./runOneEval";
import { type Scenario } from "~/components/OutputsTable/types";
const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutput: ModelOutput) => {
const result = await runOneEval(evaluation, scenario, modelOutput);
const saveResult = async (
evaluation: Evaluation,
scenario: Scenario,
modelResponse: ModelResponse,
) => {
const result = await runOneEval(evaluation, scenario, modelResponse);
return await prisma.outputEvaluation.upsert({
where: {
modelOutputId_evaluationId: {
modelOutputId: modelOutput.id,
modelResponseId_evaluationId: {
modelResponseId: modelResponse.id,
evaluationId: evaluation.id,
},
},
create: {
modelOutputId: modelOutput.id,
modelResponseId: modelResponse.id,
evaluationId: evaluation.id,
...result,
},
@@ -26,20 +30,24 @@ const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutpu
export const runEvalsForOutput = async (
experimentId: string,
scenario: Scenario,
modelOutput: ModelOutput,
modelResponse: ModelResponse,
) => {
const evaluations = await prisma.evaluation.findMany({
where: { experimentId },
});
await Promise.all(
evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelOutput)),
evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelResponse)),
);
};
export const runAllEvals = async (experimentId: string) => {
const outputs = await prisma.modelOutput.findMany({
const outputs = await prisma.modelResponse.findMany({
where: {
outdated: false,
output: {
not: Prisma.AnyNull,
},
scenarioVariantCell: {
promptVariant: {
experimentId,

View File

@@ -1,4 +1,4 @@
import { type Prisma } from "@prisma/client";
import { Prisma } from "@prisma/client";
import { prisma } from "../db";
import parseConstructFn from "./parseConstructFn";
import { type JsonObject } from "type-fest";
@@ -35,7 +35,7 @@ export const generateNewCell = async (
},
},
include: {
modelOutput: true,
modelResponses: true,
},
});
@@ -51,8 +51,6 @@ export const generateNewCell = async (
data: {
promptVariantId: variantId,
testScenarioId: scenarioId,
statusCode: 400,
errorMessage: parsedConstructFn.error,
retrievalStatus: "ERROR",
},
});
@@ -69,36 +67,55 @@ export const generateNewCell = async (
retrievalStatus: "PENDING",
},
include: {
modelOutput: true,
modelResponses: true,
},
});
const matchingModelOutput = await prisma.modelOutput.findFirst({
where: { inputHash },
const matchingModelResponse = await prisma.modelResponse.findFirst({
where: {
inputHash,
output: {
not: Prisma.AnyNull,
},
},
orderBy: {
receivedAt: "desc",
},
include: {
scenarioVariantCell: true,
},
take: 1,
});
if (matchingModelOutput) {
const newModelOutput = await prisma.modelOutput.create({
if (matchingModelResponse) {
const newModelResponse = await prisma.modelResponse.create({
data: {
...omit(matchingModelOutput, ["id"]),
...omit(matchingModelResponse, ["id", "scenarioVariantCell"]),
scenarioVariantCellId: cell.id,
output: matchingModelOutput.output as Prisma.InputJsonValue,
output: matchingModelResponse.output as Prisma.InputJsonValue,
},
});
await prisma.scenarioVariantCell.update({
where: { id: cell.id },
data: { retrievalStatus: "COMPLETE" },
data: {
retrievalStatus: "COMPLETE",
jobStartedAt: matchingModelResponse.scenarioVariantCell.jobStartedAt,
jobQueuedAt: matchingModelResponse.scenarioVariantCell.jobQueuedAt,
},
});
// Copy over all eval results as well
await Promise.all(
(
await prisma.outputEvaluation.findMany({ where: { modelOutputId: matchingModelOutput.id } })
await prisma.outputEvaluation.findMany({
where: { modelResponseId: matchingModelResponse.id },
})
).map(async (evaluation) => {
await prisma.outputEvaluation.create({
data: {
...omit(evaluation, ["id"]),
modelOutputId: newModelOutput.id,
modelResponseId: newModelResponse.id,
},
});
}),

View File

@@ -1,4 +1,4 @@
import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/client";
import { type Evaluation, type ModelResponse, type TestScenario } from "@prisma/client";
import { type ChatCompletion } from "openai/resources/chat";
import { type VariableMap, fillTemplate, escapeRegExp, escapeQuotes } from "./fillTemplate";
import { openai } from "./openai";
@@ -70,9 +70,9 @@ export const runGpt4Eval = async (
export const runOneEval = async (
evaluation: Evaluation,
scenario: TestScenario,
modelOutput: ModelOutput,
modelResponse: ModelResponse,
): Promise<{ result: number; details?: string }> => {
const output = modelOutput.output as unknown as ChatCompletion;
const output = modelResponse.output as unknown as ChatCompletion;
const message = output?.choices?.[0]?.message;