Trigger llm output retrieval on server (#39)

* Rename tables, add graphile workers, update types

* Add dev:worker command

* Update pnpm-lock.yaml

* Remove sentry config import from worker.ts

* Stop generating new cells in cell router get query

* Generate new cells for new scenarios, variants, and experiments

* Remove most error throwing from queryLLM.task.ts

* Remove promptVariantId and testScenarioId from ModelOutput

* Remove duplicate index from ModelOutput

* Move inputHash from cell to output

* Add TODO

* Add todo

* Show cost and time for each cell

* Always show output stats if there is output

* Trigger LLM outputs when scenario variables are updated

* Add newlines to ends of files

* Add another newline

* Cascade ModelOutput deletion

* Fix linting and prettier

* Return instead of throwing for non-pending cell

* Remove pnpm dev:worker from pnpm:dev

* Update pnpm-lock.yaml
This commit is contained in:
arcticfly
2023-07-14 16:38:46 -06:00
committed by GitHub
parent 032c07ec65
commit b98eb9b729
29 changed files with 1089 additions and 407 deletions

View File

@@ -11,6 +11,7 @@
"build": "next build", "build": "next build",
"dev:next": "next dev", "dev:next": "next dev",
"dev:wss": "pnpm tsx --watch src/wss-server.ts", "dev:wss": "pnpm tsx --watch src/wss-server.ts",
"dev:worker": "NODE_ENV='development' pnpm tsx --watch src/server/tasks/worker.ts",
"dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss'", "dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss'",
"postinstall": "prisma generate", "postinstall": "prisma generate",
"lint": "next lint", "lint": "next lint",
@@ -44,6 +45,7 @@
"express": "^4.18.2", "express": "^4.18.2",
"framer-motion": "^10.12.17", "framer-motion": "^10.12.17",
"gpt-tokens": "^1.0.10", "gpt-tokens": "^1.0.10",
"graphile-worker": "^0.13.0",
"immer": "^10.0.2", "immer": "^10.0.2",
"isolated-vm": "^4.5.0", "isolated-vm": "^4.5.0",
"json-stringify-pretty-compact": "^4.0.0", "json-stringify-pretty-compact": "^4.0.0",

514
pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,49 @@
-- Drop the foreign key constraints on the original ModelOutput
ALTER TABLE "ModelOutput" DROP CONSTRAINT "ModelOutput_promptVariantId_fkey";
ALTER TABLE "ModelOutput" DROP CONSTRAINT "ModelOutput_testScenarioId_fkey";
-- Rename the old table
ALTER TABLE "ModelOutput" RENAME TO "ScenarioVariantCell";
ALTER TABLE "ScenarioVariantCell" RENAME CONSTRAINT "ModelOutput_pkey" TO "ScenarioVariantCell_pkey";
ALTER INDEX "ModelOutput_inputHash_idx" RENAME TO "ScenarioVariantCell_inputHash_idx";
ALTER INDEX "ModelOutput_promptVariantId_testScenarioId_key" RENAME TO "ScenarioVariantCell_promptVariantId_testScenarioId_key";
-- Add the new fields to the renamed table
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "retryTime" TIMESTAMP(3);
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "streamingChannel" TEXT;
ALTER TABLE "ScenarioVariantCell" ALTER COLUMN "inputHash" DROP NOT NULL;
ALTER TABLE "ScenarioVariantCell" ALTER COLUMN "output" DROP NOT NULL,
ALTER COLUMN "statusCode" DROP NOT NULL,
ALTER COLUMN "timeToComplete" DROP NOT NULL;
-- Create the new table
CREATE TABLE "ModelOutput" (
"id" UUID NOT NULL,
"inputHash" TEXT NOT NULL,
"output" JSONB NOT NULL,
"timeToComplete" INTEGER NOT NULL DEFAULT 0,
"promptTokens" INTEGER,
"completionTokens" INTEGER,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"scenarioVariantCellId" UUID
);
-- Move inputHash index
DROP INDEX "ScenarioVariantCell_inputHash_idx";
CREATE INDEX "ModelOutput_inputHash_idx" ON "ModelOutput"("inputHash");
CREATE UNIQUE INDEX "ModelOutput_scenarioVariantCellId_key" ON "ModelOutput"("scenarioVariantCellId");
ALTER TABLE "ModelOutput" ADD CONSTRAINT "ModelOutput_scenarioVariantCellId_fkey" FOREIGN KEY ("scenarioVariantCellId") REFERENCES "ScenarioVariantCell"("id") ON DELETE CASCADE ON UPDATE CASCADE;
ALTER TABLE "ModelOutput" ALTER COLUMN "scenarioVariantCellId" SET NOT NULL,
ADD CONSTRAINT "ModelOutput_pkey" PRIMARY KEY ("id");
ALTER TABLE "ScenarioVariantCell" ADD CONSTRAINT "ScenarioVariantCell_promptVariantId_fkey" FOREIGN KEY ("promptVariantId") REFERENCES "PromptVariant"("id") ON DELETE CASCADE ON UPDATE CASCADE;
ALTER TABLE "ScenarioVariantCell" ADD CONSTRAINT "ScenarioVariantCell_testScenarioId_fkey" FOREIGN KEY ("testScenarioId") REFERENCES "TestScenario"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- CreateEnum
CREATE TYPE "CellRetrievalStatus" AS ENUM ('PENDING', 'IN_PROGRESS', 'COMPLETE', 'ERROR');
-- AlterTable
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "retrievalStatus" "CellRetrievalStatus" NOT NULL DEFAULT 'COMPLETE';

View File

@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "PromptVariant" ADD COLUMN "model" TEXT NOT NULL DEFAULT 'gpt-3.5-turbo';

View File

@@ -41,7 +41,7 @@ model PromptVariant {
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
ModelOutput ModelOutput[] scenarioVariantCells ScenarioVariantCell[]
EvaluationResult EvaluationResult[] EvaluationResult EvaluationResult[]
@@index([uiId]) @@index([uiId])
@@ -61,7 +61,7 @@ model TestScenario {
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
ModelOutput ModelOutput[] scenarioVariantCells ScenarioVariantCell[]
} }
model TemplateVariable { model TemplateVariable {
@@ -76,17 +76,28 @@ model TemplateVariable {
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
} }
model ModelOutput { enum CellRetrievalStatus {
PENDING
IN_PROGRESS
COMPLETE
ERROR
}
model ScenarioVariantCell {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
inputHash String inputHash String? // TODO: Remove once migration is complete
output Json output Json? // TODO: Remove once migration is complete
statusCode Int statusCode Int?
errorMessage String? errorMessage String?
timeToComplete Int @default(0) timeToComplete Int? @default(0) // TODO: Remove once migration is complete
retryTime DateTime?
streamingChannel String?
retrievalStatus CellRetrievalStatus @default(COMPLETE)
promptTokens Int? // Added promptTokens field promptTokens Int? // TODO: Remove once migration is complete
completionTokens Int? // Added completionTokens field completionTokens Int? // TODO: Remove once migration is complete
modelOutput ModelOutput?
promptVariantId String @db.Uuid promptVariantId String @db.Uuid
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade) promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
@@ -98,6 +109,24 @@ model ModelOutput {
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
@@unique([promptVariantId, testScenarioId]) @@unique([promptVariantId, testScenarioId])
}
model ModelOutput {
id String @id @default(uuid()) @db.Uuid
inputHash String
output Json
timeToComplete Int @default(0)
promptTokens Int?
completionTokens Int?
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
scenarioVariantCellId String @db.Uuid
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
@@unique([scenarioVariantCellId])
@@index([inputHash]) @@index([inputHash])
} }

View File

@@ -16,7 +16,7 @@ const experiment = await prisma.experiment.create({
}, },
}); });
await prisma.modelOutput.deleteMany({ await prisma.scenarioVariantCell.deleteMany({
where: { where: {
promptVariant: { promptVariant: {
experimentId, experimentId,

View File

@@ -0,0 +1,33 @@
import { Button, HStack, Icon } from "@chakra-ui/react";
import { BsArrowClockwise } from "react-icons/bs";
export const CellOptions = ({
refetchingOutput,
refetchOutput,
}: {
refetchingOutput: boolean;
refetchOutput: () => void;
}) => {
return (
<HStack justifyContent="flex-end" w="full">
{!refetchingOutput && (
<Button
size="xs"
w={4}
h={4}
py={4}
px={4}
minW={0}
borderRadius={8}
color="gray.500"
variant="ghost"
cursor="pointer"
onClick={refetchOutput}
aria-label="refetch output"
>
<Icon as={BsArrowClockwise} boxSize={4} />
</Button>
)}
</HStack>
);
};

View File

@@ -1,29 +1,21 @@
import { type ModelOutput } from "@prisma/client"; import { type ScenarioVariantCell } from "@prisma/client";
import { HStack, VStack, Text, Button, Icon } from "@chakra-ui/react"; import { VStack, Text } from "@chakra-ui/react";
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import { BsArrowClockwise } from "react-icons/bs";
import { rateLimitErrorMessage } from "~/sharedStrings";
import pluralize from "pluralize"; import pluralize from "pluralize";
const MAX_AUTO_RETRIES = 3;
export const ErrorHandler = ({ export const ErrorHandler = ({
output, cell,
refetchOutput, refetchOutput,
numPreviousTries,
}: { }: {
output: ModelOutput; cell: ScenarioVariantCell;
refetchOutput: () => void; refetchOutput: () => void;
numPreviousTries: number;
}) => { }) => {
const [msToWait, setMsToWait] = useState(0); const [msToWait, setMsToWait] = useState(0);
const shouldAutoRetry =
output.errorMessage === rateLimitErrorMessage && numPreviousTries < MAX_AUTO_RETRIES;
useEffect(() => { useEffect(() => {
if (!shouldAutoRetry) return; if (!cell.retryTime) return;
const initialWaitTime = calculateDelay(numPreviousTries); const initialWaitTime = cell.retryTime.getTime() - Date.now();
const msModuloOneSecond = initialWaitTime % 1000; const msModuloOneSecond = initialWaitTime % 1000;
let remainingTime = initialWaitTime - msModuloOneSecond; let remainingTime = initialWaitTime - msModuloOneSecond;
setMsToWait(remainingTime); setMsToWait(remainingTime);
@@ -35,7 +27,6 @@ export const ErrorHandler = ({
setMsToWait(remainingTime); setMsToWait(remainingTime);
if (remainingTime <= 0) { if (remainingTime <= 0) {
refetchOutput();
clearInterval(interval); clearInterval(interval);
} }
}, 1000); }, 1000);
@@ -45,32 +36,12 @@ export const ErrorHandler = ({
clearInterval(interval); clearInterval(interval);
clearTimeout(timeout); clearTimeout(timeout);
}; };
}, [shouldAutoRetry, setMsToWait, refetchOutput, numPreviousTries]); }, [cell.retryTime, cell.statusCode, setMsToWait, refetchOutput]);
return ( return (
<VStack w="full"> <VStack w="full">
<HStack w="full" alignItems="flex-start" justifyContent="space-between">
<Text color="red.600" fontWeight="bold">
Error
</Text>
<Button
size="xs"
w={4}
h={4}
px={4}
py={4}
minW={0}
borderRadius={8}
variant="ghost"
cursor="pointer"
onClick={refetchOutput}
aria-label="refetch output"
>
<Icon as={BsArrowClockwise} boxSize={6} />
</Button>
</HStack>
<Text color="red.600" wordBreak="break-word"> <Text color="red.600" wordBreak="break-word">
{output.errorMessage} {cell.errorMessage}
</Text> </Text>
{msToWait > 0 && ( {msToWait > 0 && (
<Text color="red.600" fontSize="sm"> <Text color="red.600" fontSize="sm">
@@ -80,12 +51,3 @@ export const ErrorHandler = ({
</VStack> </VStack>
); );
}; };
const MIN_DELAY = 500; // milliseconds
const MAX_DELAY = 5000; // milliseconds
function calculateDelay(numPreviousTries: number): number {
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
const jitter = Math.random() * baseDelay;
return baseDelay + jitter;
}

View File

@@ -1,17 +1,16 @@
import { type RouterOutputs, api } from "~/utils/api"; import { api } from "~/utils/api";
import { type PromptVariant, type Scenario } from "../types"; import { type PromptVariant, type Scenario } from "../types";
import { Spinner, Text, Box, Center, Flex } from "@chakra-ui/react"; import { Spinner, Text, Box, Center, Flex, VStack } from "@chakra-ui/react";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import SyntaxHighlighter from "react-syntax-highlighter"; import SyntaxHighlighter from "react-syntax-highlighter";
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs"; import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
import stringify from "json-stringify-pretty-compact"; import stringify from "json-stringify-pretty-compact";
import { type ReactElement, useState, useEffect, useRef, useCallback } from "react"; import { type ReactElement, useState, useEffect } from "react";
import { type ChatCompletion } from "openai/resources/chat"; import { type ChatCompletion } from "openai/resources/chat";
import { generateChannel } from "~/utils/generateChannel";
import { isObject } from "lodash";
import useSocket from "~/utils/useSocket"; import useSocket from "~/utils/useSocket";
import { OutputStats } from "./OutputStats"; import { OutputStats } from "./OutputStats";
import { ErrorHandler } from "./ErrorHandler"; import { ErrorHandler } from "./ErrorHandler";
import { CellOptions } from "./CellOptions";
export default function OutputCell({ export default function OutputCell({
scenario, scenario,
@@ -37,88 +36,68 @@ export default function OutputCell({
// if (variant.config === null || Object.keys(variant.config).length === 0) // if (variant.config === null || Object.keys(variant.config).length === 0)
// disabledReason = "Save your prompt variant to see output"; // disabledReason = "Save your prompt variant to see output";
const outputMutation = api.outputs.get.useMutation(); const [refetchInterval, setRefetchInterval] = useState(0);
const { data: cell, isLoading: queryLoading } = api.scenarioVariantCells.get.useQuery(
{ scenarioId: scenario.id, variantId: variant.id },
{ refetchInterval },
);
const [output, setOutput] = useState<RouterOutputs["outputs"]["get"]>(null); const { mutateAsync: hardRefetchMutate, isLoading: refetchingOutput } =
const [channel, setChannel] = useState<string | undefined>(undefined); api.scenarioVariantCells.forceRefetch.useMutation();
const [numPreviousTries, setNumPreviousTries] = useState(0); const [hardRefetch] = useHandledAsyncCallback(async () => {
await hardRefetchMutate({ scenarioId: scenario.id, variantId: variant.id });
const fetchMutex = useRef(false); await utils.scenarioVariantCells.get.invalidate({
const [fetchOutput, fetchingOutput] = useHandledAsyncCallback(
async (forceRefetch?: boolean) => {
if (fetchMutex.current) return;
setNumPreviousTries((prev) => prev + 1);
fetchMutex.current = true;
setOutput(null);
const shouldStream =
isObject(variant) &&
"config" in variant &&
isObject(variant.config) &&
"stream" in variant.config &&
variant.config.stream === true;
const channel = shouldStream ? generateChannel() : undefined;
setChannel(channel);
const output = await outputMutation.mutateAsync({
scenarioId: scenario.id, scenarioId: scenario.id,
variantId: variant.id, variantId: variant.id,
channel,
forceRefetch,
}); });
setOutput(output); }, [hardRefetchMutate, scenario.id, variant.id]);
await utils.promptVariants.stats.invalidate();
fetchMutex.current = false;
},
[outputMutation, scenario.id, variant.id],
);
const hardRefetch = useCallback(() => fetchOutput(true), [fetchOutput]);
useEffect(fetchOutput, [scenario.id, variant.id]); const fetchingOutput = queryLoading || refetchingOutput;
const awaitingOutput =
!cell || cell.retrievalStatus === "PENDING" || cell.retrievalStatus === "IN_PROGRESS";
useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]);
const modelOutput = cell?.modelOutput;
// Disconnect from socket if we're not streaming anymore // Disconnect from socket if we're not streaming anymore
const streamedMessage = useSocket(fetchingOutput ? channel : undefined); const streamedMessage = useSocket(cell?.streamingChannel);
const streamedContent = streamedMessage?.choices?.[0]?.message?.content; const streamedContent = streamedMessage?.choices?.[0]?.message?.content;
if (!vars) return null; if (!vars) return null;
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>; if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
if (fetchingOutput && !streamedMessage) if (awaitingOutput && !streamedMessage)
return ( return (
<Center h="100%" w="100%"> <Center h="100%" w="100%">
<Spinner /> <Spinner />
</Center> </Center>
); );
if (!output && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>; if (!cell && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
if (output && output.errorMessage) { if (cell && cell.errorMessage) {
return ( return <ErrorHandler cell={cell} refetchOutput={hardRefetch} />;
<ErrorHandler
output={output}
refetchOutput={hardRefetch}
numPreviousTries={numPreviousTries}
/>
);
} }
const response = output?.output as unknown as ChatCompletion; const response = modelOutput?.output as unknown as ChatCompletion;
const message = response?.choices?.[0]?.message; const message = response?.choices?.[0]?.message;
if (output && message?.function_call) { if (modelOutput && message?.function_call) {
const rawArgs = message.function_call.arguments ?? "null"; const rawArgs = message.function_call.arguments ?? "null";
let parsedArgs: string; let parsedArgs: string;
try { try {
parsedArgs = JSON.parse(rawArgs); parsedArgs = JSON.parse(rawArgs);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) { } catch (e: any) {
parsedArgs = `Failed to parse arguments as JSON: '${rawArgs}' ERROR: ${e.message as string}`; parsedArgs = `Failed to parse arguments as JSON: '${rawArgs}' ERROR: ${e.message as string}`;
} }
return ( return (
<Box fontSize="xs" width="100%" flexWrap="wrap" overflowX="auto"> <Box fontSize="xs" width="100%" flexWrap="wrap" overflowX="auto">
<VStack w="full" spacing={0}>
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
<SyntaxHighlighter <SyntaxHighlighter
customStyle={{ overflowX: "unset" }} customStyle={{ overflowX: "unset" }}
language="json" language="json"
@@ -136,17 +115,24 @@ export default function OutputCell({
{ maxLength: 40 }, { maxLength: 40 },
)} )}
</SyntaxHighlighter> </SyntaxHighlighter>
<OutputStats model={variant.model} modelOutput={output} scenario={scenario} /> </VStack>
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
</Box> </Box>
); );
} }
const contentToDisplay = message?.content ?? streamedContent ?? JSON.stringify(output?.output); const contentToDisplay =
message?.content ?? streamedContent ?? JSON.stringify(modelOutput?.output);
return ( return (
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap"> <Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
{contentToDisplay} <VStack w="full" alignItems="flex-start" spacing={0}>
{output && <OutputStats model={variant.model} modelOutput={output} scenario={scenario} />} <CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
<Text>{contentToDisplay}</Text>
</VStack>
{modelOutput && (
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
)}
</Flex> </Flex>
); );
} }

View File

@@ -9,8 +9,8 @@ import { HStack, Icon, Text } from "@chakra-ui/react";
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs"; import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
import { CostTooltip } from "~/components/tooltip/CostTooltip"; import { CostTooltip } from "~/components/tooltip/CostTooltip";
const SHOW_COST = false; const SHOW_COST = true;
const SHOW_TIME = false; const SHOW_TIME = true;
export const OutputStats = ({ export const OutputStats = ({
model, model,
@@ -35,8 +35,6 @@ export const OutputStats = ({
const cost = promptCost + completionCost; const cost = promptCost + completionCost;
if (!evals.length) return null;
return ( return (
<HStack align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}> <HStack align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
<HStack flex={1}> <HStack flex={1}>

View File

@@ -1,7 +1,7 @@
import { type GetServerSideProps } from "next"; import { type GetServerSideProps } from "next";
// eslint-disable-next-line @typescript-eslint/require-await // eslint-disable-next-line @typescript-eslint/require-await
export const getServerSideProps: GetServerSideProps = async (context) => { export const getServerSideProps: GetServerSideProps = async () => {
return { return {
redirect: { redirect: {
destination: "/experiments", destination: "/experiments",

View File

@@ -3,10 +3,6 @@ import { prisma } from "../db";
import { openai } from "../utils/openai"; import { openai } from "../utils/openai";
import { pick } from "lodash"; import { pick } from "lodash";
function promptHasVariable(prompt: string, variableName: string) {
return prompt.includes(`{{${variableName}}}`);
}
type AxiosError = { type AxiosError = {
response?: { response?: {
data?: { data?: {

View File

@@ -2,7 +2,7 @@ import { promptVariantsRouter } from "~/server/api/routers/promptVariants.router
import { createTRPCRouter } from "~/server/api/trpc"; import { createTRPCRouter } from "~/server/api/trpc";
import { experimentsRouter } from "./routers/experiments.router"; import { experimentsRouter } from "./routers/experiments.router";
import { scenariosRouter } from "./routers/scenarios.router"; import { scenariosRouter } from "./routers/scenarios.router";
import { modelOutputsRouter } from "./routers/modelOutputs.router"; import { scenarioVariantCellsRouter } from "./routers/scenarioVariantCells.router";
import { templateVarsRouter } from "./routers/templateVariables.router"; import { templateVarsRouter } from "./routers/templateVariables.router";
import { evaluationsRouter } from "./routers/evaluations.router"; import { evaluationsRouter } from "./routers/evaluations.router";
@@ -15,7 +15,7 @@ export const appRouter = createTRPCRouter({
promptVariants: promptVariantsRouter, promptVariants: promptVariantsRouter,
experiments: experimentsRouter, experiments: experimentsRouter,
scenarios: scenariosRouter, scenarios: scenariosRouter,
outputs: modelOutputsRouter, scenarioVariantCells: scenarioVariantCellsRouter,
templateVars: templateVarsRouter, templateVars: templateVarsRouter,
evaluations: evaluationsRouter, evaluations: evaluationsRouter,
}); });

View File

@@ -2,6 +2,7 @@ import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import dedent from "dedent"; import dedent from "dedent";
import { generateNewCell } from "~/server/utils/generateNewCell";
export const experimentsRouter = createTRPCRouter({ export const experimentsRouter = createTRPCRouter({
list: publicProcedure.query(async () => { list: publicProcedure.query(async () => {
@@ -64,7 +65,7 @@ export const experimentsRouter = createTRPCRouter({
}, },
}); });
await prisma.$transaction([ const [variant, scenario] = await prisma.$transaction([
prisma.promptVariant.create({ prisma.promptVariant.create({
data: { data: {
experimentId: exp.id, experimentId: exp.id,
@@ -86,6 +87,8 @@ export const experimentsRouter = createTRPCRouter({
}), }),
]); ]);
await generateNewCell(variant.id, scenario.id);
return exp; return exp;
}), }),

View File

@@ -1,101 +0,0 @@
import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import crypto from "crypto";
import type { Prisma } from "@prisma/client";
import { reevaluateVariant } from "~/server/utils/evaluations";
import { getCompletion } from "~/server/utils/getCompletion";
import { constructPrompt } from "~/server/utils/constructPrompt";
import { type CompletionCreateParams } from "openai/resources/chat";
export const modelOutputsRouter = createTRPCRouter({
get: publicProcedure
.input(
z.object({
scenarioId: z.string(),
variantId: z.string(),
channel: z.string().optional(),
forceRefetch: z.boolean().optional(),
}),
)
.mutation(async ({ input }) => {
const existing = await prisma.modelOutput.findUnique({
where: {
promptVariantId_testScenarioId: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
},
},
});
if (existing && !input.forceRefetch) return existing;
const variant = await prisma.promptVariant.findUnique({
where: {
id: input.variantId,
},
});
const scenario = await prisma.testScenario.findUnique({
where: {
id: input.scenarioId,
},
});
if (!variant || !scenario) return null;
const prompt = await constructPrompt(variant, scenario.variableValues);
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
// TODO: we should probably only use this if temperature=0
const existingResponse = await prisma.modelOutput.findFirst({
where: { inputHash, errorMessage: null },
});
let modelResponse: Awaited<ReturnType<typeof getCompletion>>;
if (existingResponse) {
modelResponse = {
output: existingResponse.output as Prisma.InputJsonValue,
statusCode: existingResponse.statusCode,
errorMessage: existingResponse.errorMessage,
timeToComplete: existingResponse.timeToComplete,
promptTokens: existingResponse.promptTokens ?? undefined,
completionTokens: existingResponse.completionTokens ?? undefined,
};
} else {
try {
modelResponse = await getCompletion(
prompt as unknown as CompletionCreateParams,
input.channel,
);
} catch (e) {
console.error(e);
throw e;
}
}
const modelOutput = await prisma.modelOutput.upsert({
where: {
promptVariantId_testScenarioId: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
},
},
create: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
inputHash,
...modelResponse,
},
update: {
...modelResponse,
},
});
await reevaluateVariant(input.variantId);
return modelOutput;
}),
});

View File

@@ -2,6 +2,7 @@ import { isObject } from "lodash";
import { z } from "zod"; import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { generateNewCell } from "~/server/utils/generateNewCell";
import { OpenAIChatModel } from "~/server/types"; import { OpenAIChatModel } from "~/server/types";
import { constructPrompt } from "~/server/utils/constructPrompt"; import { constructPrompt } from "~/server/utils/constructPrompt";
import userError from "~/server/utils/error"; import userError from "~/server/utils/error";
@@ -43,17 +44,24 @@ export const promptVariantsRouter = createTRPCRouter({
visible: true, visible: true,
}, },
}); });
const outputCount = await prisma.modelOutput.count({ const outputCount = await prisma.scenarioVariantCell.count({
where: { where: {
promptVariantId: input.variantId, promptVariantId: input.variantId,
testScenario: { visible: true }, testScenario: { visible: true },
modelOutput: {
isNot: null,
},
}, },
}); });
const overallTokens = await prisma.modelOutput.aggregate({ const overallTokens = await prisma.modelOutput.aggregate({
where: { where: {
scenarioVariantCell: {
promptVariantId: input.variantId, promptVariantId: input.variantId,
testScenario: { visible: true }, testScenario: {
visible: true,
},
},
}, },
_sum: { _sum: {
promptTokens: true, promptTokens: true,
@@ -115,6 +123,17 @@ export const promptVariantsRouter = createTRPCRouter({
recordExperimentUpdated(input.experimentId), recordExperimentUpdated(input.experimentId),
]); ]);
const scenarios = await prisma.testScenario.findMany({
where: {
experimentId: input.experimentId,
visible: true,
},
});
for (const scenario of scenarios) {
await generateNewCell(newVariant.id, scenario.id);
}
return newVariant; return newVariant;
}), }),
@@ -234,6 +253,17 @@ export const promptVariantsRouter = createTRPCRouter({
await prisma.$transaction([hideOldVariants, recordExperimentUpdated(existing.experimentId)]); await prisma.$transaction([hideOldVariants, recordExperimentUpdated(existing.experimentId)]);
const scenarios = await prisma.testScenario.findMany({
where: {
experimentId: newVariant.experimentId,
visible: true,
},
});
for (const scenario of scenarios) {
await generateNewCell(newVariant.id, scenario.id);
}
return { status: "ok" } as const; return { status: "ok" } as const;
}), }),

View File

@@ -0,0 +1,68 @@
import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { generateNewCell } from "~/server/utils/generateNewCell";
import { queueLLMRetrievalTask } from "~/server/utils/queueLLMRetrievalTask";
export const scenarioVariantCellsRouter = createTRPCRouter({
get: publicProcedure
.input(
z.object({
scenarioId: z.string(),
variantId: z.string(),
}),
)
.query(async ({ input }) => {
return await prisma.scenarioVariantCell.findUnique({
where: {
promptVariantId_testScenarioId: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
},
},
include: {
modelOutput: true,
},
});
}),
forceRefetch: publicProcedure
.input(
z.object({
scenarioId: z.string(),
variantId: z.string(),
}),
)
.mutation(async ({ input }) => {
const cell = await prisma.scenarioVariantCell.findUnique({
where: {
promptVariantId_testScenarioId: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
},
},
include: {
modelOutput: true,
},
});
if (!cell) {
await generateNewCell(input.variantId, input.scenarioId);
return true;
}
if (cell.modelOutput) {
// TODO: Maybe keep these around to show previous generations?
await prisma.modelOutput.delete({
where: { id: cell.modelOutput.id },
});
}
await prisma.scenarioVariantCell.update({
where: { id: cell.id },
data: { retrievalStatus: "PENDING" },
});
await queueLLMRetrievalTask(cell.id);
return true;
}),
});

View File

@@ -4,6 +4,7 @@ import { prisma } from "~/server/db";
import { autogenerateScenarioValues } from "../autogen"; import { autogenerateScenarioValues } from "../autogen";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated"; import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
import { reevaluateAll } from "~/server/utils/evaluations"; import { reevaluateAll } from "~/server/utils/evaluations";
import { generateNewCell } from "~/server/utils/generateNewCell";
export const scenariosRouter = createTRPCRouter({ export const scenariosRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
@@ -48,10 +49,21 @@ export const scenariosRouter = createTRPCRouter({
}, },
}); });
await prisma.$transaction([ const [scenario] = await prisma.$transaction([
createNewScenarioAction, createNewScenarioAction,
recordExperimentUpdated(input.experimentId), recordExperimentUpdated(input.experimentId),
]); ]);
const promptVariants = await prisma.promptVariant.findMany({
where: {
experimentId: input.experimentId,
visible: true,
},
});
for (const variant of promptVariants) {
await generateNewCell(variant.id, scenario.id);
}
}), }),
hide: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => { hide: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
@@ -175,6 +187,17 @@ export const scenariosRouter = createTRPCRouter({
}, },
}); });
const promptVariants = await prisma.promptVariant.findMany({
where: {
experimentId: newScenario.experimentId,
visible: true,
},
});
for (const variant of promptVariants) {
await generateNewCell(variant.id, newScenario.id);
}
return newScenario; return newScenario;
}), }),
}); });

View File

@@ -0,0 +1,47 @@
import { type Prisma } from "@prisma/client";
import { prisma } from "../db";
async function migrateScenarioVariantOutputData() {
// Get all ScenarioVariantCells
const cells = await prisma.scenarioVariantCell.findMany({ include: { modelOutput: true } });
console.log(`Found ${cells.length} records`);
let updatedCount = 0;
// Loop through all scenarioVariants
for (const cell of cells) {
// Create a new ModelOutput for each ScenarioVariant with an existing output
if (cell.output && !cell.modelOutput) {
updatedCount++;
await prisma.modelOutput.create({
data: {
scenarioVariantCellId: cell.id,
inputHash: cell.inputHash || "",
output: cell.output as Prisma.InputJsonValue,
timeToComplete: cell.timeToComplete ?? undefined,
promptTokens: cell.promptTokens,
completionTokens: cell.completionTokens,
createdAt: cell.createdAt,
updatedAt: cell.updatedAt,
},
});
} else if (cell.errorMessage && cell.retrievalStatus === "COMPLETE") {
updatedCount++;
await prisma.scenarioVariantCell.update({
where: { id: cell.id },
data: {
retrievalStatus: "ERROR",
},
});
}
}
console.log("Data migration completed");
console.log(`Updated ${updatedCount} records`);
}
// Execute the function
migrateScenarioVariantOutputData().catch((error) => {
console.error("An error occurred while migrating data: ", error);
process.exit(1);
});

View File

@@ -0,0 +1,31 @@
// Import necessary dependencies
import { quickAddJob, type Helpers, type Task } from "graphile-worker";
import { env } from "~/env.mjs";
// Define the defineTask function
function defineTask<TPayload>(
taskIdentifier: string,
taskHandler: (payload: TPayload, helpers: Helpers) => Promise<void>,
) {
const enqueue = async (payload: TPayload) => {
console.log("Enqueuing task", taskIdentifier, payload);
await quickAddJob({ connectionString: env.DATABASE_URL }, taskIdentifier, payload);
};
const handler = (payload: TPayload, helpers: Helpers) => {
helpers.logger.info(`Running task ${taskIdentifier} with payload: ${JSON.stringify(payload)}`);
return taskHandler(payload, helpers);
};
const task = {
identifier: taskIdentifier,
handler: handler as Task,
};
return {
enqueue,
task,
};
}
export default defineTask;

View File

@@ -0,0 +1,144 @@
import crypto from "crypto";
import { prisma } from "~/server/db";
import defineTask from "./defineTask";
import { type CompletionResponse, getCompletion } from "../utils/getCompletion";
import { type JSONSerializable } from "../types";
import { sleep } from "../utils/sleep";
import { shouldStream } from "../utils/shouldStream";
import { generateChannel } from "~/utils/generateChannel";
import { reevaluateVariant } from "../utils/evaluations";
import { constructPrompt } from "../utils/constructPrompt";
import { type CompletionCreateParams } from "openai/resources/chat";
const MAX_AUTO_RETRIES = 10;
const MIN_DELAY = 500; // milliseconds
const MAX_DELAY = 15000; // milliseconds
function calculateDelay(numPreviousTries: number): number {
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
const jitter = Math.random() * baseDelay;
return baseDelay + jitter;
}
const getCompletionWithRetries = async (
cellId: string,
payload: JSONSerializable,
channel?: string,
): Promise<CompletionResponse> => {
for (let i = 0; i < MAX_AUTO_RETRIES; i++) {
const modelResponse = await getCompletion(
payload as unknown as CompletionCreateParams,
channel,
);
if (modelResponse.statusCode !== 429 || i === MAX_AUTO_RETRIES - 1) {
return modelResponse;
}
const delay = calculateDelay(i);
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
errorMessage: "Rate limit exceeded",
statusCode: 429,
retryTime: new Date(Date.now() + delay),
},
});
// TODO: Maybe requeue the job so other jobs can run in the future?
await sleep(delay);
}
throw new Error("Max retries limit reached");
};
export type queryLLMJob = {
scenarioVariantCellId: string;
};
export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
const { scenarioVariantCellId } = task;
const cell = await prisma.scenarioVariantCell.findUnique({
where: { id: scenarioVariantCellId },
include: { modelOutput: true },
});
if (!cell) {
return;
}
// If cell is not pending, then some other job is already processing it
if (cell.retrievalStatus !== "PENDING") {
return;
}
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
retrievalStatus: "IN_PROGRESS",
},
});
const variant = await prisma.promptVariant.findUnique({
where: { id: cell.promptVariantId },
});
if (!variant) {
return;
}
const scenario = await prisma.testScenario.findUnique({
where: { id: cell.testScenarioId },
});
if (!scenario) {
return;
}
const prompt = await constructPrompt(variant, scenario.variableValues);
const streamingEnabled = shouldStream(prompt);
let streamingChannel;
if (streamingEnabled) {
streamingChannel = generateChannel();
// Save streaming channel so that UI can connect to it
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
streamingChannel,
},
});
}
const modelResponse = await getCompletionWithRetries(
scenarioVariantCellId,
prompt,
streamingChannel,
);
let modelOutput = null;
if (modelResponse.statusCode === 200) {
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
modelOutput = await prisma.modelOutput.create({
data: {
scenarioVariantCellId,
inputHash,
output: modelResponse.output,
timeToComplete: modelResponse.timeToComplete,
promptTokens: modelResponse.promptTokens,
completionTokens: modelResponse.completionTokens,
},
});
}
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
statusCode: modelResponse.statusCode,
errorMessage: modelResponse.errorMessage,
streamingChannel: null,
retrievalStatus: modelOutput ? "COMPLETE" : "ERROR",
modelOutput: {
connect: {
id: modelOutput?.id,
},
},
},
});
await reevaluateVariant(cell.promptVariantId);
});

View File

@@ -0,0 +1,40 @@
import { type TaskList, run } from "graphile-worker";
import "dotenv/config";
import { env } from "~/env.mjs";
import { queryLLM } from "./queryLLM.task";
const registeredTasks = [queryLLM];
const taskList = registeredTasks.reduce((acc, task) => {
acc[task.task.identifier] = task.task.handler;
return acc;
}, {} as TaskList);
async function main() {
// Run a worker to execute jobs:
const runner = await run({
connectionString: env.DATABASE_URL,
concurrency: 20,
// Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc
noHandleSignals: false,
pollInterval: 1000,
// you can set the taskList or taskDirectory but not both
taskList,
// or:
// taskDirectory: `${__dirname}/tasks`,
});
// Immediately await (or otherwise handled) the resulting promise, to avoid
// "unhandled rejection" errors causing a process crash in the event of
// something going wrong.
await runner.promise;
// If the worker exits (whether through fatal error or otherwise), the above
// promise will resolve/reject.
}
main().catch((err) => {
console.error("Unhandled error occurred running worker: ", err);
process.exit(1);
});

View File

@@ -1,4 +1,4 @@
import { type Evaluation } from "@prisma/client"; import { type ModelOutput, type Evaluation } from "@prisma/client";
import { prisma } from "../db"; import { prisma } from "../db";
import { evaluateOutput } from "./evaluateOutput"; import { evaluateOutput } from "./evaluateOutput";
@@ -12,21 +12,22 @@ export const reevaluateVariant = async (variantId: string) => {
where: { experimentId: variant.experimentId }, where: { experimentId: variant.experimentId },
}); });
const modelOutputs = await prisma.modelOutput.findMany({ const cells = await prisma.scenarioVariantCell.findMany({
where: { where: {
promptVariantId: variantId, promptVariantId: variantId,
statusCode: { notIn: [429] }, retrievalStatus: "COMPLETE",
testScenario: { visible: true }, testScenario: { visible: true },
modelOutput: { isNot: null },
}, },
include: { testScenario: true }, include: { testScenario: true, modelOutput: true },
}); });
await Promise.all( await Promise.all(
evaluations.map(async (evaluation) => { evaluations.map(async (evaluation) => {
const passCount = modelOutputs.filter((output) => const passCount = cells.filter((cell) =>
evaluateOutput(output, output.testScenario, evaluation), evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
).length; ).length;
const failCount = modelOutputs.length - passCount; const failCount = cells.length - passCount;
await prisma.evaluationResult.upsert({ await prisma.evaluationResult.upsert({
where: { where: {
@@ -55,22 +56,23 @@ export const reevaluateEvaluation = async (evaluation: Evaluation) => {
where: { experimentId: evaluation.experimentId, visible: true }, where: { experimentId: evaluation.experimentId, visible: true },
}); });
const modelOutputs = await prisma.modelOutput.findMany({ const cells = await prisma.scenarioVariantCell.findMany({
where: { where: {
promptVariantId: { in: variants.map((v) => v.id) }, promptVariantId: { in: variants.map((v) => v.id) },
testScenario: { visible: true }, testScenario: { visible: true },
statusCode: { notIn: [429] }, statusCode: { notIn: [429] },
modelOutput: { isNot: null },
}, },
include: { testScenario: true }, include: { testScenario: true, modelOutput: true },
}); });
await Promise.all( await Promise.all(
variants.map(async (variant) => { variants.map(async (variant) => {
const outputs = modelOutputs.filter((output) => output.promptVariantId === variant.id); const variantCells = cells.filter((cell) => cell.promptVariantId === variant.id);
const passCount = outputs.filter((output) => const passCount = variantCells.filter((cell) =>
evaluateOutput(output, output.testScenario, evaluation), evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
).length; ).length;
const failCount = outputs.length - passCount; const failCount = variantCells.length - passCount;
await prisma.evaluationResult.upsert({ await prisma.evaluationResult.upsert({
where: { where: {

View File

@@ -0,0 +1,76 @@
import crypto from "crypto";
import { type Prisma } from "@prisma/client";
import { prisma } from "../db";
import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
import { constructPrompt } from "./constructPrompt";
export const generateNewCell = async (variantId: string, scenarioId: string) => {
const variant = await prisma.promptVariant.findUnique({
where: {
id: variantId,
},
});
const scenario = await prisma.testScenario.findUnique({
where: {
id: scenarioId,
},
});
if (!variant || !scenario) return null;
const prompt = await constructPrompt(variant, scenario.variableValues);
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
let cell = await prisma.scenarioVariantCell.findUnique({
where: {
promptVariantId_testScenarioId: {
promptVariantId: variantId,
testScenarioId: scenarioId,
},
},
include: {
modelOutput: true,
},
});
if (cell) return cell;
cell = await prisma.scenarioVariantCell.create({
data: {
promptVariantId: variantId,
testScenarioId: scenarioId,
},
include: {
modelOutput: true,
},
});
const matchingModelOutput = await prisma.modelOutput.findFirst({
where: {
inputHash,
},
});
let newModelOutput;
if (matchingModelOutput) {
newModelOutput = await prisma.modelOutput.create({
data: {
scenarioVariantCellId: cell.id,
inputHash,
output: matchingModelOutput.output as Prisma.InputJsonValue,
timeToComplete: matchingModelOutput.timeToComplete,
promptTokens: matchingModelOutput.promptTokens,
completionTokens: matchingModelOutput.completionTokens,
createdAt: matchingModelOutput.createdAt,
updatedAt: matchingModelOutput.updatedAt,
},
});
} else {
cell = await queueLLMRetrievalTask(cell.id);
}
return { ...cell, modelOutput: newModelOutput };
};

View File

@@ -9,7 +9,7 @@ import { env } from "~/env.mjs";
import { countOpenAIChatTokens } from "~/utils/countTokens"; import { countOpenAIChatTokens } from "~/utils/countTokens";
import { rateLimitErrorMessage } from "~/sharedStrings"; import { rateLimitErrorMessage } from "~/sharedStrings";
type CompletionResponse = { export type CompletionResponse = {
output: Prisma.InputJsonValue | typeof Prisma.JsonNull; output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
statusCode: number; statusCode: number;
errorMessage: string | null; errorMessage: string | null;

View File

@@ -0,0 +1,22 @@
import { prisma } from "../db";
import { queryLLM } from "../tasks/queryLLM.task";
export const queueLLMRetrievalTask = async (cellId: string) => {
const updatedCell = await prisma.scenarioVariantCell.update({
where: {
id: cellId,
},
data: {
retrievalStatus: "PENDING",
errorMessage: null,
},
include: {
modelOutput: true,
},
});
// @ts-expect-error we aren't passing the helpers but that's ok
void queryLLM.task.handler({ scenarioVariantCellId: cellId }, { logger: console });
return updatedCell;
};

View File

@@ -0,0 +1,7 @@
import { isObject } from "lodash";
import { type JSONSerializable } from "../types";
export const shouldStream = (config: JSONSerializable): boolean => {
const shouldStream = isObject(config) && "stream" in config && config.stream === true;
return shouldStream;
};

View File

@@ -0,0 +1 @@
export const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));

View File

@@ -5,7 +5,7 @@ import { env } from "~/env.mjs";
const url = env.NEXT_PUBLIC_SOCKET_URL; const url = env.NEXT_PUBLIC_SOCKET_URL;
export default function useSocket(channel?: string) { export default function useSocket(channel?: string | null) {
const socketRef = useRef<Socket>(); const socketRef = useRef<Socket>();
const [message, setMessage] = useState<ChatCompletion | null>(null); const [message, setMessage] = useState<ChatCompletion | null>(null);