Compare commits
9 Commits
node-versi
...
save-butto
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
74c201d3a8 | ||
|
|
ab9c721d09 | ||
|
|
0a2578a1d8 | ||
|
|
1bebaff386 | ||
|
|
3bf5eaf4a2 | ||
|
|
ded97f8bb9 | ||
|
|
26ee8698be | ||
|
|
b98eb9b729 | ||
|
|
032c07ec65 |
@@ -11,6 +11,7 @@
|
|||||||
"build": "next build",
|
"build": "next build",
|
||||||
"dev:next": "next dev",
|
"dev:next": "next dev",
|
||||||
"dev:wss": "pnpm tsx --watch src/wss-server.ts",
|
"dev:wss": "pnpm tsx --watch src/wss-server.ts",
|
||||||
|
"dev:worker": "NODE_ENV='development' pnpm tsx --watch src/server/tasks/worker.ts",
|
||||||
"dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss'",
|
"dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss'",
|
||||||
"postinstall": "prisma generate",
|
"postinstall": "prisma generate",
|
||||||
"lint": "next lint",
|
"lint": "next lint",
|
||||||
@@ -44,6 +45,7 @@
|
|||||||
"express": "^4.18.2",
|
"express": "^4.18.2",
|
||||||
"framer-motion": "^10.12.17",
|
"framer-motion": "^10.12.17",
|
||||||
"gpt-tokens": "^1.0.10",
|
"gpt-tokens": "^1.0.10",
|
||||||
|
"graphile-worker": "^0.13.0",
|
||||||
"immer": "^10.0.2",
|
"immer": "^10.0.2",
|
||||||
"isolated-vm": "^4.5.0",
|
"isolated-vm": "^4.5.0",
|
||||||
"json-stringify-pretty-compact": "^4.0.0",
|
"json-stringify-pretty-compact": "^4.0.0",
|
||||||
|
|||||||
506
pnpm-lock.yaml
generated
506
pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,49 @@
|
|||||||
|
-- Drop the foreign key constraints on the original ModelOutput
|
||||||
|
ALTER TABLE "ModelOutput" DROP CONSTRAINT "ModelOutput_promptVariantId_fkey";
|
||||||
|
ALTER TABLE "ModelOutput" DROP CONSTRAINT "ModelOutput_testScenarioId_fkey";
|
||||||
|
|
||||||
|
-- Rename the old table
|
||||||
|
ALTER TABLE "ModelOutput" RENAME TO "ScenarioVariantCell";
|
||||||
|
ALTER TABLE "ScenarioVariantCell" RENAME CONSTRAINT "ModelOutput_pkey" TO "ScenarioVariantCell_pkey";
|
||||||
|
ALTER INDEX "ModelOutput_inputHash_idx" RENAME TO "ScenarioVariantCell_inputHash_idx";
|
||||||
|
ALTER INDEX "ModelOutput_promptVariantId_testScenarioId_key" RENAME TO "ScenarioVariantCell_promptVariantId_testScenarioId_key";
|
||||||
|
|
||||||
|
-- Add the new fields to the renamed table
|
||||||
|
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "retryTime" TIMESTAMP(3);
|
||||||
|
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "streamingChannel" TEXT;
|
||||||
|
ALTER TABLE "ScenarioVariantCell" ALTER COLUMN "inputHash" DROP NOT NULL;
|
||||||
|
ALTER TABLE "ScenarioVariantCell" ALTER COLUMN "output" DROP NOT NULL,
|
||||||
|
ALTER COLUMN "statusCode" DROP NOT NULL,
|
||||||
|
ALTER COLUMN "timeToComplete" DROP NOT NULL;
|
||||||
|
|
||||||
|
-- Create the new table
|
||||||
|
CREATE TABLE "ModelOutput" (
|
||||||
|
"id" UUID NOT NULL,
|
||||||
|
"inputHash" TEXT NOT NULL,
|
||||||
|
"output" JSONB NOT NULL,
|
||||||
|
"timeToComplete" INTEGER NOT NULL DEFAULT 0,
|
||||||
|
"promptTokens" INTEGER,
|
||||||
|
"completionTokens" INTEGER,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||||
|
"scenarioVariantCellId" UUID
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Move inputHash index
|
||||||
|
DROP INDEX "ScenarioVariantCell_inputHash_idx";
|
||||||
|
CREATE INDEX "ModelOutput_inputHash_idx" ON "ModelOutput"("inputHash");
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX "ModelOutput_scenarioVariantCellId_key" ON "ModelOutput"("scenarioVariantCellId");
|
||||||
|
ALTER TABLE "ModelOutput" ADD CONSTRAINT "ModelOutput_scenarioVariantCellId_fkey" FOREIGN KEY ("scenarioVariantCellId") REFERENCES "ScenarioVariantCell"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
ALTER TABLE "ModelOutput" ALTER COLUMN "scenarioVariantCellId" SET NOT NULL,
|
||||||
|
ADD CONSTRAINT "ModelOutput_pkey" PRIMARY KEY ("id");
|
||||||
|
|
||||||
|
ALTER TABLE "ScenarioVariantCell" ADD CONSTRAINT "ScenarioVariantCell_promptVariantId_fkey" FOREIGN KEY ("promptVariantId") REFERENCES "PromptVariant"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
ALTER TABLE "ScenarioVariantCell" ADD CONSTRAINT "ScenarioVariantCell_testScenarioId_fkey" FOREIGN KEY ("testScenarioId") REFERENCES "TestScenario"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
-- CreateEnum
|
||||||
|
CREATE TYPE "CellRetrievalStatus" AS ENUM ('PENDING', 'IN_PROGRESS', 'COMPLETE', 'ERROR');
|
||||||
|
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "retrievalStatus" "CellRetrievalStatus" NOT NULL DEFAULT 'COMPLETE';
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "PromptVariant" ADD COLUMN "model" TEXT NOT NULL DEFAULT 'gpt-3.5-turbo';
|
||||||
@@ -41,7 +41,7 @@ model PromptVariant {
|
|||||||
|
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
ModelOutput ModelOutput[]
|
scenarioVariantCells ScenarioVariantCell[]
|
||||||
EvaluationResult EvaluationResult[]
|
EvaluationResult EvaluationResult[]
|
||||||
|
|
||||||
@@index([uiId])
|
@@index([uiId])
|
||||||
@@ -61,7 +61,7 @@ model TestScenario {
|
|||||||
|
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
ModelOutput ModelOutput[]
|
scenarioVariantCells ScenarioVariantCell[]
|
||||||
}
|
}
|
||||||
|
|
||||||
model TemplateVariable {
|
model TemplateVariable {
|
||||||
@@ -76,17 +76,28 @@ model TemplateVariable {
|
|||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
}
|
}
|
||||||
|
|
||||||
model ModelOutput {
|
enum CellRetrievalStatus {
|
||||||
|
PENDING
|
||||||
|
IN_PROGRESS
|
||||||
|
COMPLETE
|
||||||
|
ERROR
|
||||||
|
}
|
||||||
|
|
||||||
|
model ScenarioVariantCell {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
inputHash String
|
inputHash String? // TODO: Remove once migration is complete
|
||||||
output Json
|
output Json? // TODO: Remove once migration is complete
|
||||||
statusCode Int
|
statusCode Int?
|
||||||
errorMessage String?
|
errorMessage String?
|
||||||
timeToComplete Int @default(0)
|
timeToComplete Int? @default(0) // TODO: Remove once migration is complete
|
||||||
|
retryTime DateTime?
|
||||||
|
streamingChannel String?
|
||||||
|
retrievalStatus CellRetrievalStatus @default(COMPLETE)
|
||||||
|
|
||||||
promptTokens Int? // Added promptTokens field
|
promptTokens Int? // TODO: Remove once migration is complete
|
||||||
completionTokens Int? // Added completionTokens field
|
completionTokens Int? // TODO: Remove once migration is complete
|
||||||
|
modelOutput ModelOutput?
|
||||||
|
|
||||||
promptVariantId String @db.Uuid
|
promptVariantId String @db.Uuid
|
||||||
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
|
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
|
||||||
@@ -98,6 +109,24 @@ model ModelOutput {
|
|||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
@@unique([promptVariantId, testScenarioId])
|
@@unique([promptVariantId, testScenarioId])
|
||||||
|
}
|
||||||
|
|
||||||
|
model ModelOutput {
|
||||||
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
|
inputHash String
|
||||||
|
output Json
|
||||||
|
timeToComplete Int @default(0)
|
||||||
|
promptTokens Int?
|
||||||
|
completionTokens Int?
|
||||||
|
|
||||||
|
createdAt DateTime @default(now())
|
||||||
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
|
scenarioVariantCellId String @db.Uuid
|
||||||
|
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
|
@@unique([scenarioVariantCellId])
|
||||||
@@index([inputHash])
|
@@index([inputHash])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ const experiment = await prisma.experiment.create({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await prisma.modelOutput.deleteMany({
|
await prisma.scenarioVariantCell.deleteMany({
|
||||||
where: {
|
where: {
|
||||||
promptVariant: {
|
promptVariant: {
|
||||||
experimentId,
|
experimentId,
|
||||||
|
|||||||
33
src/components/OutputsTable/OutputCell/CellOptions.tsx
Normal file
33
src/components/OutputsTable/OutputCell/CellOptions.tsx
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import { Button, HStack, Icon } from "@chakra-ui/react";
|
||||||
|
import { BsArrowClockwise } from "react-icons/bs";
|
||||||
|
|
||||||
|
export const CellOptions = ({
|
||||||
|
refetchingOutput,
|
||||||
|
refetchOutput,
|
||||||
|
}: {
|
||||||
|
refetchingOutput: boolean;
|
||||||
|
refetchOutput: () => void;
|
||||||
|
}) => {
|
||||||
|
return (
|
||||||
|
<HStack justifyContent="flex-end" w="full">
|
||||||
|
{!refetchingOutput && (
|
||||||
|
<Button
|
||||||
|
size="xs"
|
||||||
|
w={4}
|
||||||
|
h={4}
|
||||||
|
py={4}
|
||||||
|
px={4}
|
||||||
|
minW={0}
|
||||||
|
borderRadius={8}
|
||||||
|
color="gray.500"
|
||||||
|
variant="ghost"
|
||||||
|
cursor="pointer"
|
||||||
|
onClick={refetchOutput}
|
||||||
|
aria-label="refetch output"
|
||||||
|
>
|
||||||
|
<Icon as={BsArrowClockwise} boxSize={4} />
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</HStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
@@ -1,29 +1,21 @@
|
|||||||
import { type ModelOutput } from "@prisma/client";
|
import { type ScenarioVariantCell } from "@prisma/client";
|
||||||
import { HStack, VStack, Text, Button, Icon } from "@chakra-ui/react";
|
import { VStack, Text } from "@chakra-ui/react";
|
||||||
import { useEffect, useState } from "react";
|
import { useEffect, useState } from "react";
|
||||||
import { BsArrowClockwise } from "react-icons/bs";
|
|
||||||
import { rateLimitErrorMessage } from "~/sharedStrings";
|
|
||||||
import pluralize from "pluralize";
|
import pluralize from "pluralize";
|
||||||
|
|
||||||
const MAX_AUTO_RETRIES = 3;
|
|
||||||
|
|
||||||
export const ErrorHandler = ({
|
export const ErrorHandler = ({
|
||||||
output,
|
cell,
|
||||||
refetchOutput,
|
refetchOutput,
|
||||||
numPreviousTries,
|
|
||||||
}: {
|
}: {
|
||||||
output: ModelOutput;
|
cell: ScenarioVariantCell;
|
||||||
refetchOutput: () => void;
|
refetchOutput: () => void;
|
||||||
numPreviousTries: number;
|
|
||||||
}) => {
|
}) => {
|
||||||
const [msToWait, setMsToWait] = useState(0);
|
const [msToWait, setMsToWait] = useState(0);
|
||||||
const shouldAutoRetry =
|
|
||||||
output.errorMessage === rateLimitErrorMessage && numPreviousTries < MAX_AUTO_RETRIES;
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!shouldAutoRetry) return;
|
if (!cell.retryTime) return;
|
||||||
|
|
||||||
const initialWaitTime = calculateDelay(numPreviousTries);
|
const initialWaitTime = cell.retryTime.getTime() - Date.now();
|
||||||
const msModuloOneSecond = initialWaitTime % 1000;
|
const msModuloOneSecond = initialWaitTime % 1000;
|
||||||
let remainingTime = initialWaitTime - msModuloOneSecond;
|
let remainingTime = initialWaitTime - msModuloOneSecond;
|
||||||
setMsToWait(remainingTime);
|
setMsToWait(remainingTime);
|
||||||
@@ -35,7 +27,6 @@ export const ErrorHandler = ({
|
|||||||
setMsToWait(remainingTime);
|
setMsToWait(remainingTime);
|
||||||
|
|
||||||
if (remainingTime <= 0) {
|
if (remainingTime <= 0) {
|
||||||
refetchOutput();
|
|
||||||
clearInterval(interval);
|
clearInterval(interval);
|
||||||
}
|
}
|
||||||
}, 1000);
|
}, 1000);
|
||||||
@@ -45,32 +36,12 @@ export const ErrorHandler = ({
|
|||||||
clearInterval(interval);
|
clearInterval(interval);
|
||||||
clearTimeout(timeout);
|
clearTimeout(timeout);
|
||||||
};
|
};
|
||||||
}, [shouldAutoRetry, setMsToWait, refetchOutput, numPreviousTries]);
|
}, [cell.retryTime, cell.statusCode, setMsToWait, refetchOutput]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<VStack w="full">
|
<VStack w="full">
|
||||||
<HStack w="full" alignItems="flex-start" justifyContent="space-between">
|
|
||||||
<Text color="red.600" fontWeight="bold">
|
|
||||||
Error
|
|
||||||
</Text>
|
|
||||||
<Button
|
|
||||||
size="xs"
|
|
||||||
w={4}
|
|
||||||
h={4}
|
|
||||||
px={4}
|
|
||||||
py={4}
|
|
||||||
minW={0}
|
|
||||||
borderRadius={8}
|
|
||||||
variant="ghost"
|
|
||||||
cursor="pointer"
|
|
||||||
onClick={refetchOutput}
|
|
||||||
aria-label="refetch output"
|
|
||||||
>
|
|
||||||
<Icon as={BsArrowClockwise} boxSize={6} />
|
|
||||||
</Button>
|
|
||||||
</HStack>
|
|
||||||
<Text color="red.600" wordBreak="break-word">
|
<Text color="red.600" wordBreak="break-word">
|
||||||
{output.errorMessage}
|
{cell.errorMessage}
|
||||||
</Text>
|
</Text>
|
||||||
{msToWait > 0 && (
|
{msToWait > 0 && (
|
||||||
<Text color="red.600" fontSize="sm">
|
<Text color="red.600" fontSize="sm">
|
||||||
@@ -80,12 +51,3 @@ export const ErrorHandler = ({
|
|||||||
</VStack>
|
</VStack>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
const MIN_DELAY = 500; // milliseconds
|
|
||||||
const MAX_DELAY = 5000; // milliseconds
|
|
||||||
|
|
||||||
function calculateDelay(numPreviousTries: number): number {
|
|
||||||
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
|
|
||||||
const jitter = Math.random() * baseDelay;
|
|
||||||
return baseDelay + jitter;
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,17 +1,16 @@
|
|||||||
import { type RouterOutputs, api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { type PromptVariant, type Scenario } from "../types";
|
import { type PromptVariant, type Scenario } from "../types";
|
||||||
import { Spinner, Text, Box, Center, Flex } from "@chakra-ui/react";
|
import { Spinner, Text, Box, Center, Flex, VStack } from "@chakra-ui/react";
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
import SyntaxHighlighter from "react-syntax-highlighter";
|
import SyntaxHighlighter from "react-syntax-highlighter";
|
||||||
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
|
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
|
||||||
import stringify from "json-stringify-pretty-compact";
|
import stringify from "json-stringify-pretty-compact";
|
||||||
import { type ReactElement, useState, useEffect, useRef, useCallback } from "react";
|
import { type ReactElement, useState, useEffect } from "react";
|
||||||
import { type ChatCompletion } from "openai/resources/chat";
|
import { type ChatCompletion } from "openai/resources/chat";
|
||||||
import { generateChannel } from "~/utils/generateChannel";
|
|
||||||
import { isObject } from "lodash";
|
|
||||||
import useSocket from "~/utils/useSocket";
|
import useSocket from "~/utils/useSocket";
|
||||||
import { OutputStats } from "./OutputStats";
|
import { OutputStats } from "./OutputStats";
|
||||||
import { ErrorHandler } from "./ErrorHandler";
|
import { ErrorHandler } from "./ErrorHandler";
|
||||||
|
import { CellOptions } from "./CellOptions";
|
||||||
|
|
||||||
export default function OutputCell({
|
export default function OutputCell({
|
||||||
scenario,
|
scenario,
|
||||||
@@ -37,116 +36,103 @@ export default function OutputCell({
|
|||||||
// if (variant.config === null || Object.keys(variant.config).length === 0)
|
// if (variant.config === null || Object.keys(variant.config).length === 0)
|
||||||
// disabledReason = "Save your prompt variant to see output";
|
// disabledReason = "Save your prompt variant to see output";
|
||||||
|
|
||||||
const outputMutation = api.outputs.get.useMutation();
|
const [refetchInterval, setRefetchInterval] = useState(0);
|
||||||
|
const { data: cell, isLoading: queryLoading } = api.scenarioVariantCells.get.useQuery(
|
||||||
const [output, setOutput] = useState<RouterOutputs["outputs"]["get"]>(null);
|
{ scenarioId: scenario.id, variantId: variant.id },
|
||||||
const [channel, setChannel] = useState<string | undefined>(undefined);
|
{ refetchInterval },
|
||||||
const [numPreviousTries, setNumPreviousTries] = useState(0);
|
|
||||||
|
|
||||||
const fetchMutex = useRef(false);
|
|
||||||
const [fetchOutput, fetchingOutput] = useHandledAsyncCallback(
|
|
||||||
async (forceRefetch?: boolean) => {
|
|
||||||
if (fetchMutex.current) return;
|
|
||||||
setNumPreviousTries((prev) => prev + 1);
|
|
||||||
|
|
||||||
fetchMutex.current = true;
|
|
||||||
setOutput(null);
|
|
||||||
|
|
||||||
const shouldStream =
|
|
||||||
isObject(variant) &&
|
|
||||||
"config" in variant &&
|
|
||||||
isObject(variant.config) &&
|
|
||||||
"stream" in variant.config &&
|
|
||||||
variant.config.stream === true;
|
|
||||||
|
|
||||||
const channel = shouldStream ? generateChannel() : undefined;
|
|
||||||
setChannel(channel);
|
|
||||||
|
|
||||||
const output = await outputMutation.mutateAsync({
|
|
||||||
scenarioId: scenario.id,
|
|
||||||
variantId: variant.id,
|
|
||||||
channel,
|
|
||||||
forceRefetch,
|
|
||||||
});
|
|
||||||
setOutput(output);
|
|
||||||
await utils.promptVariants.stats.invalidate();
|
|
||||||
fetchMutex.current = false;
|
|
||||||
},
|
|
||||||
[outputMutation, scenario.id, variant.id],
|
|
||||||
);
|
);
|
||||||
const hardRefetch = useCallback(() => fetchOutput(true), [fetchOutput]);
|
|
||||||
|
|
||||||
useEffect(fetchOutput, [scenario.id, variant.id]);
|
const { mutateAsync: hardRefetchMutate, isLoading: refetchingOutput } =
|
||||||
|
api.scenarioVariantCells.forceRefetch.useMutation();
|
||||||
|
const [hardRefetch] = useHandledAsyncCallback(async () => {
|
||||||
|
await hardRefetchMutate({ scenarioId: scenario.id, variantId: variant.id });
|
||||||
|
await utils.scenarioVariantCells.get.invalidate({
|
||||||
|
scenarioId: scenario.id,
|
||||||
|
variantId: variant.id,
|
||||||
|
});
|
||||||
|
}, [hardRefetchMutate, scenario.id, variant.id]);
|
||||||
|
|
||||||
|
const fetchingOutput = queryLoading || refetchingOutput;
|
||||||
|
|
||||||
|
const awaitingOutput =
|
||||||
|
!cell || cell.retrievalStatus === "PENDING" || cell.retrievalStatus === "IN_PROGRESS";
|
||||||
|
useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]);
|
||||||
|
|
||||||
|
const modelOutput = cell?.modelOutput;
|
||||||
|
|
||||||
// Disconnect from socket if we're not streaming anymore
|
// Disconnect from socket if we're not streaming anymore
|
||||||
const streamedMessage = useSocket(fetchingOutput ? channel : undefined);
|
const streamedMessage = useSocket(cell?.streamingChannel);
|
||||||
const streamedContent = streamedMessage?.choices?.[0]?.message?.content;
|
const streamedContent = streamedMessage?.choices?.[0]?.message?.content;
|
||||||
|
|
||||||
if (!vars) return null;
|
if (!vars) return null;
|
||||||
|
|
||||||
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
|
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
|
||||||
|
|
||||||
if (fetchingOutput && !streamedMessage)
|
if (awaitingOutput && !streamedMessage)
|
||||||
return (
|
return (
|
||||||
<Center h="100%" w="100%">
|
<Center h="100%" w="100%">
|
||||||
<Spinner />
|
<Spinner />
|
||||||
</Center>
|
</Center>
|
||||||
);
|
);
|
||||||
|
|
||||||
if (!output && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
|
if (!cell && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
|
||||||
|
|
||||||
if (output && output.errorMessage) {
|
if (cell && cell.errorMessage) {
|
||||||
return (
|
return <ErrorHandler cell={cell} refetchOutput={hardRefetch} />;
|
||||||
<ErrorHandler
|
|
||||||
output={output}
|
|
||||||
refetchOutput={hardRefetch}
|
|
||||||
numPreviousTries={numPreviousTries}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = output?.output as unknown as ChatCompletion;
|
const response = modelOutput?.output as unknown as ChatCompletion;
|
||||||
const message = response?.choices?.[0]?.message;
|
const message = response?.choices?.[0]?.message;
|
||||||
|
|
||||||
if (output && message?.function_call) {
|
if (modelOutput && message?.function_call) {
|
||||||
const rawArgs = message.function_call.arguments ?? "null";
|
const rawArgs = message.function_call.arguments ?? "null";
|
||||||
let parsedArgs: string;
|
let parsedArgs: string;
|
||||||
try {
|
try {
|
||||||
parsedArgs = JSON.parse(rawArgs);
|
parsedArgs = JSON.parse(rawArgs);
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
} catch (e: any) {
|
} catch (e: any) {
|
||||||
parsedArgs = `Failed to parse arguments as JSON: '${rawArgs}' ERROR: ${e.message as string}`;
|
parsedArgs = `Failed to parse arguments as JSON: '${rawArgs}' ERROR: ${e.message as string}`;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box fontSize="xs" width="100%" flexWrap="wrap" overflowX="auto">
|
<Box fontSize="xs" width="100%" flexWrap="wrap" overflowX="auto">
|
||||||
<SyntaxHighlighter
|
<VStack w="full" spacing={0}>
|
||||||
customStyle={{ overflowX: "unset" }}
|
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
|
||||||
language="json"
|
<SyntaxHighlighter
|
||||||
style={docco}
|
customStyle={{ overflowX: "unset" }}
|
||||||
lineProps={{
|
language="json"
|
||||||
style: { wordBreak: "break-all", whiteSpace: "pre-wrap" },
|
style={docco}
|
||||||
}}
|
lineProps={{
|
||||||
wrapLines
|
style: { wordBreak: "break-all", whiteSpace: "pre-wrap" },
|
||||||
>
|
}}
|
||||||
{stringify(
|
wrapLines
|
||||||
{
|
>
|
||||||
function: message.function_call.name,
|
{stringify(
|
||||||
args: parsedArgs,
|
{
|
||||||
},
|
function: message.function_call.name,
|
||||||
{ maxLength: 40 },
|
args: parsedArgs,
|
||||||
)}
|
},
|
||||||
</SyntaxHighlighter>
|
{ maxLength: 40 },
|
||||||
<OutputStats model={variant.model} modelOutput={output} scenario={scenario} />
|
)}
|
||||||
|
</SyntaxHighlighter>
|
||||||
|
</VStack>
|
||||||
|
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const contentToDisplay = message?.content ?? streamedContent ?? JSON.stringify(output?.output);
|
const contentToDisplay =
|
||||||
|
message?.content ?? streamedContent ?? JSON.stringify(modelOutput?.output);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
|
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
|
||||||
{contentToDisplay}
|
<VStack w="full" alignItems="flex-start" spacing={0}>
|
||||||
{output && <OutputStats model={variant.model} modelOutput={output} scenario={scenario} />}
|
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
|
||||||
|
<Text>{contentToDisplay}</Text>
|
||||||
|
</VStack>
|
||||||
|
{modelOutput && (
|
||||||
|
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
|
||||||
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import { HStack, Icon, Text } from "@chakra-ui/react";
|
|||||||
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
|
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
|
||||||
import { CostTooltip } from "~/components/tooltip/CostTooltip";
|
import { CostTooltip } from "~/components/tooltip/CostTooltip";
|
||||||
|
|
||||||
const SHOW_COST = false;
|
const SHOW_COST = true;
|
||||||
const SHOW_TIME = false;
|
const SHOW_TIME = true;
|
||||||
|
|
||||||
export const OutputStats = ({
|
export const OutputStats = ({
|
||||||
model,
|
model,
|
||||||
@@ -35,8 +35,6 @@ export const OutputStats = ({
|
|||||||
|
|
||||||
const cost = promptCost + completionCost;
|
const cost = promptCost + completionCost;
|
||||||
|
|
||||||
if (!evals.length) return null;
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<HStack align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
|
<HStack align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
|
||||||
<HStack flex={1}>
|
<HStack flex={1}>
|
||||||
|
|||||||
@@ -13,10 +13,11 @@ import AutoResizeTextArea from "../AutoResizeTextArea";
|
|||||||
|
|
||||||
export default function ScenarioEditor({
|
export default function ScenarioEditor({
|
||||||
scenario,
|
scenario,
|
||||||
hovered,
|
...props
|
||||||
}: {
|
}: {
|
||||||
scenario: Scenario;
|
scenario: Scenario;
|
||||||
hovered: boolean;
|
hovered: boolean;
|
||||||
|
canHide: boolean;
|
||||||
}) {
|
}) {
|
||||||
const savedValues = scenario.variableValues as Record<string, string>;
|
const savedValues = scenario.variableValues as Record<string, string>;
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
@@ -92,30 +93,34 @@ export default function ScenarioEditor({
|
|||||||
onDrop={onReorder}
|
onDrop={onReorder}
|
||||||
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
|
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
|
||||||
>
|
>
|
||||||
<Stack alignSelf="flex-start" opacity={hovered ? 1 : 0} spacing={0}>
|
<Stack alignSelf="flex-start" opacity={props.hovered ? 1 : 0} spacing={0}>
|
||||||
<Tooltip label="Hide scenario" hasArrow>
|
{props.canHide && (
|
||||||
{/* for some reason the tooltip can't position itself properly relative to the icon without the wrapping box */}
|
<>
|
||||||
<Button
|
<Tooltip label="Hide scenario" hasArrow>
|
||||||
variant="unstyled"
|
{/* for some reason the tooltip can't position itself properly relative to the icon without the wrapping box */}
|
||||||
color="gray.400"
|
<Button
|
||||||
height="unset"
|
variant="unstyled"
|
||||||
width="unset"
|
color="gray.400"
|
||||||
minW="unset"
|
height="unset"
|
||||||
onClick={onHide}
|
width="unset"
|
||||||
_hover={{
|
minW="unset"
|
||||||
color: "gray.800",
|
onClick={onHide}
|
||||||
cursor: "pointer",
|
_hover={{
|
||||||
}}
|
color: "gray.800",
|
||||||
>
|
cursor: "pointer",
|
||||||
<Icon as={hidingInProgress ? Spinner : BsX} boxSize={6} />
|
}}
|
||||||
</Button>
|
>
|
||||||
</Tooltip>
|
<Icon as={hidingInProgress ? Spinner : BsX} boxSize={6} />
|
||||||
<Icon
|
</Button>
|
||||||
as={RiDraggable}
|
</Tooltip>
|
||||||
boxSize={6}
|
<Icon
|
||||||
color="gray.400"
|
as={RiDraggable}
|
||||||
_hover={{ color: "gray.800", cursor: "pointer" }}
|
boxSize={6}
|
||||||
/>
|
color="gray.400"
|
||||||
|
_hover={{ color: "gray.800", cursor: "pointer" }}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
</Stack>
|
</Stack>
|
||||||
{variableLabels.length === 0 ? (
|
{variableLabels.length === 0 ? (
|
||||||
<Box color="gray.500">{vars.data ? "No scenario variables configured" : "Loading..."}</Box>
|
<Box color="gray.500">{vars.data ? "No scenario variables configured" : "Loading..."}</Box>
|
||||||
|
|||||||
@@ -5,7 +5,11 @@ import OutputCell from "./OutputCell/OutputCell";
|
|||||||
import ScenarioEditor from "./ScenarioEditor";
|
import ScenarioEditor from "./ScenarioEditor";
|
||||||
import type { PromptVariant, Scenario } from "./types";
|
import type { PromptVariant, Scenario } from "./types";
|
||||||
|
|
||||||
const ScenarioRow = (props: { scenario: Scenario; variants: PromptVariant[] }) => {
|
const ScenarioRow = (props: {
|
||||||
|
scenario: Scenario;
|
||||||
|
variants: PromptVariant[];
|
||||||
|
canHide: boolean;
|
||||||
|
}) => {
|
||||||
const [isHovered, setIsHovered] = useState(false);
|
const [isHovered, setIsHovered] = useState(false);
|
||||||
|
|
||||||
const highlightStyle = { backgroundColor: "gray.50" };
|
const highlightStyle = { backgroundColor: "gray.50" };
|
||||||
@@ -18,7 +22,7 @@ const ScenarioRow = (props: { scenario: Scenario; variants: PromptVariant[] }) =
|
|||||||
sx={isHovered ? highlightStyle : undefined}
|
sx={isHovered ? highlightStyle : undefined}
|
||||||
borderLeftWidth={1}
|
borderLeftWidth={1}
|
||||||
>
|
>
|
||||||
<ScenarioEditor scenario={props.scenario} hovered={isHovered} />
|
<ScenarioEditor scenario={props.scenario} hovered={isHovered} canHide={props.canHide} />
|
||||||
</GridItem>
|
</GridItem>
|
||||||
{props.variants.map((variant) => (
|
{props.variants.map((variant) => (
|
||||||
<GridItem
|
<GridItem
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ import { type PromptVariant } from "./types";
|
|||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useAppStore } from "~/state/store";
|
import { useAppStore } from "~/state/store";
|
||||||
import { editorBackground } from "~/state/sharedVariantEditor.slice";
|
import { editorBackground } from "~/state/sharedVariantEditor.slice";
|
||||||
export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
|
||||||
|
export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||||
const monaco = useAppStore.use.sharedVariantEditor.monaco();
|
const monaco = useAppStore.use.sharedVariantEditor.monaco();
|
||||||
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
|
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
|
||||||
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
|
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
|
||||||
@@ -17,10 +18,12 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
|||||||
|
|
||||||
const checkForChanges = useCallback(() => {
|
const checkForChanges = useCallback(() => {
|
||||||
if (!editorRef.current) return;
|
if (!editorRef.current) return;
|
||||||
const currentConfig = editorRef.current.getValue();
|
const currentFn = editorRef.current.getValue();
|
||||||
setIsChanged(currentConfig !== lastSavedFn);
|
setIsChanged(currentFn.length > 0 && currentFn !== lastSavedFn);
|
||||||
}, [lastSavedFn]);
|
}, [lastSavedFn]);
|
||||||
|
|
||||||
|
useEffect(checkForChanges, [checkForChanges, lastSavedFn]);
|
||||||
|
|
||||||
const replaceVariant = api.promptVariants.replaceVariant.useMutation();
|
const replaceVariant = api.promptVariants.replaceVariant.useMutation();
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
const toast = useToast();
|
const toast = useToast();
|
||||||
@@ -75,9 +78,9 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
await utils.promptVariants.list.invalidate();
|
setIsChanged(false);
|
||||||
|
|
||||||
checkForChanges();
|
await utils.promptVariants.list.invalidate();
|
||||||
}, [checkForChanges]);
|
}, [checkForChanges]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import { RiDraggable } from "react-icons/ri";
|
|||||||
import { cellPadding, headerMinHeight } from "../constants";
|
import { cellPadding, headerMinHeight } from "../constants";
|
||||||
import AutoResizeTextArea from "../AutoResizeTextArea";
|
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||||
|
|
||||||
export default function VariantHeader(props: { variant: PromptVariant }) {
|
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
const [isDragTarget, setIsDragTarget] = useState(false);
|
const [isDragTarget, setIsDragTarget] = useState(false);
|
||||||
const [isInputHovered, setIsInputHovered] = useState(false);
|
const [isInputHovered, setIsInputHovered] = useState(false);
|
||||||
@@ -95,11 +95,13 @@ export default function VariantHeader(props: { variant: PromptVariant }) {
|
|||||||
onMouseEnter={() => setIsInputHovered(true)}
|
onMouseEnter={() => setIsInputHovered(true)}
|
||||||
onMouseLeave={() => setIsInputHovered(false)}
|
onMouseLeave={() => setIsInputHovered(false)}
|
||||||
/>
|
/>
|
||||||
<Tooltip label="Hide Variant" hasArrow>
|
{props.canHide && (
|
||||||
<Button variant="ghost" colorScheme="gray" size="sm" onClick={onHide}>
|
<Tooltip label="Remove Variant" hasArrow>
|
||||||
<Icon as={BsX} boxSize={6} />
|
<Button variant="ghost" colorScheme="gray" size="sm" onClick={onHide}>
|
||||||
</Button>
|
<Icon as={BsX} boxSize={6} />
|
||||||
</Tooltip>
|
</Button>
|
||||||
|
</Tooltip>
|
||||||
|
)}
|
||||||
</HStack>
|
</HStack>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import { api } from "~/utils/api";
|
|||||||
import NewScenarioButton from "./NewScenarioButton";
|
import NewScenarioButton from "./NewScenarioButton";
|
||||||
import NewVariantButton from "./NewVariantButton";
|
import NewVariantButton from "./NewVariantButton";
|
||||||
import ScenarioRow from "./ScenarioRow";
|
import ScenarioRow from "./ScenarioRow";
|
||||||
import VariantConfigEditor from "./VariantEditor";
|
import VariantEditor from "./VariantEditor";
|
||||||
import VariantHeader from "./VariantHeader";
|
import VariantHeader from "./VariantHeader";
|
||||||
import { cellPadding } from "../constants";
|
import { cellPadding } from "../constants";
|
||||||
import { BsPencil } from "react-icons/bs";
|
import { BsPencil } from "react-icons/bs";
|
||||||
@@ -78,7 +78,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
|||||||
|
|
||||||
{variants.data.map((variant) => (
|
{variants.data.map((variant) => (
|
||||||
<GridItem key={variant.uiId} padding={0} sx={stickyHeaderStyle} borderTopWidth={1}>
|
<GridItem key={variant.uiId} padding={0} sx={stickyHeaderStyle} borderTopWidth={1}>
|
||||||
<VariantHeader variant={variant} />
|
<VariantHeader variant={variant} canHide={variants.data.length > 1} />
|
||||||
</GridItem>
|
</GridItem>
|
||||||
))}
|
))}
|
||||||
<GridItem
|
<GridItem
|
||||||
@@ -94,7 +94,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
|||||||
|
|
||||||
{variants.data.map((variant) => (
|
{variants.data.map((variant) => (
|
||||||
<GridItem key={variant.uiId}>
|
<GridItem key={variant.uiId}>
|
||||||
<VariantConfigEditor variant={variant} />
|
<VariantEditor variant={variant} />
|
||||||
</GridItem>
|
</GridItem>
|
||||||
))}
|
))}
|
||||||
{variants.data.map((variant) => (
|
{variants.data.map((variant) => (
|
||||||
@@ -103,7 +103,12 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
|||||||
</GridItem>
|
</GridItem>
|
||||||
))}
|
))}
|
||||||
{scenarios.data.map((scenario) => (
|
{scenarios.data.map((scenario) => (
|
||||||
<ScenarioRow key={scenario.uiId} scenario={scenario} variants={variants.data} />
|
<ScenarioRow
|
||||||
|
key={scenario.uiId}
|
||||||
|
scenario={scenario}
|
||||||
|
variants={variants.data}
|
||||||
|
canHide={scenarios.data.length > 1}
|
||||||
|
/>
|
||||||
))}
|
))}
|
||||||
<GridItem borderBottomWidth={0} borderRightWidth={0} w="100%" colSpan={allCols} padding={0}>
|
<GridItem borderBottomWidth={0} borderRightWidth={0} w="100%" colSpan={allCols} padding={0}>
|
||||||
<NewScenarioButton />
|
<NewScenarioButton />
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import { type GetServerSideProps } from "next";
|
import { type GetServerSideProps } from "next";
|
||||||
|
|
||||||
// eslint-disable-next-line @typescript-eslint/require-await
|
// eslint-disable-next-line @typescript-eslint/require-await
|
||||||
export const getServerSideProps: GetServerSideProps = async (context) => {
|
export const getServerSideProps: GetServerSideProps = async () => {
|
||||||
return {
|
return {
|
||||||
redirect: {
|
redirect: {
|
||||||
destination: "/experiments",
|
destination: "/experiments",
|
||||||
|
|||||||
@@ -3,10 +3,6 @@ import { prisma } from "../db";
|
|||||||
import { openai } from "../utils/openai";
|
import { openai } from "../utils/openai";
|
||||||
import { pick } from "lodash";
|
import { pick } from "lodash";
|
||||||
|
|
||||||
function promptHasVariable(prompt: string, variableName: string) {
|
|
||||||
return prompt.includes(`{{${variableName}}}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
type AxiosError = {
|
type AxiosError = {
|
||||||
response?: {
|
response?: {
|
||||||
data?: {
|
data?: {
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import { promptVariantsRouter } from "~/server/api/routers/promptVariants.router
|
|||||||
import { createTRPCRouter } from "~/server/api/trpc";
|
import { createTRPCRouter } from "~/server/api/trpc";
|
||||||
import { experimentsRouter } from "./routers/experiments.router";
|
import { experimentsRouter } from "./routers/experiments.router";
|
||||||
import { scenariosRouter } from "./routers/scenarios.router";
|
import { scenariosRouter } from "./routers/scenarios.router";
|
||||||
import { modelOutputsRouter } from "./routers/modelOutputs.router";
|
import { scenarioVariantCellsRouter } from "./routers/scenarioVariantCells.router";
|
||||||
import { templateVarsRouter } from "./routers/templateVariables.router";
|
import { templateVarsRouter } from "./routers/templateVariables.router";
|
||||||
import { evaluationsRouter } from "./routers/evaluations.router";
|
import { evaluationsRouter } from "./routers/evaluations.router";
|
||||||
|
|
||||||
@@ -15,7 +15,7 @@ export const appRouter = createTRPCRouter({
|
|||||||
promptVariants: promptVariantsRouter,
|
promptVariants: promptVariantsRouter,
|
||||||
experiments: experimentsRouter,
|
experiments: experimentsRouter,
|
||||||
scenarios: scenariosRouter,
|
scenarios: scenariosRouter,
|
||||||
outputs: modelOutputsRouter,
|
scenarioVariantCells: scenarioVariantCellsRouter,
|
||||||
templateVars: templateVarsRouter,
|
templateVars: templateVarsRouter,
|
||||||
evaluations: evaluationsRouter,
|
evaluations: evaluationsRouter,
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import { z } from "zod";
|
|||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import dedent from "dedent";
|
import dedent from "dedent";
|
||||||
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
|
|
||||||
export const experimentsRouter = createTRPCRouter({
|
export const experimentsRouter = createTRPCRouter({
|
||||||
list: publicProcedure.query(async () => {
|
list: publicProcedure.query(async () => {
|
||||||
@@ -64,7 +65,7 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await prisma.$transaction([
|
const [variant, _, scenario] = await prisma.$transaction([
|
||||||
prisma.promptVariant.create({
|
prisma.promptVariant.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: exp.id,
|
experimentId: exp.id,
|
||||||
@@ -73,19 +74,29 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
constructFn: dedent`prompt = {
|
constructFn: dedent`prompt = {
|
||||||
model: "gpt-3.5-turbo-0613",
|
model: "gpt-3.5-turbo-0613",
|
||||||
stream: true,
|
stream: true,
|
||||||
messages: [{ role: "system", content: "Return 'Ready to go!'" }],
|
messages: [{ role: "system", content: ${"`Return '${scenario.text}'`"} }],
|
||||||
}`,
|
}`,
|
||||||
model: "gpt-3.5-turbo-0613",
|
model: "gpt-3.5-turbo-0613",
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
|
prisma.templateVariable.create({
|
||||||
|
data: {
|
||||||
|
experimentId: exp.id,
|
||||||
|
label: "text",
|
||||||
|
},
|
||||||
|
}),
|
||||||
prisma.testScenario.create({
|
prisma.testScenario.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: exp.id,
|
experimentId: exp.id,
|
||||||
variableValues: {},
|
variableValues: {
|
||||||
|
text: "This is a test scenario.",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
await generateNewCell(variant.id, scenario.id);
|
||||||
|
|
||||||
return exp;
|
return exp;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
|
|||||||
@@ -1,101 +0,0 @@
|
|||||||
import { z } from "zod";
|
|
||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
|
||||||
import { prisma } from "~/server/db";
|
|
||||||
import crypto from "crypto";
|
|
||||||
import type { Prisma } from "@prisma/client";
|
|
||||||
import { reevaluateVariant } from "~/server/utils/evaluations";
|
|
||||||
import { getCompletion } from "~/server/utils/getCompletion";
|
|
||||||
import { constructPrompt } from "~/server/utils/constructPrompt";
|
|
||||||
import { type CompletionCreateParams } from "openai/resources/chat";
|
|
||||||
|
|
||||||
export const modelOutputsRouter = createTRPCRouter({
|
|
||||||
get: publicProcedure
|
|
||||||
.input(
|
|
||||||
z.object({
|
|
||||||
scenarioId: z.string(),
|
|
||||||
variantId: z.string(),
|
|
||||||
channel: z.string().optional(),
|
|
||||||
forceRefetch: z.boolean().optional(),
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.mutation(async ({ input }) => {
|
|
||||||
const existing = await prisma.modelOutput.findUnique({
|
|
||||||
where: {
|
|
||||||
promptVariantId_testScenarioId: {
|
|
||||||
promptVariantId: input.variantId,
|
|
||||||
testScenarioId: input.scenarioId,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
if (existing && !input.forceRefetch) return existing;
|
|
||||||
|
|
||||||
const variant = await prisma.promptVariant.findUnique({
|
|
||||||
where: {
|
|
||||||
id: input.variantId,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const scenario = await prisma.testScenario.findUnique({
|
|
||||||
where: {
|
|
||||||
id: input.scenarioId,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!variant || !scenario) return null;
|
|
||||||
|
|
||||||
const prompt = await constructPrompt(variant, scenario.variableValues);
|
|
||||||
|
|
||||||
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
|
|
||||||
|
|
||||||
// TODO: we should probably only use this if temperature=0
|
|
||||||
const existingResponse = await prisma.modelOutput.findFirst({
|
|
||||||
where: { inputHash, errorMessage: null },
|
|
||||||
});
|
|
||||||
|
|
||||||
let modelResponse: Awaited<ReturnType<typeof getCompletion>>;
|
|
||||||
|
|
||||||
if (existingResponse) {
|
|
||||||
modelResponse = {
|
|
||||||
output: existingResponse.output as Prisma.InputJsonValue,
|
|
||||||
statusCode: existingResponse.statusCode,
|
|
||||||
errorMessage: existingResponse.errorMessage,
|
|
||||||
timeToComplete: existingResponse.timeToComplete,
|
|
||||||
promptTokens: existingResponse.promptTokens ?? undefined,
|
|
||||||
completionTokens: existingResponse.completionTokens ?? undefined,
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
try {
|
|
||||||
modelResponse = await getCompletion(
|
|
||||||
prompt as unknown as CompletionCreateParams,
|
|
||||||
input.channel,
|
|
||||||
);
|
|
||||||
} catch (e) {
|
|
||||||
console.error(e);
|
|
||||||
throw e;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const modelOutput = await prisma.modelOutput.upsert({
|
|
||||||
where: {
|
|
||||||
promptVariantId_testScenarioId: {
|
|
||||||
promptVariantId: input.variantId,
|
|
||||||
testScenarioId: input.scenarioId,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
create: {
|
|
||||||
promptVariantId: input.variantId,
|
|
||||||
testScenarioId: input.scenarioId,
|
|
||||||
inputHash,
|
|
||||||
...modelResponse,
|
|
||||||
},
|
|
||||||
update: {
|
|
||||||
...modelResponse,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
await reevaluateVariant(input.variantId);
|
|
||||||
|
|
||||||
return modelOutput;
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
@@ -1,7 +1,9 @@
|
|||||||
|
import dedent from "dedent";
|
||||||
import { isObject } from "lodash";
|
import { isObject } from "lodash";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
import { OpenAIChatModel } from "~/server/types";
|
import { OpenAIChatModel } from "~/server/types";
|
||||||
import { constructPrompt } from "~/server/utils/constructPrompt";
|
import { constructPrompt } from "~/server/utils/constructPrompt";
|
||||||
import userError from "~/server/utils/error";
|
import userError from "~/server/utils/error";
|
||||||
@@ -43,17 +45,24 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
visible: true,
|
visible: true,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
const outputCount = await prisma.modelOutput.count({
|
const outputCount = await prisma.scenarioVariantCell.count({
|
||||||
where: {
|
where: {
|
||||||
promptVariantId: input.variantId,
|
promptVariantId: input.variantId,
|
||||||
testScenario: { visible: true },
|
testScenario: { visible: true },
|
||||||
|
modelOutput: {
|
||||||
|
isNot: null,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const overallTokens = await prisma.modelOutput.aggregate({
|
const overallTokens = await prisma.modelOutput.aggregate({
|
||||||
where: {
|
where: {
|
||||||
promptVariantId: input.variantId,
|
scenarioVariantCell: {
|
||||||
testScenario: { visible: true },
|
promptVariantId: input.variantId,
|
||||||
|
testScenario: {
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
_sum: {
|
_sum: {
|
||||||
promptTokens: true,
|
promptTokens: true,
|
||||||
@@ -105,7 +114,18 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
label: `Prompt Variant ${largestSortIndex + 2}`,
|
label: `Prompt Variant ${largestSortIndex + 2}`,
|
||||||
sortIndex: (lastVariant?.sortIndex ?? 0) + 1,
|
sortIndex: (lastVariant?.sortIndex ?? 0) + 1,
|
||||||
constructFn: lastVariant?.constructFn ?? "",
|
constructFn:
|
||||||
|
lastVariant?.constructFn ??
|
||||||
|
dedent`
|
||||||
|
prompt = {
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content: "Return 'Hello, world!'",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
model: lastVariant?.model ?? "gpt-3.5-turbo",
|
model: lastVariant?.model ?? "gpt-3.5-turbo",
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
@@ -115,6 +135,17 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
recordExperimentUpdated(input.experimentId),
|
recordExperimentUpdated(input.experimentId),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
const scenarios = await prisma.testScenario.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: input.experimentId,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
for (const scenario of scenarios) {
|
||||||
|
await generateNewCell(newVariant.id, scenario.id);
|
||||||
|
}
|
||||||
|
|
||||||
return newVariant;
|
return newVariant;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
@@ -234,6 +265,17 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
|
|
||||||
await prisma.$transaction([hideOldVariants, recordExperimentUpdated(existing.experimentId)]);
|
await prisma.$transaction([hideOldVariants, recordExperimentUpdated(existing.experimentId)]);
|
||||||
|
|
||||||
|
const scenarios = await prisma.testScenario.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: newVariant.experimentId,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
for (const scenario of scenarios) {
|
||||||
|
await generateNewCell(newVariant.id, scenario.id);
|
||||||
|
}
|
||||||
|
|
||||||
return { status: "ok" } as const;
|
return { status: "ok" } as const;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
|
|||||||
68
src/server/api/routers/scenarioVariantCells.router.ts
Normal file
68
src/server/api/routers/scenarioVariantCells.router.ts
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
import { z } from "zod";
|
||||||
|
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||||
|
import { prisma } from "~/server/db";
|
||||||
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
|
import { queueLLMRetrievalTask } from "~/server/utils/queueLLMRetrievalTask";
|
||||||
|
|
||||||
|
export const scenarioVariantCellsRouter = createTRPCRouter({
|
||||||
|
get: publicProcedure
|
||||||
|
.input(
|
||||||
|
z.object({
|
||||||
|
scenarioId: z.string(),
|
||||||
|
variantId: z.string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.query(async ({ input }) => {
|
||||||
|
return await prisma.scenarioVariantCell.findUnique({
|
||||||
|
where: {
|
||||||
|
promptVariantId_testScenarioId: {
|
||||||
|
promptVariantId: input.variantId,
|
||||||
|
testScenarioId: input.scenarioId,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
include: {
|
||||||
|
modelOutput: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}),
|
||||||
|
forceRefetch: publicProcedure
|
||||||
|
.input(
|
||||||
|
z.object({
|
||||||
|
scenarioId: z.string(),
|
||||||
|
variantId: z.string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.mutation(async ({ input }) => {
|
||||||
|
const cell = await prisma.scenarioVariantCell.findUnique({
|
||||||
|
where: {
|
||||||
|
promptVariantId_testScenarioId: {
|
||||||
|
promptVariantId: input.variantId,
|
||||||
|
testScenarioId: input.scenarioId,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
include: {
|
||||||
|
modelOutput: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!cell) {
|
||||||
|
await generateNewCell(input.variantId, input.scenarioId);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cell.modelOutput) {
|
||||||
|
// TODO: Maybe keep these around to show previous generations?
|
||||||
|
await prisma.modelOutput.delete({
|
||||||
|
where: { id: cell.modelOutput.id },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: cell.id },
|
||||||
|
data: { retrievalStatus: "PENDING" },
|
||||||
|
});
|
||||||
|
|
||||||
|
await queueLLMRetrievalTask(cell.id);
|
||||||
|
return true;
|
||||||
|
}),
|
||||||
|
});
|
||||||
@@ -4,6 +4,7 @@ import { prisma } from "~/server/db";
|
|||||||
import { autogenerateScenarioValues } from "../autogen";
|
import { autogenerateScenarioValues } from "../autogen";
|
||||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||||
import { reevaluateAll } from "~/server/utils/evaluations";
|
import { reevaluateAll } from "~/server/utils/evaluations";
|
||||||
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
|
|
||||||
export const scenariosRouter = createTRPCRouter({
|
export const scenariosRouter = createTRPCRouter({
|
||||||
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
||||||
@@ -48,10 +49,21 @@ export const scenariosRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await prisma.$transaction([
|
const [scenario] = await prisma.$transaction([
|
||||||
createNewScenarioAction,
|
createNewScenarioAction,
|
||||||
recordExperimentUpdated(input.experimentId),
|
recordExperimentUpdated(input.experimentId),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
const promptVariants = await prisma.promptVariant.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: input.experimentId,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
for (const variant of promptVariants) {
|
||||||
|
await generateNewCell(variant.id, scenario.id);
|
||||||
|
}
|
||||||
}),
|
}),
|
||||||
|
|
||||||
hide: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
|
hide: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
|
||||||
@@ -175,6 +187,17 @@ export const scenariosRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const promptVariants = await prisma.promptVariant.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: newScenario.experimentId,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
for (const variant of promptVariants) {
|
||||||
|
await generateNewCell(variant.id, newScenario.id);
|
||||||
|
}
|
||||||
|
|
||||||
return newScenario;
|
return newScenario;
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|||||||
47
src/server/scripts/migrateScenarioVariantOutputData.ts
Normal file
47
src/server/scripts/migrateScenarioVariantOutputData.ts
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
import { type Prisma } from "@prisma/client";
|
||||||
|
import { prisma } from "../db";
|
||||||
|
|
||||||
|
async function migrateScenarioVariantOutputData() {
|
||||||
|
// Get all ScenarioVariantCells
|
||||||
|
const cells = await prisma.scenarioVariantCell.findMany({ include: { modelOutput: true } });
|
||||||
|
console.log(`Found ${cells.length} records`);
|
||||||
|
|
||||||
|
let updatedCount = 0;
|
||||||
|
|
||||||
|
// Loop through all scenarioVariants
|
||||||
|
for (const cell of cells) {
|
||||||
|
// Create a new ModelOutput for each ScenarioVariant with an existing output
|
||||||
|
if (cell.output && !cell.modelOutput) {
|
||||||
|
updatedCount++;
|
||||||
|
await prisma.modelOutput.create({
|
||||||
|
data: {
|
||||||
|
scenarioVariantCellId: cell.id,
|
||||||
|
inputHash: cell.inputHash || "",
|
||||||
|
output: cell.output as Prisma.InputJsonValue,
|
||||||
|
timeToComplete: cell.timeToComplete ?? undefined,
|
||||||
|
promptTokens: cell.promptTokens,
|
||||||
|
completionTokens: cell.completionTokens,
|
||||||
|
createdAt: cell.createdAt,
|
||||||
|
updatedAt: cell.updatedAt,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else if (cell.errorMessage && cell.retrievalStatus === "COMPLETE") {
|
||||||
|
updatedCount++;
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: cell.id },
|
||||||
|
data: {
|
||||||
|
retrievalStatus: "ERROR",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log("Data migration completed");
|
||||||
|
console.log(`Updated ${updatedCount} records`);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the function
|
||||||
|
migrateScenarioVariantOutputData().catch((error) => {
|
||||||
|
console.error("An error occurred while migrating data: ", error);
|
||||||
|
process.exit(1);
|
||||||
|
});
|
||||||
31
src/server/tasks/defineTask.ts
Normal file
31
src/server/tasks/defineTask.ts
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
// Import necessary dependencies
|
||||||
|
import { quickAddJob, type Helpers, type Task } from "graphile-worker";
|
||||||
|
import { env } from "~/env.mjs";
|
||||||
|
|
||||||
|
// Define the defineTask function
|
||||||
|
function defineTask<TPayload>(
|
||||||
|
taskIdentifier: string,
|
||||||
|
taskHandler: (payload: TPayload, helpers: Helpers) => Promise<void>,
|
||||||
|
) {
|
||||||
|
const enqueue = async (payload: TPayload) => {
|
||||||
|
console.log("Enqueuing task", taskIdentifier, payload);
|
||||||
|
await quickAddJob({ connectionString: env.DATABASE_URL }, taskIdentifier, payload);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handler = (payload: TPayload, helpers: Helpers) => {
|
||||||
|
helpers.logger.info(`Running task ${taskIdentifier} with payload: ${JSON.stringify(payload)}`);
|
||||||
|
return taskHandler(payload, helpers);
|
||||||
|
};
|
||||||
|
|
||||||
|
const task = {
|
||||||
|
identifier: taskIdentifier,
|
||||||
|
handler: handler as Task,
|
||||||
|
};
|
||||||
|
|
||||||
|
return {
|
||||||
|
enqueue,
|
||||||
|
task,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export default defineTask;
|
||||||
144
src/server/tasks/queryLLM.task.ts
Normal file
144
src/server/tasks/queryLLM.task.ts
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
import crypto from "crypto";
|
||||||
|
import { prisma } from "~/server/db";
|
||||||
|
import defineTask from "./defineTask";
|
||||||
|
import { type CompletionResponse, getCompletion } from "../utils/getCompletion";
|
||||||
|
import { type JSONSerializable } from "../types";
|
||||||
|
import { sleep } from "../utils/sleep";
|
||||||
|
import { shouldStream } from "../utils/shouldStream";
|
||||||
|
import { generateChannel } from "~/utils/generateChannel";
|
||||||
|
import { reevaluateVariant } from "../utils/evaluations";
|
||||||
|
import { constructPrompt } from "../utils/constructPrompt";
|
||||||
|
import { type CompletionCreateParams } from "openai/resources/chat";
|
||||||
|
|
||||||
|
const MAX_AUTO_RETRIES = 10;
|
||||||
|
const MIN_DELAY = 500; // milliseconds
|
||||||
|
const MAX_DELAY = 15000; // milliseconds
|
||||||
|
|
||||||
|
function calculateDelay(numPreviousTries: number): number {
|
||||||
|
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
|
||||||
|
const jitter = Math.random() * baseDelay;
|
||||||
|
return baseDelay + jitter;
|
||||||
|
}
|
||||||
|
|
||||||
|
const getCompletionWithRetries = async (
|
||||||
|
cellId: string,
|
||||||
|
payload: JSONSerializable,
|
||||||
|
channel?: string,
|
||||||
|
): Promise<CompletionResponse> => {
|
||||||
|
for (let i = 0; i < MAX_AUTO_RETRIES; i++) {
|
||||||
|
const modelResponse = await getCompletion(
|
||||||
|
payload as unknown as CompletionCreateParams,
|
||||||
|
channel,
|
||||||
|
);
|
||||||
|
if (modelResponse.statusCode !== 429 || i === MAX_AUTO_RETRIES - 1) {
|
||||||
|
return modelResponse;
|
||||||
|
}
|
||||||
|
const delay = calculateDelay(i);
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: cellId },
|
||||||
|
data: {
|
||||||
|
errorMessage: "Rate limit exceeded",
|
||||||
|
statusCode: 429,
|
||||||
|
retryTime: new Date(Date.now() + delay),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
// TODO: Maybe requeue the job so other jobs can run in the future?
|
||||||
|
await sleep(delay);
|
||||||
|
}
|
||||||
|
throw new Error("Max retries limit reached");
|
||||||
|
};
|
||||||
|
|
||||||
|
export type queryLLMJob = {
|
||||||
|
scenarioVariantCellId: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||||
|
const { scenarioVariantCellId } = task;
|
||||||
|
const cell = await prisma.scenarioVariantCell.findUnique({
|
||||||
|
where: { id: scenarioVariantCellId },
|
||||||
|
include: { modelOutput: true },
|
||||||
|
});
|
||||||
|
if (!cell) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If cell is not pending, then some other job is already processing it
|
||||||
|
if (cell.retrievalStatus !== "PENDING") {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: scenarioVariantCellId },
|
||||||
|
data: {
|
||||||
|
retrievalStatus: "IN_PROGRESS",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const variant = await prisma.promptVariant.findUnique({
|
||||||
|
where: { id: cell.promptVariantId },
|
||||||
|
});
|
||||||
|
if (!variant) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const scenario = await prisma.testScenario.findUnique({
|
||||||
|
where: { id: cell.testScenarioId },
|
||||||
|
});
|
||||||
|
if (!scenario) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const prompt = await constructPrompt(variant, scenario.variableValues);
|
||||||
|
|
||||||
|
const streamingEnabled = shouldStream(prompt);
|
||||||
|
let streamingChannel;
|
||||||
|
|
||||||
|
if (streamingEnabled) {
|
||||||
|
streamingChannel = generateChannel();
|
||||||
|
// Save streaming channel so that UI can connect to it
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: scenarioVariantCellId },
|
||||||
|
data: {
|
||||||
|
streamingChannel,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const modelResponse = await getCompletionWithRetries(
|
||||||
|
scenarioVariantCellId,
|
||||||
|
prompt,
|
||||||
|
streamingChannel,
|
||||||
|
);
|
||||||
|
|
||||||
|
let modelOutput = null;
|
||||||
|
if (modelResponse.statusCode === 200) {
|
||||||
|
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
|
||||||
|
|
||||||
|
modelOutput = await prisma.modelOutput.create({
|
||||||
|
data: {
|
||||||
|
scenarioVariantCellId,
|
||||||
|
inputHash,
|
||||||
|
output: modelResponse.output,
|
||||||
|
timeToComplete: modelResponse.timeToComplete,
|
||||||
|
promptTokens: modelResponse.promptTokens,
|
||||||
|
completionTokens: modelResponse.completionTokens,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: scenarioVariantCellId },
|
||||||
|
data: {
|
||||||
|
statusCode: modelResponse.statusCode,
|
||||||
|
errorMessage: modelResponse.errorMessage,
|
||||||
|
streamingChannel: null,
|
||||||
|
retrievalStatus: modelOutput ? "COMPLETE" : "ERROR",
|
||||||
|
modelOutput: {
|
||||||
|
connect: {
|
||||||
|
id: modelOutput?.id,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
await reevaluateVariant(cell.promptVariantId);
|
||||||
|
});
|
||||||
40
src/server/tasks/worker.ts
Normal file
40
src/server/tasks/worker.ts
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import { type TaskList, run } from "graphile-worker";
|
||||||
|
import "dotenv/config";
|
||||||
|
|
||||||
|
import { env } from "~/env.mjs";
|
||||||
|
import { queryLLM } from "./queryLLM.task";
|
||||||
|
|
||||||
|
const registeredTasks = [queryLLM];
|
||||||
|
|
||||||
|
const taskList = registeredTasks.reduce((acc, task) => {
|
||||||
|
acc[task.task.identifier] = task.task.handler;
|
||||||
|
return acc;
|
||||||
|
}, {} as TaskList);
|
||||||
|
|
||||||
|
async function main() {
|
||||||
|
// Run a worker to execute jobs:
|
||||||
|
const runner = await run({
|
||||||
|
connectionString: env.DATABASE_URL,
|
||||||
|
concurrency: 20,
|
||||||
|
// Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc
|
||||||
|
noHandleSignals: false,
|
||||||
|
pollInterval: 1000,
|
||||||
|
// you can set the taskList or taskDirectory but not both
|
||||||
|
taskList,
|
||||||
|
// or:
|
||||||
|
// taskDirectory: `${__dirname}/tasks`,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Immediately await (or otherwise handled) the resulting promise, to avoid
|
||||||
|
// "unhandled rejection" errors causing a process crash in the event of
|
||||||
|
// something going wrong.
|
||||||
|
await runner.promise;
|
||||||
|
|
||||||
|
// If the worker exits (whether through fatal error or otherwise), the above
|
||||||
|
// promise will resolve/reject.
|
||||||
|
}
|
||||||
|
|
||||||
|
main().catch((err) => {
|
||||||
|
console.error("Unhandled error occurred running worker: ", err);
|
||||||
|
process.exit(1);
|
||||||
|
});
|
||||||
@@ -9,7 +9,7 @@ export async function constructPrompt(
|
|||||||
scenario: TestScenario["variableValues"],
|
scenario: TestScenario["variableValues"],
|
||||||
): Promise<JSONSerializable> {
|
): Promise<JSONSerializable> {
|
||||||
const code = `
|
const code = `
|
||||||
const scenario = ${JSON.stringify(scenario, null, 2)};
|
const scenario = ${JSON.stringify(scenario ?? {}, null, 2)};
|
||||||
let prompt
|
let prompt
|
||||||
|
|
||||||
${variant.constructFn}
|
${variant.constructFn}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { type Evaluation } from "@prisma/client";
|
import { type ModelOutput, type Evaluation } from "@prisma/client";
|
||||||
import { prisma } from "../db";
|
import { prisma } from "../db";
|
||||||
import { evaluateOutput } from "./evaluateOutput";
|
import { evaluateOutput } from "./evaluateOutput";
|
||||||
|
|
||||||
@@ -12,21 +12,22 @@ export const reevaluateVariant = async (variantId: string) => {
|
|||||||
where: { experimentId: variant.experimentId },
|
where: { experimentId: variant.experimentId },
|
||||||
});
|
});
|
||||||
|
|
||||||
const modelOutputs = await prisma.modelOutput.findMany({
|
const cells = await prisma.scenarioVariantCell.findMany({
|
||||||
where: {
|
where: {
|
||||||
promptVariantId: variantId,
|
promptVariantId: variantId,
|
||||||
statusCode: { notIn: [429] },
|
retrievalStatus: "COMPLETE",
|
||||||
testScenario: { visible: true },
|
testScenario: { visible: true },
|
||||||
|
modelOutput: { isNot: null },
|
||||||
},
|
},
|
||||||
include: { testScenario: true },
|
include: { testScenario: true, modelOutput: true },
|
||||||
});
|
});
|
||||||
|
|
||||||
await Promise.all(
|
await Promise.all(
|
||||||
evaluations.map(async (evaluation) => {
|
evaluations.map(async (evaluation) => {
|
||||||
const passCount = modelOutputs.filter((output) =>
|
const passCount = cells.filter((cell) =>
|
||||||
evaluateOutput(output, output.testScenario, evaluation),
|
evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
|
||||||
).length;
|
).length;
|
||||||
const failCount = modelOutputs.length - passCount;
|
const failCount = cells.length - passCount;
|
||||||
|
|
||||||
await prisma.evaluationResult.upsert({
|
await prisma.evaluationResult.upsert({
|
||||||
where: {
|
where: {
|
||||||
@@ -55,22 +56,23 @@ export const reevaluateEvaluation = async (evaluation: Evaluation) => {
|
|||||||
where: { experimentId: evaluation.experimentId, visible: true },
|
where: { experimentId: evaluation.experimentId, visible: true },
|
||||||
});
|
});
|
||||||
|
|
||||||
const modelOutputs = await prisma.modelOutput.findMany({
|
const cells = await prisma.scenarioVariantCell.findMany({
|
||||||
where: {
|
where: {
|
||||||
promptVariantId: { in: variants.map((v) => v.id) },
|
promptVariantId: { in: variants.map((v) => v.id) },
|
||||||
testScenario: { visible: true },
|
testScenario: { visible: true },
|
||||||
statusCode: { notIn: [429] },
|
statusCode: { notIn: [429] },
|
||||||
|
modelOutput: { isNot: null },
|
||||||
},
|
},
|
||||||
include: { testScenario: true },
|
include: { testScenario: true, modelOutput: true },
|
||||||
});
|
});
|
||||||
|
|
||||||
await Promise.all(
|
await Promise.all(
|
||||||
variants.map(async (variant) => {
|
variants.map(async (variant) => {
|
||||||
const outputs = modelOutputs.filter((output) => output.promptVariantId === variant.id);
|
const variantCells = cells.filter((cell) => cell.promptVariantId === variant.id);
|
||||||
const passCount = outputs.filter((output) =>
|
const passCount = variantCells.filter((cell) =>
|
||||||
evaluateOutput(output, output.testScenario, evaluation),
|
evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
|
||||||
).length;
|
).length;
|
||||||
const failCount = outputs.length - passCount;
|
const failCount = variantCells.length - passCount;
|
||||||
|
|
||||||
await prisma.evaluationResult.upsert({
|
await prisma.evaluationResult.upsert({
|
||||||
where: {
|
where: {
|
||||||
|
|||||||
76
src/server/utils/generateNewCell.ts
Normal file
76
src/server/utils/generateNewCell.ts
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
import crypto from "crypto";
|
||||||
|
import { type Prisma } from "@prisma/client";
|
||||||
|
import { prisma } from "../db";
|
||||||
|
import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
|
||||||
|
import { constructPrompt } from "./constructPrompt";
|
||||||
|
|
||||||
|
export const generateNewCell = async (variantId: string, scenarioId: string) => {
|
||||||
|
const variant = await prisma.promptVariant.findUnique({
|
||||||
|
where: {
|
||||||
|
id: variantId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const scenario = await prisma.testScenario.findUnique({
|
||||||
|
where: {
|
||||||
|
id: scenarioId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!variant || !scenario) return null;
|
||||||
|
|
||||||
|
const prompt = await constructPrompt(variant, scenario.variableValues);
|
||||||
|
|
||||||
|
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
|
||||||
|
|
||||||
|
let cell = await prisma.scenarioVariantCell.findUnique({
|
||||||
|
where: {
|
||||||
|
promptVariantId_testScenarioId: {
|
||||||
|
promptVariantId: variantId,
|
||||||
|
testScenarioId: scenarioId,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
include: {
|
||||||
|
modelOutput: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (cell) return cell;
|
||||||
|
|
||||||
|
cell = await prisma.scenarioVariantCell.create({
|
||||||
|
data: {
|
||||||
|
promptVariantId: variantId,
|
||||||
|
testScenarioId: scenarioId,
|
||||||
|
},
|
||||||
|
include: {
|
||||||
|
modelOutput: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const matchingModelOutput = await prisma.modelOutput.findFirst({
|
||||||
|
where: {
|
||||||
|
inputHash,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
let newModelOutput;
|
||||||
|
|
||||||
|
if (matchingModelOutput) {
|
||||||
|
newModelOutput = await prisma.modelOutput.create({
|
||||||
|
data: {
|
||||||
|
scenarioVariantCellId: cell.id,
|
||||||
|
inputHash,
|
||||||
|
output: matchingModelOutput.output as Prisma.InputJsonValue,
|
||||||
|
timeToComplete: matchingModelOutput.timeToComplete,
|
||||||
|
promptTokens: matchingModelOutput.promptTokens,
|
||||||
|
completionTokens: matchingModelOutput.completionTokens,
|
||||||
|
createdAt: matchingModelOutput.createdAt,
|
||||||
|
updatedAt: matchingModelOutput.updatedAt,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
cell = await queueLLMRetrievalTask(cell.id);
|
||||||
|
}
|
||||||
|
|
||||||
|
return { ...cell, modelOutput: newModelOutput };
|
||||||
|
};
|
||||||
@@ -9,7 +9,7 @@ import { env } from "~/env.mjs";
|
|||||||
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||||
import { rateLimitErrorMessage } from "~/sharedStrings";
|
import { rateLimitErrorMessage } from "~/sharedStrings";
|
||||||
|
|
||||||
type CompletionResponse = {
|
export type CompletionResponse = {
|
||||||
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
|
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
|
||||||
statusCode: number;
|
statusCode: number;
|
||||||
errorMessage: string | null;
|
errorMessage: string | null;
|
||||||
|
|||||||
22
src/server/utils/queueLLMRetrievalTask.ts
Normal file
22
src/server/utils/queueLLMRetrievalTask.ts
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import { prisma } from "../db";
|
||||||
|
import { queryLLM } from "../tasks/queryLLM.task";
|
||||||
|
|
||||||
|
export const queueLLMRetrievalTask = async (cellId: string) => {
|
||||||
|
const updatedCell = await prisma.scenarioVariantCell.update({
|
||||||
|
where: {
|
||||||
|
id: cellId,
|
||||||
|
},
|
||||||
|
data: {
|
||||||
|
retrievalStatus: "PENDING",
|
||||||
|
errorMessage: null,
|
||||||
|
},
|
||||||
|
include: {
|
||||||
|
modelOutput: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// @ts-expect-error we aren't passing the helpers but that's ok
|
||||||
|
void queryLLM.task.handler({ scenarioVariantCellId: cellId }, { logger: console });
|
||||||
|
|
||||||
|
return updatedCell;
|
||||||
|
};
|
||||||
7
src/server/utils/shouldStream.ts
Normal file
7
src/server/utils/shouldStream.ts
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
import { isObject } from "lodash";
|
||||||
|
import { type JSONSerializable } from "../types";
|
||||||
|
|
||||||
|
export const shouldStream = (config: JSONSerializable): boolean => {
|
||||||
|
const shouldStream = isObject(config) && "stream" in config && config.stream === true;
|
||||||
|
return shouldStream;
|
||||||
|
};
|
||||||
1
src/server/utils/sleep.ts
Normal file
1
src/server/utils/sleep.ts
Normal file
@@ -0,0 +1 @@
|
|||||||
|
export const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
|
||||||
@@ -33,6 +33,7 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
|
|||||||
|
|
||||||
monaco.languages.typescript.typescriptDefaults.setCompilerOptions({
|
monaco.languages.typescript.typescriptDefaults.setCompilerOptions({
|
||||||
allowNonTsExtensions: true,
|
allowNonTsExtensions: true,
|
||||||
|
strictNullChecks: true,
|
||||||
lib: ["esnext"],
|
lib: ["esnext"],
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -84,7 +85,7 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
|
|||||||
)} as const;
|
)} as const;
|
||||||
|
|
||||||
type Scenario = typeof scenarios[number];
|
type Scenario = typeof scenarios[number];
|
||||||
declare var scenario: Scenario | null;
|
declare var scenario: Scenario | { [key: string]: string };
|
||||||
`;
|
`;
|
||||||
|
|
||||||
const scenariosModel = monaco.editor.getModel(monaco.Uri.parse("file:///scenarios.ts"));
|
const scenariosModel = monaco.editor.getModel(monaco.Uri.parse("file:///scenarios.ts"));
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import { env } from "~/env.mjs";
|
|||||||
|
|
||||||
const url = env.NEXT_PUBLIC_SOCKET_URL;
|
const url = env.NEXT_PUBLIC_SOCKET_URL;
|
||||||
|
|
||||||
export default function useSocket(channel?: string) {
|
export default function useSocket(channel?: string | null) {
|
||||||
const socketRef = useRef<Socket>();
|
const socketRef = useRef<Socket>();
|
||||||
const [message, setMessage] = useState<ChatCompletion | null>(null);
|
const [message, setMessage] = useState<ChatCompletion | null>(null);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user