Store multiple ModelResponses (#95)

* Store multiple ModelResponses

* Fix prettier

* Add CellContent container
This commit is contained in:
arcticfly
2023-07-25 18:54:38 -07:00
committed by GitHub
parent 45afb1f1f4
commit 98b231c8bd
15 changed files with 341 additions and 159 deletions

View File

@@ -0,0 +1,52 @@
-- DropForeignKey
ALTER TABLE "ModelOutput" DROP CONSTRAINT "ModelOutput_scenarioVariantCellId_fkey";
-- DropForeignKey
ALTER TABLE "OutputEvaluation" DROP CONSTRAINT "OutputEvaluation_modelOutputId_fkey";
-- DropIndex
DROP INDEX "OutputEvaluation_modelOutputId_evaluationId_key";
-- AlterTable
ALTER TABLE "OutputEvaluation" RENAME COLUMN "modelOutputId" TO "modelResponseId";
-- AlterTable
ALTER TABLE "ScenarioVariantCell" DROP COLUMN "retryTime",
DROP COLUMN "statusCode",
ADD COLUMN "jobQueuedAt" TIMESTAMP(3),
ADD COLUMN "jobStartedAt" TIMESTAMP(3);
ALTER TABLE "ModelOutput" RENAME TO "ModelResponse";
ALTER TABLE "ModelResponse"
ADD COLUMN "requestedAt" TIMESTAMP(3),
ADD COLUMN "receivedAt" TIMESTAMP(3),
ADD COLUMN "statusCode" INTEGER,
ADD COLUMN "errorMessage" TEXT,
ADD COLUMN "retryTime" TIMESTAMP(3),
ADD COLUMN "outdated" BOOLEAN NOT NULL DEFAULT false;
-- 3. Remove the unnecessary column
ALTER TABLE "ModelResponse"
DROP COLUMN "timeToComplete";
-- AlterTable
ALTER TABLE "ModelResponse" RENAME CONSTRAINT "ModelOutput_pkey" TO "ModelResponse_pkey";
ALTER TABLE "ModelResponse" ALTER COLUMN "output" DROP NOT NULL;
-- DropIndex
DROP INDEX "ModelOutput_scenarioVariantCellId_key";
-- AddForeignKey
ALTER TABLE "ModelResponse" ADD CONSTRAINT "ModelResponse_scenarioVariantCellId_fkey" FOREIGN KEY ("scenarioVariantCellId") REFERENCES "ScenarioVariantCell"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- RenameIndex
ALTER INDEX "ModelOutput_inputHash_idx" RENAME TO "ModelResponse_inputHash_idx";
-- CreateIndex
CREATE UNIQUE INDEX "OutputEvaluation_modelResponseId_evaluationId_key" ON "OutputEvaluation"("modelResponseId", "evaluationId");
-- AddForeignKey
ALTER TABLE "OutputEvaluation" ADD CONSTRAINT "OutputEvaluation_modelResponseId_fkey" FOREIGN KEY ("modelResponseId") REFERENCES "ModelResponse"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -90,12 +90,11 @@ enum CellRetrievalStatus {
model ScenarioVariantCell { model ScenarioVariantCell {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
statusCode Int?
errorMessage String?
retryTime DateTime?
retrievalStatus CellRetrievalStatus @default(COMPLETE) retrievalStatus CellRetrievalStatus @default(COMPLETE)
jobQueuedAt DateTime?
modelOutput ModelOutput? jobStartedAt DateTime?
modelResponses ModelResponse[]
errorMessage String? // Contains errors that occurred independently of model responses
promptVariantId String @db.Uuid promptVariantId String @db.Uuid
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade) promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
@@ -110,15 +109,20 @@ model ScenarioVariantCell {
@@unique([promptVariantId, testScenarioId]) @@unique([promptVariantId, testScenarioId])
} }
model ModelOutput { model ModelResponse {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
inputHash String inputHash String
output Json requestedAt DateTime?
timeToComplete Int @default(0) receivedAt DateTime?
output Json?
cost Float? cost Float?
promptTokens Int? promptTokens Int?
completionTokens Int? completionTokens Int?
statusCode Int?
errorMessage String?
retryTime DateTime?
outdated Boolean @default(false)
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
@@ -127,7 +131,6 @@ model ModelOutput {
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade) scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
outputEvaluations OutputEvaluation[] outputEvaluations OutputEvaluation[]
@@unique([scenarioVariantCellId])
@@index([inputHash]) @@index([inputHash])
} }
@@ -159,8 +162,8 @@ model OutputEvaluation {
result Float result Float
details String? details String?
modelOutputId String @db.Uuid modelResponseId String @db.Uuid
modelOutput ModelOutput @relation(fields: [modelOutputId], references: [id], onDelete: Cascade) modelResponse ModelResponse @relation(fields: [modelResponseId], references: [id], onDelete: Cascade)
evaluationId String @db.Uuid evaluationId String @db.Uuid
evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade) evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade)
@@ -168,7 +171,7 @@ model OutputEvaluation {
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
@@unique([modelOutputId, evaluationId]) @@unique([modelResponseId, evaluationId])
} }
model Organization { model Organization {

View File

@@ -0,0 +1,17 @@
import { type StackProps, VStack } from "@chakra-ui/react";
import { CellOptions } from "./CellOptions";
export const CellContent = ({
hardRefetch,
hardRefetching,
children,
...props
}: {
hardRefetch: () => void;
hardRefetching: boolean;
} & StackProps) => (
<VStack maxH={500} w="full" overflowY="auto" alignItems="flex-start" {...props}>
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
{children}
</VStack>
);

View File

@@ -1,4 +1,4 @@
import { Button, HStack, Icon, Tooltip } from "@chakra-ui/react"; import { Button, HStack, Icon, Spinner, Tooltip } from "@chakra-ui/react";
import { BsArrowClockwise } from "react-icons/bs"; import { BsArrowClockwise } from "react-icons/bs";
import { useExperimentAccess } from "~/utils/hooks"; import { useExperimentAccess } from "~/utils/hooks";
@@ -12,7 +12,7 @@ export const CellOptions = ({
const { canModify } = useExperimentAccess(); const { canModify } = useExperimentAccess();
return ( return (
<HStack justifyContent="flex-end" w="full"> <HStack justifyContent="flex-end" w="full">
{!refetchingOutput && canModify && ( {canModify && (
<Tooltip label="Refetch output" aria-label="refetch output"> <Tooltip label="Refetch output" aria-label="refetch output">
<Button <Button
size="xs" size="xs"
@@ -28,7 +28,7 @@ export const CellOptions = ({
onClick={refetchOutput} onClick={refetchOutput}
aria-label="refetch output" aria-label="refetch output"
> >
<Icon as={BsArrowClockwise} boxSize={4} /> <Icon as={refetchingOutput ? Spinner : BsArrowClockwise} boxSize={4} />
</Button> </Button>
</Tooltip> </Tooltip>
)} )}

View File

@@ -1,16 +1,19 @@
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { type PromptVariant, type Scenario } from "../types"; import { type PromptVariant, type Scenario } from "../types";
import { Spinner, Text, Center, VStack } from "@chakra-ui/react"; import { Text, VStack } from "@chakra-ui/react";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import SyntaxHighlighter from "react-syntax-highlighter"; import SyntaxHighlighter from "react-syntax-highlighter";
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs"; import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
import stringify from "json-stringify-pretty-compact"; import stringify from "json-stringify-pretty-compact";
import { type ReactElement, useState, useEffect } from "react"; import { type ReactElement, useState, useEffect, Fragment } from "react";
import useSocket from "~/utils/useSocket"; import useSocket from "~/utils/useSocket";
import { OutputStats } from "./OutputStats"; import { OutputStats } from "./OutputStats";
import { ErrorHandler } from "./ErrorHandler"; import { RetryCountdown } from "./RetryCountdown";
import { CellOptions } from "./CellOptions";
import frontendModelProviders from "~/modelProviders/frontendModelProviders"; import frontendModelProviders from "~/modelProviders/frontendModelProviders";
import { ResponseLog } from "./ResponseLog";
import { CellContent } from "./CellContent";
const WAITING_MESSAGE_INTERVAL = 20000;
export default function OutputCell({ export default function OutputCell({
scenario, scenario,
@@ -65,46 +68,91 @@ export default function OutputCell({
hardRefetching; hardRefetching;
useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]); useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]);
const modelOutput = cell?.modelOutput;
// TODO: disconnect from socket if we're not streaming anymore // TODO: disconnect from socket if we're not streaming anymore
const streamedMessage = useSocket<OutputSchema>(cell?.id); const streamedMessage = useSocket<OutputSchema>(cell?.id);
if (!vars) return null; if (!vars) return null;
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
if (awaitingOutput && !streamedMessage)
return (
<Center h="100%" w="100%">
<Spinner />
</Center>
);
if (!cell && !fetchingOutput) if (!cell && !fetchingOutput)
return ( return (
<VStack> <CellContent hardRefetching={hardRefetching} hardRefetch={hardRefetch}>
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
<Text color="gray.500">Error retrieving output</Text> <Text color="gray.500">Error retrieving output</Text>
</VStack> </CellContent>
); );
if (cell && cell.errorMessage) { if (cell && cell.errorMessage) {
return ( return (
<VStack> <CellContent hardRefetching={hardRefetching} hardRefetch={hardRefetch}>
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} /> <Text color="red.500">{cell.errorMessage}</Text>
<ErrorHandler cell={cell} refetchOutput={hardRefetch} /> </CellContent>
</VStack>
); );
} }
const normalizedOutput = modelOutput if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
? provider.normalizeOutput(modelOutput.output)
const mostRecentResponse = cell?.modelResponses[cell.modelResponses.length - 1];
const showLogs = !streamedMessage && !mostRecentResponse?.output;
if (showLogs)
return (
<CellContent
hardRefetching={hardRefetching}
hardRefetch={hardRefetch}
alignItems="flex-start"
fontFamily="inconsolata, monospace"
spacing={0}
>
{cell?.jobQueuedAt && <ResponseLog time={cell.jobQueuedAt} title="Job queued" />}
{cell?.jobStartedAt && <ResponseLog time={cell.jobStartedAt} title="Job started" />}
{cell?.modelResponses?.map((response) => {
let numWaitingMessages = 0;
const relativeWaitingTime = response.receivedAt
? response.receivedAt.getTime()
: Date.now();
if (response.requestedAt) {
numWaitingMessages = Math.floor(
(relativeWaitingTime - response.requestedAt.getTime()) / WAITING_MESSAGE_INTERVAL,
);
}
return (
<Fragment key={response.id}>
{response.requestedAt && (
<ResponseLog time={response.requestedAt} title="Request sent to API" />
)}
{response.requestedAt &&
Array.from({ length: numWaitingMessages }, (_, i) => (
<ResponseLog
key={`waiting-${i}`}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
time={new Date(response.requestedAt!.getTime() + i * WAITING_MESSAGE_INTERVAL)}
title="Waiting for response"
/>
))}
{response.receivedAt && (
<ResponseLog
time={response.receivedAt}
title="Response received from API"
message={`statusCode: ${response.statusCode ?? ""}\n ${
response.errorMessage ?? ""
}`}
/>
)}
</Fragment>
);
}) ?? null}
{mostRecentResponse?.retryTime && (
<RetryCountdown retryTime={mostRecentResponse.retryTime} />
)}
</CellContent>
);
const normalizedOutput = mostRecentResponse?.output
? provider.normalizeOutput(mostRecentResponse?.output)
: streamedMessage : streamedMessage
? provider.normalizeOutput(streamedMessage) ? provider.normalizeOutput(streamedMessage)
: null; : null;
if (modelOutput && normalizedOutput?.type === "json") { if (mostRecentResponse?.output && normalizedOutput?.type === "json") {
return ( return (
<VStack <VStack
w="100%" w="100%"
@@ -114,8 +162,13 @@ export default function OutputCell({
overflowX="hidden" overflowX="hidden"
justifyContent="space-between" justifyContent="space-between"
> >
<VStack w="full" flex={1} spacing={0}> <CellContent
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} /> hardRefetching={hardRefetching}
hardRefetch={hardRefetch}
w="full"
flex={1}
spacing={0}
>
<SyntaxHighlighter <SyntaxHighlighter
customStyle={{ overflowX: "unset", width: "100%", flex: 1 }} customStyle={{ overflowX: "unset", width: "100%", flex: 1 }}
language="json" language="json"
@@ -127,8 +180,8 @@ export default function OutputCell({
> >
{stringify(normalizedOutput.value, { maxLength: 40 })} {stringify(normalizedOutput.value, { maxLength: 40 })}
</SyntaxHighlighter> </SyntaxHighlighter>
</VStack> </CellContent>
<OutputStats modelOutput={modelOutput} scenario={scenario} /> <OutputStats modelResponse={mostRecentResponse} scenario={scenario} />
</VStack> </VStack>
); );
} }
@@ -138,10 +191,13 @@ export default function OutputCell({
return ( return (
<VStack w="100%" h="100%" justifyContent="space-between" whiteSpace="pre-wrap"> <VStack w="100%" h="100%" justifyContent="space-between" whiteSpace="pre-wrap">
<VStack w="full" alignItems="flex-start" spacing={0}> <VStack w="full" alignItems="flex-start" spacing={0}>
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} /> <CellContent hardRefetching={hardRefetching} hardRefetch={hardRefetch}>
<Text>{contentToDisplay}</Text> <Text>{contentToDisplay}</Text>
</CellContent>
</VStack> </VStack>
{modelOutput && <OutputStats modelOutput={modelOutput} scenario={scenario} />} {mostRecentResponse?.output && (
<OutputStats modelResponse={mostRecentResponse} scenario={scenario} />
)}
</VStack> </VStack>
); );
} }

View File

@@ -7,28 +7,32 @@ import { CostTooltip } from "~/components/tooltip/CostTooltip";
const SHOW_TIME = true; const SHOW_TIME = true;
export const OutputStats = ({ export const OutputStats = ({
modelOutput, modelResponse,
}: { }: {
modelOutput: NonNullable< modelResponse: NonNullable<
NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelOutput"] NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelResponses"][0]
>; >;
scenario: Scenario; scenario: Scenario;
}) => { }) => {
const timeToComplete = modelOutput.timeToComplete; const timeToComplete =
modelResponse.receivedAt && modelResponse.requestedAt
? modelResponse.receivedAt.getTime() - modelResponse.requestedAt.getTime()
: 0;
const promptTokens = modelOutput.promptTokens; const promptTokens = modelResponse.promptTokens;
const completionTokens = modelOutput.completionTokens; const completionTokens = modelResponse.completionTokens;
return ( return (
<HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}> <HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
<HStack flex={1}> <HStack flex={1}>
{modelOutput.outputEvaluations.map((evaluation) => { {modelResponse.outputEvaluations.map((evaluation) => {
const passed = evaluation.result > 0.5; const passed = evaluation.result > 0.5;
return ( return (
<Tooltip <Tooltip
isDisabled={!evaluation.details} isDisabled={!evaluation.details}
label={evaluation.details} label={evaluation.details}
key={evaluation.id} key={evaluation.id}
shouldWrapChildren
> >
<HStack spacing={0}> <HStack spacing={0}>
<Text>{evaluation.evaluation.label}</Text> <Text>{evaluation.evaluation.label}</Text>
@@ -42,15 +46,15 @@ export const OutputStats = ({
); );
})} })}
</HStack> </HStack>
{modelOutput.cost && ( {modelResponse.cost && (
<CostTooltip <CostTooltip
promptTokens={promptTokens} promptTokens={promptTokens}
completionTokens={completionTokens} completionTokens={completionTokens}
cost={modelOutput.cost} cost={modelResponse.cost}
> >
<HStack spacing={0}> <HStack spacing={0}>
<Icon as={BsCurrencyDollar} /> <Icon as={BsCurrencyDollar} />
<Text mr={1}>{modelOutput.cost.toFixed(3)}</Text> <Text mr={1}>{modelResponse.cost.toFixed(3)}</Text>
</HStack> </HStack>
</CostTooltip> </CostTooltip>
)} )}

View File

@@ -0,0 +1,22 @@
import { HStack, VStack, Text } from "@chakra-ui/react";
import dayjs from "dayjs";
export const ResponseLog = ({
time,
title,
message,
}: {
time: Date;
title: string;
message?: string;
}) => {
return (
<VStack spacing={0} alignItems="flex-start">
<HStack>
<Text>{dayjs(time).format("HH:mm:ss")}</Text>
<Text>{title}</Text>
</HStack>
{message && <Text pl={4}>{message}</Text>}
</VStack>
);
};

View File

@@ -1,21 +1,12 @@
import { type ScenarioVariantCell } from "@prisma/client"; import { Text } from "@chakra-ui/react";
import { VStack, Text } from "@chakra-ui/react";
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import pluralize from "pluralize"; import pluralize from "pluralize";
export const ErrorHandler = ({ export const RetryCountdown = ({ retryTime }: { retryTime: Date }) => {
cell,
refetchOutput,
}: {
cell: ScenarioVariantCell;
refetchOutput: () => void;
}) => {
const [msToWait, setMsToWait] = useState(0); const [msToWait, setMsToWait] = useState(0);
useEffect(() => { useEffect(() => {
if (!cell.retryTime) return; const initialWaitTime = retryTime.getTime() - Date.now();
const initialWaitTime = cell.retryTime.getTime() - Date.now();
const msModuloOneSecond = initialWaitTime % 1000; const msModuloOneSecond = initialWaitTime % 1000;
let remainingTime = initialWaitTime - msModuloOneSecond; let remainingTime = initialWaitTime - msModuloOneSecond;
setMsToWait(remainingTime); setMsToWait(remainingTime);
@@ -36,18 +27,13 @@ export const ErrorHandler = ({
clearInterval(interval); clearInterval(interval);
clearTimeout(timeout); clearTimeout(timeout);
}; };
}, [cell.retryTime, cell.statusCode, setMsToWait, refetchOutput]); }, [retryTime]);
if (msToWait <= 0) return null;
return ( return (
<VStack w="full">
<Text color="red.600" wordBreak="break-word">
{cell.errorMessage}
</Text>
{msToWait > 0 && (
<Text color="red.600" fontSize="sm"> <Text color="red.600" fontSize="sm">
Retrying in {pluralize("second", Math.ceil(msToWait / 1000), true)}... Retrying in {pluralize("second", Math.ceil(msToWait / 1000), true)}...
</Text> </Text>
)}
</VStack>
); );
}; };

View File

@@ -118,7 +118,7 @@ export const experimentsRouter = createTRPCRouter({
}, },
}, },
include: { include: {
modelOutput: { modelResponses: {
include: { include: {
outputEvaluations: true, outputEvaluations: true,
}, },
@@ -177,11 +177,11 @@ export const experimentsRouter = createTRPCRouter({
} }
const cellsToCreate: Prisma.ScenarioVariantCellCreateManyInput[] = []; const cellsToCreate: Prisma.ScenarioVariantCellCreateManyInput[] = [];
const modelOutputsToCreate: Prisma.ModelOutputCreateManyInput[] = []; const modelResponsesToCreate: Prisma.ModelResponseCreateManyInput[] = [];
const outputEvaluationsToCreate: Prisma.OutputEvaluationCreateManyInput[] = []; const outputEvaluationsToCreate: Prisma.OutputEvaluationCreateManyInput[] = [];
for (const cell of existingCells) { for (const cell of existingCells) {
const newCellId = uuidv4(); const newCellId = uuidv4();
const { modelOutput, ...cellData } = cell; const { modelResponses, ...cellData } = cell;
cellsToCreate.push({ cellsToCreate.push({
...cellData, ...cellData,
id: newCellId, id: newCellId,
@@ -189,20 +189,20 @@ export const experimentsRouter = createTRPCRouter({
testScenarioId: existingToNewScenarioIds.get(cell.testScenarioId) ?? "", testScenarioId: existingToNewScenarioIds.get(cell.testScenarioId) ?? "",
prompt: (cell.prompt as Prisma.InputJsonValue) ?? undefined, prompt: (cell.prompt as Prisma.InputJsonValue) ?? undefined,
}); });
if (modelOutput) { for (const modelResponse of modelResponses) {
const newModelOutputId = uuidv4(); const newModelResponseId = uuidv4();
const { outputEvaluations, ...modelOutputData } = modelOutput; const { outputEvaluations, ...modelResponseData } = modelResponse;
modelOutputsToCreate.push({ modelResponsesToCreate.push({
...modelOutputData, ...modelResponseData,
id: newModelOutputId, id: newModelResponseId,
scenarioVariantCellId: newCellId, scenarioVariantCellId: newCellId,
output: (modelOutput.output as Prisma.InputJsonValue) ?? undefined, output: (modelResponse.output as Prisma.InputJsonValue) ?? undefined,
}); });
for (const evaluation of outputEvaluations) { for (const evaluation of outputEvaluations) {
outputEvaluationsToCreate.push({ outputEvaluationsToCreate.push({
...evaluation, ...evaluation,
id: uuidv4(), id: uuidv4(),
modelOutputId: newModelOutputId, modelResponseId: newModelResponseId,
evaluationId: existingToNewEvaluationIds.get(evaluation.evaluationId) ?? "", evaluationId: existingToNewEvaluationIds.get(evaluation.evaluationId) ?? "",
}); });
} }
@@ -245,8 +245,8 @@ export const experimentsRouter = createTRPCRouter({
prisma.scenarioVariantCell.createMany({ prisma.scenarioVariantCell.createMany({
data: cellsToCreate, data: cellsToCreate,
}), }),
prisma.modelOutput.createMany({ prisma.modelResponse.createMany({
data: modelOutputsToCreate, data: modelResponsesToCreate,
}), }),
prisma.evaluation.createMany({ prisma.evaluation.createMany({
data: evaluationsToCreate, data: evaluationsToCreate,

View File

@@ -1,6 +1,7 @@
import { z } from "zod"; import { z } from "zod";
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { Prisma } from "@prisma/client";
import { generateNewCell } from "~/server/utils/generateNewCell"; import { generateNewCell } from "~/server/utils/generateNewCell";
import userError from "~/server/utils/error"; import userError from "~/server/utils/error";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated"; import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
@@ -51,7 +52,9 @@ export const promptVariantsRouter = createTRPCRouter({
id: true, id: true,
}, },
where: { where: {
modelOutput: { modelResponse: {
outdated: false,
output: { not: Prisma.AnyNull },
scenarioVariantCell: { scenarioVariantCell: {
promptVariant: { promptVariant: {
id: input.variantId, id: input.variantId,
@@ -93,14 +96,23 @@ export const promptVariantsRouter = createTRPCRouter({
where: { where: {
promptVariantId: input.variantId, promptVariantId: input.variantId,
testScenario: { visible: true }, testScenario: { visible: true },
modelOutput: { modelResponses: {
is: {}, some: {
outdated: false,
output: {
not: Prisma.AnyNull,
},
},
}, },
}, },
}); });
const overallTokens = await prisma.modelOutput.aggregate({ const overallTokens = await prisma.modelResponse.aggregate({
where: { where: {
outdated: false,
output: {
not: Prisma.AnyNull,
},
scenarioVariantCell: { scenarioVariantCell: {
promptVariantId: input.variantId, promptVariantId: input.variantId,
testScenario: { testScenario: {

View File

@@ -27,7 +27,10 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
}, },
}, },
include: { include: {
modelOutput: { modelResponses: {
where: {
outdated: false,
},
include: { include: {
outputEvaluations: { outputEvaluations: {
include: { include: {
@@ -62,7 +65,6 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
testScenarioId: input.scenarioId, testScenarioId: input.scenarioId,
}, },
}, },
include: { modelOutput: true },
}); });
if (!cell) { if (!cell) {
@@ -70,12 +72,12 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
return; return;
} }
if (cell.modelOutput) { await prisma.modelResponse.updateMany({
// TODO: Maybe keep these around to show previous generations? where: { scenarioVariantCellId: cell.id },
await prisma.modelOutput.delete({ data: {
where: { id: cell.modelOutput.id }, outdated: true,
},
}); });
}
await queueQueryModel(cell.id, true); await queueQueryModel(cell.id, true);
}), }),

View File

@@ -29,17 +29,9 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
const { cellId, stream } = task; const { cellId, stream } = task;
const cell = await prisma.scenarioVariantCell.findUnique({ const cell = await prisma.scenarioVariantCell.findUnique({
where: { id: cellId }, where: { id: cellId },
include: { modelOutput: true }, include: { modelResponses: true },
}); });
if (!cell) { if (!cell) {
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
statusCode: 404,
errorMessage: "Cell not found",
retrievalStatus: "ERROR",
},
});
return; return;
} }
@@ -51,6 +43,7 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
where: { id: cellId }, where: { id: cellId },
data: { data: {
retrievalStatus: "IN_PROGRESS", retrievalStatus: "IN_PROGRESS",
jobStartedAt: new Date(),
}, },
}); });
@@ -61,7 +54,6 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
await prisma.scenarioVariantCell.update({ await prisma.scenarioVariantCell.update({
where: { id: cellId }, where: { id: cellId },
data: { data: {
statusCode: 404,
errorMessage: "Prompt Variant not found", errorMessage: "Prompt Variant not found",
retrievalStatus: "ERROR", retrievalStatus: "ERROR",
}, },
@@ -76,7 +68,6 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
await prisma.scenarioVariantCell.update({ await prisma.scenarioVariantCell.update({
where: { id: cellId }, where: { id: cellId },
data: { data: {
statusCode: 404,
errorMessage: "Scenario not found", errorMessage: "Scenario not found",
retrievalStatus: "ERROR", retrievalStatus: "ERROR",
}, },
@@ -90,7 +81,6 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
await prisma.scenarioVariantCell.update({ await prisma.scenarioVariantCell.update({
where: { id: cellId }, where: { id: cellId },
data: { data: {
statusCode: 400,
errorMessage: prompt.error, errorMessage: prompt.error,
retrievalStatus: "ERROR", retrievalStatus: "ERROR",
}, },
@@ -106,17 +96,24 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
} }
: null; : null;
for (let i = 0; true; i++) {
const response = await provider.getCompletion(prompt.modelInput, onStream);
if (response.type === "success") {
const inputHash = hashPrompt(prompt); const inputHash = hashPrompt(prompt);
const modelOutput = await prisma.modelOutput.create({ for (let i = 0; true; i++) {
const modelResponse = await prisma.modelResponse.create({
data: { data: {
scenarioVariantCellId: cellId,
inputHash, inputHash,
scenarioVariantCellId: cellId,
requestedAt: new Date(),
},
});
const response = await provider.getCompletion(prompt.modelInput, onStream);
if (response.type === "success") {
await prisma.modelResponse.update({
where: { id: modelResponse.id },
data: {
output: response.value as Prisma.InputJsonObject, output: response.value as Prisma.InputJsonObject,
timeToComplete: response.timeToComplete, statusCode: response.statusCode,
receivedAt: new Date(),
promptTokens: response.promptTokens, promptTokens: response.promptTokens,
completionTokens: response.completionTokens, completionTokens: response.completionTokens,
cost: response.cost, cost: response.cost,
@@ -126,30 +123,35 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
await prisma.scenarioVariantCell.update({ await prisma.scenarioVariantCell.update({
where: { id: cellId }, where: { id: cellId },
data: { data: {
statusCode: response.statusCode,
retrievalStatus: "COMPLETE", retrievalStatus: "COMPLETE",
}, },
}); });
await runEvalsForOutput(variant.experimentId, scenario, modelOutput); await runEvalsForOutput(variant.experimentId, scenario, modelResponse);
break; break;
} else { } else {
const shouldRetry = response.autoRetry && i < MAX_AUTO_RETRIES; const shouldRetry = response.autoRetry && i < MAX_AUTO_RETRIES;
const delay = calculateDelay(i); const delay = calculateDelay(i);
await prisma.scenarioVariantCell.update({ await prisma.modelResponse.update({
where: { id: cellId }, where: { id: modelResponse.id },
data: { data: {
errorMessage: response.message,
statusCode: response.statusCode, statusCode: response.statusCode,
errorMessage: response.message,
receivedAt: new Date(),
retryTime: shouldRetry ? new Date(Date.now() + delay) : null, retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
retrievalStatus: "ERROR",
}, },
}); });
if (shouldRetry) { if (shouldRetry) {
await sleep(delay); await sleep(delay);
} else { } else {
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
retrievalStatus: "ERROR",
},
});
break; break;
} }
} }
@@ -165,6 +167,7 @@ export const queueQueryModel = async (cellId: string, stream: boolean) => {
data: { data: {
retrievalStatus: "PENDING", retrievalStatus: "PENDING",
errorMessage: null, errorMessage: null,
jobQueuedAt: new Date(),
}, },
}), }),
queryModel.enqueue({ cellId, stream }), queryModel.enqueue({ cellId, stream }),

View File

@@ -1,19 +1,23 @@
import { type ModelOutput, type Evaluation } from "@prisma/client"; import { type ModelResponse, type Evaluation, Prisma } from "@prisma/client";
import { prisma } from "../db"; import { prisma } from "../db";
import { runOneEval } from "./runOneEval"; import { runOneEval } from "./runOneEval";
import { type Scenario } from "~/components/OutputsTable/types"; import { type Scenario } from "~/components/OutputsTable/types";
const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutput: ModelOutput) => { const saveResult = async (
const result = await runOneEval(evaluation, scenario, modelOutput); evaluation: Evaluation,
scenario: Scenario,
modelResponse: ModelResponse,
) => {
const result = await runOneEval(evaluation, scenario, modelResponse);
return await prisma.outputEvaluation.upsert({ return await prisma.outputEvaluation.upsert({
where: { where: {
modelOutputId_evaluationId: { modelResponseId_evaluationId: {
modelOutputId: modelOutput.id, modelResponseId: modelResponse.id,
evaluationId: evaluation.id, evaluationId: evaluation.id,
}, },
}, },
create: { create: {
modelOutputId: modelOutput.id, modelResponseId: modelResponse.id,
evaluationId: evaluation.id, evaluationId: evaluation.id,
...result, ...result,
}, },
@@ -26,20 +30,24 @@ const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutpu
export const runEvalsForOutput = async ( export const runEvalsForOutput = async (
experimentId: string, experimentId: string,
scenario: Scenario, scenario: Scenario,
modelOutput: ModelOutput, modelResponse: ModelResponse,
) => { ) => {
const evaluations = await prisma.evaluation.findMany({ const evaluations = await prisma.evaluation.findMany({
where: { experimentId }, where: { experimentId },
}); });
await Promise.all( await Promise.all(
evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelOutput)), evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelResponse)),
); );
}; };
export const runAllEvals = async (experimentId: string) => { export const runAllEvals = async (experimentId: string) => {
const outputs = await prisma.modelOutput.findMany({ const outputs = await prisma.modelResponse.findMany({
where: { where: {
outdated: false,
output: {
not: Prisma.AnyNull,
},
scenarioVariantCell: { scenarioVariantCell: {
promptVariant: { promptVariant: {
experimentId, experimentId,

View File

@@ -1,4 +1,4 @@
import { type Prisma } from "@prisma/client"; import { Prisma } from "@prisma/client";
import { prisma } from "../db"; import { prisma } from "../db";
import parseConstructFn from "./parseConstructFn"; import parseConstructFn from "./parseConstructFn";
import { type JsonObject } from "type-fest"; import { type JsonObject } from "type-fest";
@@ -35,7 +35,7 @@ export const generateNewCell = async (
}, },
}, },
include: { include: {
modelOutput: true, modelResponses: true,
}, },
}); });
@@ -51,8 +51,6 @@ export const generateNewCell = async (
data: { data: {
promptVariantId: variantId, promptVariantId: variantId,
testScenarioId: scenarioId, testScenarioId: scenarioId,
statusCode: 400,
errorMessage: parsedConstructFn.error,
retrievalStatus: "ERROR", retrievalStatus: "ERROR",
}, },
}); });
@@ -69,36 +67,55 @@ export const generateNewCell = async (
retrievalStatus: "PENDING", retrievalStatus: "PENDING",
}, },
include: { include: {
modelOutput: true, modelResponses: true,
}, },
}); });
const matchingModelOutput = await prisma.modelOutput.findFirst({ const matchingModelResponse = await prisma.modelResponse.findFirst({
where: { inputHash }, where: {
inputHash,
output: {
not: Prisma.AnyNull,
},
},
orderBy: {
receivedAt: "desc",
},
include: {
scenarioVariantCell: true,
},
take: 1,
}); });
if (matchingModelOutput) { if (matchingModelResponse) {
const newModelOutput = await prisma.modelOutput.create({ const newModelResponse = await prisma.modelResponse.create({
data: { data: {
...omit(matchingModelOutput, ["id"]), ...omit(matchingModelResponse, ["id", "scenarioVariantCell"]),
scenarioVariantCellId: cell.id, scenarioVariantCellId: cell.id,
output: matchingModelOutput.output as Prisma.InputJsonValue, output: matchingModelResponse.output as Prisma.InputJsonValue,
}, },
}); });
await prisma.scenarioVariantCell.update({ await prisma.scenarioVariantCell.update({
where: { id: cell.id }, where: { id: cell.id },
data: { retrievalStatus: "COMPLETE" }, data: {
retrievalStatus: "COMPLETE",
jobStartedAt: matchingModelResponse.scenarioVariantCell.jobStartedAt,
jobQueuedAt: matchingModelResponse.scenarioVariantCell.jobQueuedAt,
},
}); });
// Copy over all eval results as well // Copy over all eval results as well
await Promise.all( await Promise.all(
( (
await prisma.outputEvaluation.findMany({ where: { modelOutputId: matchingModelOutput.id } }) await prisma.outputEvaluation.findMany({
where: { modelResponseId: matchingModelResponse.id },
})
).map(async (evaluation) => { ).map(async (evaluation) => {
await prisma.outputEvaluation.create({ await prisma.outputEvaluation.create({
data: { data: {
...omit(evaluation, ["id"]), ...omit(evaluation, ["id"]),
modelOutputId: newModelOutput.id, modelResponseId: newModelResponse.id,
}, },
}); });
}), }),

View File

@@ -1,4 +1,4 @@
import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/client"; import { type Evaluation, type ModelResponse, type TestScenario } from "@prisma/client";
import { type ChatCompletion } from "openai/resources/chat"; import { type ChatCompletion } from "openai/resources/chat";
import { type VariableMap, fillTemplate, escapeRegExp, escapeQuotes } from "./fillTemplate"; import { type VariableMap, fillTemplate, escapeRegExp, escapeQuotes } from "./fillTemplate";
import { openai } from "./openai"; import { openai } from "./openai";
@@ -70,9 +70,9 @@ export const runGpt4Eval = async (
export const runOneEval = async ( export const runOneEval = async (
evaluation: Evaluation, evaluation: Evaluation,
scenario: TestScenario, scenario: TestScenario,
modelOutput: ModelOutput, modelResponse: ModelResponse,
): Promise<{ result: number; details?: string }> => { ): Promise<{ result: number; details?: string }> => {
const output = modelOutput.output as unknown as ChatCompletion; const output = modelResponse.output as unknown as ChatCompletion;
const message = output?.choices?.[0]?.message; const message = output?.choices?.[0]?.message;