Store multiple ModelResponses (#95)
* Store multiple ModelResponses * Fix prettier * Add CellContent container
This commit is contained in:
@@ -0,0 +1,52 @@
|
||||
-- DropForeignKey
|
||||
ALTER TABLE "ModelOutput" DROP CONSTRAINT "ModelOutput_scenarioVariantCellId_fkey";
|
||||
|
||||
-- DropForeignKey
|
||||
ALTER TABLE "OutputEvaluation" DROP CONSTRAINT "OutputEvaluation_modelOutputId_fkey";
|
||||
|
||||
-- DropIndex
|
||||
DROP INDEX "OutputEvaluation_modelOutputId_evaluationId_key";
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "OutputEvaluation" RENAME COLUMN "modelOutputId" TO "modelResponseId";
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "ScenarioVariantCell" DROP COLUMN "retryTime",
|
||||
DROP COLUMN "statusCode",
|
||||
ADD COLUMN "jobQueuedAt" TIMESTAMP(3),
|
||||
ADD COLUMN "jobStartedAt" TIMESTAMP(3);
|
||||
|
||||
ALTER TABLE "ModelOutput" RENAME TO "ModelResponse";
|
||||
|
||||
ALTER TABLE "ModelResponse"
|
||||
ADD COLUMN "requestedAt" TIMESTAMP(3),
|
||||
ADD COLUMN "receivedAt" TIMESTAMP(3),
|
||||
ADD COLUMN "statusCode" INTEGER,
|
||||
ADD COLUMN "errorMessage" TEXT,
|
||||
ADD COLUMN "retryTime" TIMESTAMP(3),
|
||||
ADD COLUMN "outdated" BOOLEAN NOT NULL DEFAULT false;
|
||||
|
||||
-- 3. Remove the unnecessary column
|
||||
ALTER TABLE "ModelResponse"
|
||||
DROP COLUMN "timeToComplete";
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "ModelResponse" RENAME CONSTRAINT "ModelOutput_pkey" TO "ModelResponse_pkey";
|
||||
ALTER TABLE "ModelResponse" ALTER COLUMN "output" DROP NOT NULL;
|
||||
|
||||
-- DropIndex
|
||||
DROP INDEX "ModelOutput_scenarioVariantCellId_key";
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "ModelResponse" ADD CONSTRAINT "ModelResponse_scenarioVariantCellId_fkey" FOREIGN KEY ("scenarioVariantCellId") REFERENCES "ScenarioVariantCell"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- RenameIndex
|
||||
ALTER INDEX "ModelOutput_inputHash_idx" RENAME TO "ModelResponse_inputHash_idx";
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "OutputEvaluation_modelResponseId_evaluationId_key" ON "OutputEvaluation"("modelResponseId", "evaluationId");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "OutputEvaluation" ADD CONSTRAINT "OutputEvaluation_modelResponseId_fkey" FOREIGN KEY ("modelResponseId") REFERENCES "ModelResponse"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
|
||||
@@ -90,12 +90,11 @@ enum CellRetrievalStatus {
|
||||
model ScenarioVariantCell {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
|
||||
statusCode Int?
|
||||
errorMessage String?
|
||||
retryTime DateTime?
|
||||
retrievalStatus CellRetrievalStatus @default(COMPLETE)
|
||||
|
||||
modelOutput ModelOutput?
|
||||
jobQueuedAt DateTime?
|
||||
jobStartedAt DateTime?
|
||||
modelResponses ModelResponse[]
|
||||
errorMessage String? // Contains errors that occurred independently of model responses
|
||||
|
||||
promptVariantId String @db.Uuid
|
||||
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
|
||||
@@ -110,15 +109,20 @@ model ScenarioVariantCell {
|
||||
@@unique([promptVariantId, testScenarioId])
|
||||
}
|
||||
|
||||
model ModelOutput {
|
||||
model ModelResponse {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
|
||||
inputHash String
|
||||
output Json
|
||||
timeToComplete Int @default(0)
|
||||
cost Float?
|
||||
promptTokens Int?
|
||||
completionTokens Int?
|
||||
inputHash String
|
||||
requestedAt DateTime?
|
||||
receivedAt DateTime?
|
||||
output Json?
|
||||
cost Float?
|
||||
promptTokens Int?
|
||||
completionTokens Int?
|
||||
statusCode Int?
|
||||
errorMessage String?
|
||||
retryTime DateTime?
|
||||
outdated Boolean @default(false)
|
||||
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
@@ -127,7 +131,6 @@ model ModelOutput {
|
||||
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
|
||||
outputEvaluations OutputEvaluation[]
|
||||
|
||||
@@unique([scenarioVariantCellId])
|
||||
@@index([inputHash])
|
||||
}
|
||||
|
||||
@@ -159,8 +162,8 @@ model OutputEvaluation {
|
||||
result Float
|
||||
details String?
|
||||
|
||||
modelOutputId String @db.Uuid
|
||||
modelOutput ModelOutput @relation(fields: [modelOutputId], references: [id], onDelete: Cascade)
|
||||
modelResponseId String @db.Uuid
|
||||
modelResponse ModelResponse @relation(fields: [modelResponseId], references: [id], onDelete: Cascade)
|
||||
|
||||
evaluationId String @db.Uuid
|
||||
evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade)
|
||||
@@ -168,7 +171,7 @@ model OutputEvaluation {
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
@@unique([modelOutputId, evaluationId])
|
||||
@@unique([modelResponseId, evaluationId])
|
||||
}
|
||||
|
||||
model Organization {
|
||||
|
||||
17
src/components/OutputsTable/OutputCell/CellContent.tsx
Normal file
17
src/components/OutputsTable/OutputCell/CellContent.tsx
Normal file
@@ -0,0 +1,17 @@
|
||||
import { type StackProps, VStack } from "@chakra-ui/react";
|
||||
import { CellOptions } from "./CellOptions";
|
||||
|
||||
export const CellContent = ({
|
||||
hardRefetch,
|
||||
hardRefetching,
|
||||
children,
|
||||
...props
|
||||
}: {
|
||||
hardRefetch: () => void;
|
||||
hardRefetching: boolean;
|
||||
} & StackProps) => (
|
||||
<VStack maxH={500} w="full" overflowY="auto" alignItems="flex-start" {...props}>
|
||||
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
|
||||
{children}
|
||||
</VStack>
|
||||
);
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Button, HStack, Icon, Tooltip } from "@chakra-ui/react";
|
||||
import { Button, HStack, Icon, Spinner, Tooltip } from "@chakra-ui/react";
|
||||
import { BsArrowClockwise } from "react-icons/bs";
|
||||
import { useExperimentAccess } from "~/utils/hooks";
|
||||
|
||||
@@ -12,7 +12,7 @@ export const CellOptions = ({
|
||||
const { canModify } = useExperimentAccess();
|
||||
return (
|
||||
<HStack justifyContent="flex-end" w="full">
|
||||
{!refetchingOutput && canModify && (
|
||||
{canModify && (
|
||||
<Tooltip label="Refetch output" aria-label="refetch output">
|
||||
<Button
|
||||
size="xs"
|
||||
@@ -28,7 +28,7 @@ export const CellOptions = ({
|
||||
onClick={refetchOutput}
|
||||
aria-label="refetch output"
|
||||
>
|
||||
<Icon as={BsArrowClockwise} boxSize={4} />
|
||||
<Icon as={refetchingOutput ? Spinner : BsArrowClockwise} boxSize={4} />
|
||||
</Button>
|
||||
</Tooltip>
|
||||
)}
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
import { api } from "~/utils/api";
|
||||
import { type PromptVariant, type Scenario } from "../types";
|
||||
import { Spinner, Text, Center, VStack } from "@chakra-ui/react";
|
||||
import { Text, VStack } from "@chakra-ui/react";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import SyntaxHighlighter from "react-syntax-highlighter";
|
||||
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
|
||||
import stringify from "json-stringify-pretty-compact";
|
||||
import { type ReactElement, useState, useEffect } from "react";
|
||||
import { type ReactElement, useState, useEffect, Fragment } from "react";
|
||||
import useSocket from "~/utils/useSocket";
|
||||
import { OutputStats } from "./OutputStats";
|
||||
import { ErrorHandler } from "./ErrorHandler";
|
||||
import { CellOptions } from "./CellOptions";
|
||||
import { RetryCountdown } from "./RetryCountdown";
|
||||
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
|
||||
import { ResponseLog } from "./ResponseLog";
|
||||
import { CellContent } from "./CellContent";
|
||||
|
||||
const WAITING_MESSAGE_INTERVAL = 20000;
|
||||
|
||||
export default function OutputCell({
|
||||
scenario,
|
||||
@@ -65,46 +68,91 @@ export default function OutputCell({
|
||||
hardRefetching;
|
||||
useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]);
|
||||
|
||||
const modelOutput = cell?.modelOutput;
|
||||
|
||||
// TODO: disconnect from socket if we're not streaming anymore
|
||||
const streamedMessage = useSocket<OutputSchema>(cell?.id);
|
||||
|
||||
if (!vars) return null;
|
||||
|
||||
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
|
||||
|
||||
if (awaitingOutput && !streamedMessage)
|
||||
return (
|
||||
<Center h="100%" w="100%">
|
||||
<Spinner />
|
||||
</Center>
|
||||
);
|
||||
|
||||
if (!cell && !fetchingOutput)
|
||||
return (
|
||||
<VStack>
|
||||
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
|
||||
<CellContent hardRefetching={hardRefetching} hardRefetch={hardRefetch}>
|
||||
<Text color="gray.500">Error retrieving output</Text>
|
||||
</VStack>
|
||||
</CellContent>
|
||||
);
|
||||
|
||||
if (cell && cell.errorMessage) {
|
||||
return (
|
||||
<VStack>
|
||||
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
|
||||
<ErrorHandler cell={cell} refetchOutput={hardRefetch} />
|
||||
</VStack>
|
||||
<CellContent hardRefetching={hardRefetching} hardRefetch={hardRefetch}>
|
||||
<Text color="red.500">{cell.errorMessage}</Text>
|
||||
</CellContent>
|
||||
);
|
||||
}
|
||||
|
||||
const normalizedOutput = modelOutput
|
||||
? provider.normalizeOutput(modelOutput.output)
|
||||
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
|
||||
|
||||
const mostRecentResponse = cell?.modelResponses[cell.modelResponses.length - 1];
|
||||
const showLogs = !streamedMessage && !mostRecentResponse?.output;
|
||||
|
||||
if (showLogs)
|
||||
return (
|
||||
<CellContent
|
||||
hardRefetching={hardRefetching}
|
||||
hardRefetch={hardRefetch}
|
||||
alignItems="flex-start"
|
||||
fontFamily="inconsolata, monospace"
|
||||
spacing={0}
|
||||
>
|
||||
{cell?.jobQueuedAt && <ResponseLog time={cell.jobQueuedAt} title="Job queued" />}
|
||||
{cell?.jobStartedAt && <ResponseLog time={cell.jobStartedAt} title="Job started" />}
|
||||
{cell?.modelResponses?.map((response) => {
|
||||
let numWaitingMessages = 0;
|
||||
const relativeWaitingTime = response.receivedAt
|
||||
? response.receivedAt.getTime()
|
||||
: Date.now();
|
||||
if (response.requestedAt) {
|
||||
numWaitingMessages = Math.floor(
|
||||
(relativeWaitingTime - response.requestedAt.getTime()) / WAITING_MESSAGE_INTERVAL,
|
||||
);
|
||||
}
|
||||
return (
|
||||
<Fragment key={response.id}>
|
||||
{response.requestedAt && (
|
||||
<ResponseLog time={response.requestedAt} title="Request sent to API" />
|
||||
)}
|
||||
{response.requestedAt &&
|
||||
Array.from({ length: numWaitingMessages }, (_, i) => (
|
||||
<ResponseLog
|
||||
key={`waiting-${i}`}
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
time={new Date(response.requestedAt!.getTime() + i * WAITING_MESSAGE_INTERVAL)}
|
||||
title="Waiting for response"
|
||||
/>
|
||||
))}
|
||||
{response.receivedAt && (
|
||||
<ResponseLog
|
||||
time={response.receivedAt}
|
||||
title="Response received from API"
|
||||
message={`statusCode: ${response.statusCode ?? ""}\n ${
|
||||
response.errorMessage ?? ""
|
||||
}`}
|
||||
/>
|
||||
)}
|
||||
</Fragment>
|
||||
);
|
||||
}) ?? null}
|
||||
{mostRecentResponse?.retryTime && (
|
||||
<RetryCountdown retryTime={mostRecentResponse.retryTime} />
|
||||
)}
|
||||
</CellContent>
|
||||
);
|
||||
|
||||
const normalizedOutput = mostRecentResponse?.output
|
||||
? provider.normalizeOutput(mostRecentResponse?.output)
|
||||
: streamedMessage
|
||||
? provider.normalizeOutput(streamedMessage)
|
||||
: null;
|
||||
|
||||
if (modelOutput && normalizedOutput?.type === "json") {
|
||||
if (mostRecentResponse?.output && normalizedOutput?.type === "json") {
|
||||
return (
|
||||
<VStack
|
||||
w="100%"
|
||||
@@ -114,8 +162,13 @@ export default function OutputCell({
|
||||
overflowX="hidden"
|
||||
justifyContent="space-between"
|
||||
>
|
||||
<VStack w="full" flex={1} spacing={0}>
|
||||
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
|
||||
<CellContent
|
||||
hardRefetching={hardRefetching}
|
||||
hardRefetch={hardRefetch}
|
||||
w="full"
|
||||
flex={1}
|
||||
spacing={0}
|
||||
>
|
||||
<SyntaxHighlighter
|
||||
customStyle={{ overflowX: "unset", width: "100%", flex: 1 }}
|
||||
language="json"
|
||||
@@ -127,8 +180,8 @@ export default function OutputCell({
|
||||
>
|
||||
{stringify(normalizedOutput.value, { maxLength: 40 })}
|
||||
</SyntaxHighlighter>
|
||||
</VStack>
|
||||
<OutputStats modelOutput={modelOutput} scenario={scenario} />
|
||||
</CellContent>
|
||||
<OutputStats modelResponse={mostRecentResponse} scenario={scenario} />
|
||||
</VStack>
|
||||
);
|
||||
}
|
||||
@@ -138,10 +191,13 @@ export default function OutputCell({
|
||||
return (
|
||||
<VStack w="100%" h="100%" justifyContent="space-between" whiteSpace="pre-wrap">
|
||||
<VStack w="full" alignItems="flex-start" spacing={0}>
|
||||
<CellOptions refetchingOutput={hardRefetching} refetchOutput={hardRefetch} />
|
||||
<Text>{contentToDisplay}</Text>
|
||||
<CellContent hardRefetching={hardRefetching} hardRefetch={hardRefetch}>
|
||||
<Text>{contentToDisplay}</Text>
|
||||
</CellContent>
|
||||
</VStack>
|
||||
{modelOutput && <OutputStats modelOutput={modelOutput} scenario={scenario} />}
|
||||
{mostRecentResponse?.output && (
|
||||
<OutputStats modelResponse={mostRecentResponse} scenario={scenario} />
|
||||
)}
|
||||
</VStack>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -7,28 +7,32 @@ import { CostTooltip } from "~/components/tooltip/CostTooltip";
|
||||
const SHOW_TIME = true;
|
||||
|
||||
export const OutputStats = ({
|
||||
modelOutput,
|
||||
modelResponse,
|
||||
}: {
|
||||
modelOutput: NonNullable<
|
||||
NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelOutput"]
|
||||
modelResponse: NonNullable<
|
||||
NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelResponses"][0]
|
||||
>;
|
||||
scenario: Scenario;
|
||||
}) => {
|
||||
const timeToComplete = modelOutput.timeToComplete;
|
||||
const timeToComplete =
|
||||
modelResponse.receivedAt && modelResponse.requestedAt
|
||||
? modelResponse.receivedAt.getTime() - modelResponse.requestedAt.getTime()
|
||||
: 0;
|
||||
|
||||
const promptTokens = modelOutput.promptTokens;
|
||||
const completionTokens = modelOutput.completionTokens;
|
||||
const promptTokens = modelResponse.promptTokens;
|
||||
const completionTokens = modelResponse.completionTokens;
|
||||
|
||||
return (
|
||||
<HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
|
||||
<HStack flex={1}>
|
||||
{modelOutput.outputEvaluations.map((evaluation) => {
|
||||
{modelResponse.outputEvaluations.map((evaluation) => {
|
||||
const passed = evaluation.result > 0.5;
|
||||
return (
|
||||
<Tooltip
|
||||
isDisabled={!evaluation.details}
|
||||
label={evaluation.details}
|
||||
key={evaluation.id}
|
||||
shouldWrapChildren
|
||||
>
|
||||
<HStack spacing={0}>
|
||||
<Text>{evaluation.evaluation.label}</Text>
|
||||
@@ -42,15 +46,15 @@ export const OutputStats = ({
|
||||
);
|
||||
})}
|
||||
</HStack>
|
||||
{modelOutput.cost && (
|
||||
{modelResponse.cost && (
|
||||
<CostTooltip
|
||||
promptTokens={promptTokens}
|
||||
completionTokens={completionTokens}
|
||||
cost={modelOutput.cost}
|
||||
cost={modelResponse.cost}
|
||||
>
|
||||
<HStack spacing={0}>
|
||||
<Icon as={BsCurrencyDollar} />
|
||||
<Text mr={1}>{modelOutput.cost.toFixed(3)}</Text>
|
||||
<Text mr={1}>{modelResponse.cost.toFixed(3)}</Text>
|
||||
</HStack>
|
||||
</CostTooltip>
|
||||
)}
|
||||
|
||||
22
src/components/OutputsTable/OutputCell/ResponseLog.tsx
Normal file
22
src/components/OutputsTable/OutputCell/ResponseLog.tsx
Normal file
@@ -0,0 +1,22 @@
|
||||
import { HStack, VStack, Text } from "@chakra-ui/react";
|
||||
import dayjs from "dayjs";
|
||||
|
||||
export const ResponseLog = ({
|
||||
time,
|
||||
title,
|
||||
message,
|
||||
}: {
|
||||
time: Date;
|
||||
title: string;
|
||||
message?: string;
|
||||
}) => {
|
||||
return (
|
||||
<VStack spacing={0} alignItems="flex-start">
|
||||
<HStack>
|
||||
<Text>{dayjs(time).format("HH:mm:ss")}</Text>
|
||||
<Text>{title}</Text>
|
||||
</HStack>
|
||||
{message && <Text pl={4}>{message}</Text>}
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
@@ -1,21 +1,12 @@
|
||||
import { type ScenarioVariantCell } from "@prisma/client";
|
||||
import { VStack, Text } from "@chakra-ui/react";
|
||||
import { Text } from "@chakra-ui/react";
|
||||
import { useEffect, useState } from "react";
|
||||
import pluralize from "pluralize";
|
||||
|
||||
export const ErrorHandler = ({
|
||||
cell,
|
||||
refetchOutput,
|
||||
}: {
|
||||
cell: ScenarioVariantCell;
|
||||
refetchOutput: () => void;
|
||||
}) => {
|
||||
export const RetryCountdown = ({ retryTime }: { retryTime: Date }) => {
|
||||
const [msToWait, setMsToWait] = useState(0);
|
||||
|
||||
useEffect(() => {
|
||||
if (!cell.retryTime) return;
|
||||
|
||||
const initialWaitTime = cell.retryTime.getTime() - Date.now();
|
||||
const initialWaitTime = retryTime.getTime() - Date.now();
|
||||
const msModuloOneSecond = initialWaitTime % 1000;
|
||||
let remainingTime = initialWaitTime - msModuloOneSecond;
|
||||
setMsToWait(remainingTime);
|
||||
@@ -36,18 +27,13 @@ export const ErrorHandler = ({
|
||||
clearInterval(interval);
|
||||
clearTimeout(timeout);
|
||||
};
|
||||
}, [cell.retryTime, cell.statusCode, setMsToWait, refetchOutput]);
|
||||
}, [retryTime]);
|
||||
|
||||
if (msToWait <= 0) return null;
|
||||
|
||||
return (
|
||||
<VStack w="full">
|
||||
<Text color="red.600" wordBreak="break-word">
|
||||
{cell.errorMessage}
|
||||
</Text>
|
||||
{msToWait > 0 && (
|
||||
<Text color="red.600" fontSize="sm">
|
||||
Retrying in {pluralize("second", Math.ceil(msToWait / 1000), true)}...
|
||||
</Text>
|
||||
)}
|
||||
</VStack>
|
||||
<Text color="red.600" fontSize="sm">
|
||||
Retrying in {pluralize("second", Math.ceil(msToWait / 1000), true)}...
|
||||
</Text>
|
||||
);
|
||||
};
|
||||
@@ -118,7 +118,7 @@ export const experimentsRouter = createTRPCRouter({
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelOutput: {
|
||||
modelResponses: {
|
||||
include: {
|
||||
outputEvaluations: true,
|
||||
},
|
||||
@@ -177,11 +177,11 @@ export const experimentsRouter = createTRPCRouter({
|
||||
}
|
||||
|
||||
const cellsToCreate: Prisma.ScenarioVariantCellCreateManyInput[] = [];
|
||||
const modelOutputsToCreate: Prisma.ModelOutputCreateManyInput[] = [];
|
||||
const modelResponsesToCreate: Prisma.ModelResponseCreateManyInput[] = [];
|
||||
const outputEvaluationsToCreate: Prisma.OutputEvaluationCreateManyInput[] = [];
|
||||
for (const cell of existingCells) {
|
||||
const newCellId = uuidv4();
|
||||
const { modelOutput, ...cellData } = cell;
|
||||
const { modelResponses, ...cellData } = cell;
|
||||
cellsToCreate.push({
|
||||
...cellData,
|
||||
id: newCellId,
|
||||
@@ -189,20 +189,20 @@ export const experimentsRouter = createTRPCRouter({
|
||||
testScenarioId: existingToNewScenarioIds.get(cell.testScenarioId) ?? "",
|
||||
prompt: (cell.prompt as Prisma.InputJsonValue) ?? undefined,
|
||||
});
|
||||
if (modelOutput) {
|
||||
const newModelOutputId = uuidv4();
|
||||
const { outputEvaluations, ...modelOutputData } = modelOutput;
|
||||
modelOutputsToCreate.push({
|
||||
...modelOutputData,
|
||||
id: newModelOutputId,
|
||||
for (const modelResponse of modelResponses) {
|
||||
const newModelResponseId = uuidv4();
|
||||
const { outputEvaluations, ...modelResponseData } = modelResponse;
|
||||
modelResponsesToCreate.push({
|
||||
...modelResponseData,
|
||||
id: newModelResponseId,
|
||||
scenarioVariantCellId: newCellId,
|
||||
output: (modelOutput.output as Prisma.InputJsonValue) ?? undefined,
|
||||
output: (modelResponse.output as Prisma.InputJsonValue) ?? undefined,
|
||||
});
|
||||
for (const evaluation of outputEvaluations) {
|
||||
outputEvaluationsToCreate.push({
|
||||
...evaluation,
|
||||
id: uuidv4(),
|
||||
modelOutputId: newModelOutputId,
|
||||
modelResponseId: newModelResponseId,
|
||||
evaluationId: existingToNewEvaluationIds.get(evaluation.evaluationId) ?? "",
|
||||
});
|
||||
}
|
||||
@@ -245,8 +245,8 @@ export const experimentsRouter = createTRPCRouter({
|
||||
prisma.scenarioVariantCell.createMany({
|
||||
data: cellsToCreate,
|
||||
}),
|
||||
prisma.modelOutput.createMany({
|
||||
data: modelOutputsToCreate,
|
||||
prisma.modelResponse.createMany({
|
||||
data: modelResponsesToCreate,
|
||||
}),
|
||||
prisma.evaluation.createMany({
|
||||
data: evaluationsToCreate,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
import userError from "~/server/utils/error";
|
||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||
@@ -51,7 +52,9 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
id: true,
|
||||
},
|
||||
where: {
|
||||
modelOutput: {
|
||||
modelResponse: {
|
||||
outdated: false,
|
||||
output: { not: Prisma.AnyNull },
|
||||
scenarioVariantCell: {
|
||||
promptVariant: {
|
||||
id: input.variantId,
|
||||
@@ -93,14 +96,23 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
where: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenario: { visible: true },
|
||||
modelOutput: {
|
||||
is: {},
|
||||
modelResponses: {
|
||||
some: {
|
||||
outdated: false,
|
||||
output: {
|
||||
not: Prisma.AnyNull,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const overallTokens = await prisma.modelOutput.aggregate({
|
||||
const overallTokens = await prisma.modelResponse.aggregate({
|
||||
where: {
|
||||
outdated: false,
|
||||
output: {
|
||||
not: Prisma.AnyNull,
|
||||
},
|
||||
scenarioVariantCell: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenario: {
|
||||
|
||||
@@ -27,7 +27,10 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelOutput: {
|
||||
modelResponses: {
|
||||
where: {
|
||||
outdated: false,
|
||||
},
|
||||
include: {
|
||||
outputEvaluations: {
|
||||
include: {
|
||||
@@ -62,7 +65,6 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
|
||||
testScenarioId: input.scenarioId,
|
||||
},
|
||||
},
|
||||
include: { modelOutput: true },
|
||||
});
|
||||
|
||||
if (!cell) {
|
||||
@@ -70,12 +72,12 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
|
||||
return;
|
||||
}
|
||||
|
||||
if (cell.modelOutput) {
|
||||
// TODO: Maybe keep these around to show previous generations?
|
||||
await prisma.modelOutput.delete({
|
||||
where: { id: cell.modelOutput.id },
|
||||
});
|
||||
}
|
||||
await prisma.modelResponse.updateMany({
|
||||
where: { scenarioVariantCellId: cell.id },
|
||||
data: {
|
||||
outdated: true,
|
||||
},
|
||||
});
|
||||
|
||||
await queueQueryModel(cell.id, true);
|
||||
}),
|
||||
|
||||
@@ -29,17 +29,9 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
|
||||
const { cellId, stream } = task;
|
||||
const cell = await prisma.scenarioVariantCell.findUnique({
|
||||
where: { id: cellId },
|
||||
include: { modelOutput: true },
|
||||
include: { modelResponses: true },
|
||||
});
|
||||
if (!cell) {
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
statusCode: 404,
|
||||
errorMessage: "Cell not found",
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -51,6 +43,7 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
retrievalStatus: "IN_PROGRESS",
|
||||
jobStartedAt: new Date(),
|
||||
},
|
||||
});
|
||||
|
||||
@@ -61,7 +54,6 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
statusCode: 404,
|
||||
errorMessage: "Prompt Variant not found",
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
@@ -76,7 +68,6 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
statusCode: 404,
|
||||
errorMessage: "Scenario not found",
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
@@ -90,7 +81,6 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
statusCode: 400,
|
||||
errorMessage: prompt.error,
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
@@ -106,17 +96,24 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
|
||||
}
|
||||
: null;
|
||||
|
||||
const inputHash = hashPrompt(prompt);
|
||||
|
||||
for (let i = 0; true; i++) {
|
||||
const modelResponse = await prisma.modelResponse.create({
|
||||
data: {
|
||||
inputHash,
|
||||
scenarioVariantCellId: cellId,
|
||||
requestedAt: new Date(),
|
||||
},
|
||||
});
|
||||
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
||||
if (response.type === "success") {
|
||||
const inputHash = hashPrompt(prompt);
|
||||
|
||||
const modelOutput = await prisma.modelOutput.create({
|
||||
await prisma.modelResponse.update({
|
||||
where: { id: modelResponse.id },
|
||||
data: {
|
||||
scenarioVariantCellId: cellId,
|
||||
inputHash,
|
||||
output: response.value as Prisma.InputJsonObject,
|
||||
timeToComplete: response.timeToComplete,
|
||||
statusCode: response.statusCode,
|
||||
receivedAt: new Date(),
|
||||
promptTokens: response.promptTokens,
|
||||
completionTokens: response.completionTokens,
|
||||
cost: response.cost,
|
||||
@@ -126,30 +123,35 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
statusCode: response.statusCode,
|
||||
retrievalStatus: "COMPLETE",
|
||||
},
|
||||
});
|
||||
|
||||
await runEvalsForOutput(variant.experimentId, scenario, modelOutput);
|
||||
await runEvalsForOutput(variant.experimentId, scenario, modelResponse);
|
||||
break;
|
||||
} else {
|
||||
const shouldRetry = response.autoRetry && i < MAX_AUTO_RETRIES;
|
||||
const delay = calculateDelay(i);
|
||||
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
await prisma.modelResponse.update({
|
||||
where: { id: modelResponse.id },
|
||||
data: {
|
||||
errorMessage: response.message,
|
||||
statusCode: response.statusCode,
|
||||
errorMessage: response.message,
|
||||
receivedAt: new Date(),
|
||||
retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
|
||||
if (shouldRetry) {
|
||||
await sleep(delay);
|
||||
} else {
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -165,6 +167,7 @@ export const queueQueryModel = async (cellId: string, stream: boolean) => {
|
||||
data: {
|
||||
retrievalStatus: "PENDING",
|
||||
errorMessage: null,
|
||||
jobQueuedAt: new Date(),
|
||||
},
|
||||
}),
|
||||
queryModel.enqueue({ cellId, stream }),
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
import { type ModelOutput, type Evaluation } from "@prisma/client";
|
||||
import { type ModelResponse, type Evaluation, Prisma } from "@prisma/client";
|
||||
import { prisma } from "../db";
|
||||
import { runOneEval } from "./runOneEval";
|
||||
import { type Scenario } from "~/components/OutputsTable/types";
|
||||
|
||||
const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutput: ModelOutput) => {
|
||||
const result = await runOneEval(evaluation, scenario, modelOutput);
|
||||
const saveResult = async (
|
||||
evaluation: Evaluation,
|
||||
scenario: Scenario,
|
||||
modelResponse: ModelResponse,
|
||||
) => {
|
||||
const result = await runOneEval(evaluation, scenario, modelResponse);
|
||||
return await prisma.outputEvaluation.upsert({
|
||||
where: {
|
||||
modelOutputId_evaluationId: {
|
||||
modelOutputId: modelOutput.id,
|
||||
modelResponseId_evaluationId: {
|
||||
modelResponseId: modelResponse.id,
|
||||
evaluationId: evaluation.id,
|
||||
},
|
||||
},
|
||||
create: {
|
||||
modelOutputId: modelOutput.id,
|
||||
modelResponseId: modelResponse.id,
|
||||
evaluationId: evaluation.id,
|
||||
...result,
|
||||
},
|
||||
@@ -26,20 +30,24 @@ const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutpu
|
||||
export const runEvalsForOutput = async (
|
||||
experimentId: string,
|
||||
scenario: Scenario,
|
||||
modelOutput: ModelOutput,
|
||||
modelResponse: ModelResponse,
|
||||
) => {
|
||||
const evaluations = await prisma.evaluation.findMany({
|
||||
where: { experimentId },
|
||||
});
|
||||
|
||||
await Promise.all(
|
||||
evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelOutput)),
|
||||
evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelResponse)),
|
||||
);
|
||||
};
|
||||
|
||||
export const runAllEvals = async (experimentId: string) => {
|
||||
const outputs = await prisma.modelOutput.findMany({
|
||||
const outputs = await prisma.modelResponse.findMany({
|
||||
where: {
|
||||
outdated: false,
|
||||
output: {
|
||||
not: Prisma.AnyNull,
|
||||
},
|
||||
scenarioVariantCell: {
|
||||
promptVariant: {
|
||||
experimentId,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { type Prisma } from "@prisma/client";
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { prisma } from "../db";
|
||||
import parseConstructFn from "./parseConstructFn";
|
||||
import { type JsonObject } from "type-fest";
|
||||
@@ -35,7 +35,7 @@ export const generateNewCell = async (
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelOutput: true,
|
||||
modelResponses: true,
|
||||
},
|
||||
});
|
||||
|
||||
@@ -51,8 +51,6 @@ export const generateNewCell = async (
|
||||
data: {
|
||||
promptVariantId: variantId,
|
||||
testScenarioId: scenarioId,
|
||||
statusCode: 400,
|
||||
errorMessage: parsedConstructFn.error,
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
@@ -69,36 +67,55 @@ export const generateNewCell = async (
|
||||
retrievalStatus: "PENDING",
|
||||
},
|
||||
include: {
|
||||
modelOutput: true,
|
||||
modelResponses: true,
|
||||
},
|
||||
});
|
||||
|
||||
const matchingModelOutput = await prisma.modelOutput.findFirst({
|
||||
where: { inputHash },
|
||||
const matchingModelResponse = await prisma.modelResponse.findFirst({
|
||||
where: {
|
||||
inputHash,
|
||||
output: {
|
||||
not: Prisma.AnyNull,
|
||||
},
|
||||
},
|
||||
orderBy: {
|
||||
receivedAt: "desc",
|
||||
},
|
||||
include: {
|
||||
scenarioVariantCell: true,
|
||||
},
|
||||
take: 1,
|
||||
});
|
||||
|
||||
if (matchingModelOutput) {
|
||||
const newModelOutput = await prisma.modelOutput.create({
|
||||
if (matchingModelResponse) {
|
||||
const newModelResponse = await prisma.modelResponse.create({
|
||||
data: {
|
||||
...omit(matchingModelOutput, ["id"]),
|
||||
...omit(matchingModelResponse, ["id", "scenarioVariantCell"]),
|
||||
scenarioVariantCellId: cell.id,
|
||||
output: matchingModelOutput.output as Prisma.InputJsonValue,
|
||||
output: matchingModelResponse.output as Prisma.InputJsonValue,
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cell.id },
|
||||
data: { retrievalStatus: "COMPLETE" },
|
||||
data: {
|
||||
retrievalStatus: "COMPLETE",
|
||||
jobStartedAt: matchingModelResponse.scenarioVariantCell.jobStartedAt,
|
||||
jobQueuedAt: matchingModelResponse.scenarioVariantCell.jobQueuedAt,
|
||||
},
|
||||
});
|
||||
|
||||
// Copy over all eval results as well
|
||||
await Promise.all(
|
||||
(
|
||||
await prisma.outputEvaluation.findMany({ where: { modelOutputId: matchingModelOutput.id } })
|
||||
await prisma.outputEvaluation.findMany({
|
||||
where: { modelResponseId: matchingModelResponse.id },
|
||||
})
|
||||
).map(async (evaluation) => {
|
||||
await prisma.outputEvaluation.create({
|
||||
data: {
|
||||
...omit(evaluation, ["id"]),
|
||||
modelOutputId: newModelOutput.id,
|
||||
modelResponseId: newModelResponse.id,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/client";
|
||||
import { type Evaluation, type ModelResponse, type TestScenario } from "@prisma/client";
|
||||
import { type ChatCompletion } from "openai/resources/chat";
|
||||
import { type VariableMap, fillTemplate, escapeRegExp, escapeQuotes } from "./fillTemplate";
|
||||
import { openai } from "./openai";
|
||||
@@ -70,9 +70,9 @@ export const runGpt4Eval = async (
|
||||
export const runOneEval = async (
|
||||
evaluation: Evaluation,
|
||||
scenario: TestScenario,
|
||||
modelOutput: ModelOutput,
|
||||
modelResponse: ModelResponse,
|
||||
): Promise<{ result: number; details?: string }> => {
|
||||
const output = modelOutput.output as unknown as ChatCompletion;
|
||||
const output = modelResponse.output as unknown as ChatCompletion;
|
||||
|
||||
const message = output?.choices?.[0]?.message;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user