Compare commits
21 Commits
prompt-tem
...
save-butto
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
74c201d3a8 | ||
|
|
ab9c721d09 | ||
|
|
0a2578a1d8 | ||
|
|
1bebaff386 | ||
|
|
3bf5eaf4a2 | ||
|
|
ded97f8bb9 | ||
|
|
26ee8698be | ||
|
|
b98eb9b729 | ||
|
|
032c07ec65 | ||
|
|
80c0d13bb9 | ||
|
|
f7c94be3f6 | ||
|
|
c3e85607e0 | ||
|
|
cd5927b8f5 | ||
|
|
731406d1f4 | ||
|
|
3c59e4b774 | ||
|
|
a20f81939d | ||
|
|
972b1f2333 | ||
|
|
7321f3deda | ||
|
|
2bd41fdfbf | ||
|
|
a5378b106b | ||
|
|
0371dacfca |
51
.github/workflows/ci.yaml
vendored
Normal file
51
.github/workflows/ci.yaml
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
name: CI checks
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
run-checks:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v2
|
||||
with:
|
||||
node-version: "20"
|
||||
|
||||
- uses: pnpm/action-setup@v2
|
||||
name: Install pnpm
|
||||
id: pnpm-install
|
||||
with:
|
||||
version: 8.6.1
|
||||
run_install: false
|
||||
|
||||
- name: Get pnpm store directory
|
||||
id: pnpm-cache
|
||||
shell: bash
|
||||
run: |
|
||||
echo "STORE_PATH=$(pnpm store path)" >> $GITHUB_OUTPUT
|
||||
|
||||
- uses: actions/cache@v3
|
||||
name: Setup pnpm cache
|
||||
with:
|
||||
path: ${{ steps.pnpm-cache.outputs.STORE_PATH }}
|
||||
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-store-
|
||||
|
||||
- name: Install Dependencies
|
||||
run: pnpm install
|
||||
|
||||
- name: Check types
|
||||
run: pnpm tsc
|
||||
|
||||
- name: Lint
|
||||
run: SKIP_ENV_VALIDATION=1 pnpm lint
|
||||
|
||||
- name: Check prettier
|
||||
run: pnpm prettier . --check
|
||||
1
.tool-versions
Normal file
1
.tool-versions
Normal file
@@ -0,0 +1 @@
|
||||
nodejs 20.2.0
|
||||
13
package.json
13
package.json
@@ -3,10 +3,15 @@
|
||||
"type": "module",
|
||||
"version": "0.1.0",
|
||||
"license": "Apache-2.0",
|
||||
"engines": {
|
||||
"node": ">=20.0.0",
|
||||
"pnpm": ">=8.6.1"
|
||||
},
|
||||
"scripts": {
|
||||
"build": "next build",
|
||||
"dev:next": "next dev",
|
||||
"dev:wss": "pnpm tsx --watch src/wss-server.ts",
|
||||
"dev:worker": "NODE_ENV='development' pnpm tsx --watch src/server/tasks/worker.ts",
|
||||
"dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss'",
|
||||
"postinstall": "prisma generate",
|
||||
"lint": "next lint",
|
||||
@@ -14,6 +19,8 @@
|
||||
"codegen": "tsx src/codegen/export-openai-types.ts"
|
||||
},
|
||||
"dependencies": {
|
||||
"@babel/preset-typescript": "^7.22.5",
|
||||
"@babel/standalone": "^7.22.9",
|
||||
"@chakra-ui/next-js": "^2.1.4",
|
||||
"@chakra-ui/react": "^2.7.1",
|
||||
"@emotion/react": "^11.11.1",
|
||||
@@ -38,6 +45,7 @@
|
||||
"express": "^4.18.2",
|
||||
"framer-motion": "^10.12.17",
|
||||
"gpt-tokens": "^1.0.10",
|
||||
"graphile-worker": "^0.13.0",
|
||||
"immer": "^10.0.2",
|
||||
"isolated-vm": "^4.5.0",
|
||||
"json-stringify-pretty-compact": "^4.0.0",
|
||||
@@ -48,6 +56,7 @@
|
||||
"openai": "4.0.0-beta.2",
|
||||
"pluralize": "^8.0.0",
|
||||
"posthog-js": "^1.68.4",
|
||||
"prettier": "^3.0.0",
|
||||
"react": "18.2.0",
|
||||
"react-dom": "18.2.0",
|
||||
"react-icons": "^4.10.1",
|
||||
@@ -62,6 +71,8 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@openapi-contrib/openapi-schema-to-json-schema": "^4.0.5",
|
||||
"@types/babel__core": "^7.20.1",
|
||||
"@types/babel__standalone": "^7.1.4",
|
||||
"@types/chroma-js": "^2.4.0",
|
||||
"@types/cors": "^2.8.13",
|
||||
"@types/eslint": "^8.37.0",
|
||||
@@ -77,8 +88,8 @@
|
||||
"eslint": "^8.40.0",
|
||||
"eslint-config-next": "^13.4.2",
|
||||
"eslint-plugin-unused-imports": "^2.0.0",
|
||||
"monaco-editor": "^0.40.0",
|
||||
"openapi-typescript": "^6.3.4",
|
||||
"prettier": "^3.0.0",
|
||||
"prisma": "^4.14.0",
|
||||
"raw-loader": "^4.0.2",
|
||||
"typescript": "^5.0.4",
|
||||
|
||||
916
pnpm-lock.yaml
generated
916
pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,49 @@
|
||||
-- Drop the foreign key constraints on the original ModelOutput
|
||||
ALTER TABLE "ModelOutput" DROP CONSTRAINT "ModelOutput_promptVariantId_fkey";
|
||||
ALTER TABLE "ModelOutput" DROP CONSTRAINT "ModelOutput_testScenarioId_fkey";
|
||||
|
||||
-- Rename the old table
|
||||
ALTER TABLE "ModelOutput" RENAME TO "ScenarioVariantCell";
|
||||
ALTER TABLE "ScenarioVariantCell" RENAME CONSTRAINT "ModelOutput_pkey" TO "ScenarioVariantCell_pkey";
|
||||
ALTER INDEX "ModelOutput_inputHash_idx" RENAME TO "ScenarioVariantCell_inputHash_idx";
|
||||
ALTER INDEX "ModelOutput_promptVariantId_testScenarioId_key" RENAME TO "ScenarioVariantCell_promptVariantId_testScenarioId_key";
|
||||
|
||||
-- Add the new fields to the renamed table
|
||||
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "retryTime" TIMESTAMP(3);
|
||||
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "streamingChannel" TEXT;
|
||||
ALTER TABLE "ScenarioVariantCell" ALTER COLUMN "inputHash" DROP NOT NULL;
|
||||
ALTER TABLE "ScenarioVariantCell" ALTER COLUMN "output" DROP NOT NULL,
|
||||
ALTER COLUMN "statusCode" DROP NOT NULL,
|
||||
ALTER COLUMN "timeToComplete" DROP NOT NULL;
|
||||
|
||||
-- Create the new table
|
||||
CREATE TABLE "ModelOutput" (
|
||||
"id" UUID NOT NULL,
|
||||
"inputHash" TEXT NOT NULL,
|
||||
"output" JSONB NOT NULL,
|
||||
"timeToComplete" INTEGER NOT NULL DEFAULT 0,
|
||||
"promptTokens" INTEGER,
|
||||
"completionTokens" INTEGER,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"scenarioVariantCellId" UUID
|
||||
);
|
||||
|
||||
-- Move inputHash index
|
||||
DROP INDEX "ScenarioVariantCell_inputHash_idx";
|
||||
CREATE INDEX "ModelOutput_inputHash_idx" ON "ModelOutput"("inputHash");
|
||||
|
||||
CREATE UNIQUE INDEX "ModelOutput_scenarioVariantCellId_key" ON "ModelOutput"("scenarioVariantCellId");
|
||||
ALTER TABLE "ModelOutput" ADD CONSTRAINT "ModelOutput_scenarioVariantCellId_fkey" FOREIGN KEY ("scenarioVariantCellId") REFERENCES "ScenarioVariantCell"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
ALTER TABLE "ModelOutput" ALTER COLUMN "scenarioVariantCellId" SET NOT NULL,
|
||||
ADD CONSTRAINT "ModelOutput_pkey" PRIMARY KEY ("id");
|
||||
|
||||
ALTER TABLE "ScenarioVariantCell" ADD CONSTRAINT "ScenarioVariantCell_promptVariantId_fkey" FOREIGN KEY ("promptVariantId") REFERENCES "PromptVariant"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
ALTER TABLE "ScenarioVariantCell" ADD CONSTRAINT "ScenarioVariantCell_testScenarioId_fkey" FOREIGN KEY ("testScenarioId") REFERENCES "TestScenario"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- CreateEnum
|
||||
CREATE TYPE "CellRetrievalStatus" AS ENUM ('PENDING', 'IN_PROGRESS', 'COMPLETE', 'ERROR');
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "retrievalStatus" "CellRetrievalStatus" NOT NULL DEFAULT 'COMPLETE';
|
||||
@@ -0,0 +1,2 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "PromptVariant" ADD COLUMN "model" TEXT NOT NULL DEFAULT 'gpt-3.5-turbo';
|
||||
@@ -26,10 +26,11 @@ model Experiment {
|
||||
}
|
||||
|
||||
model PromptVariant {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
label String
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
|
||||
label String
|
||||
constructFn String
|
||||
model String @default("gpt-3.5-turbo")
|
||||
|
||||
uiId String @default(uuid()) @db.Uuid
|
||||
visible Boolean @default(true)
|
||||
@@ -40,7 +41,7 @@ model PromptVariant {
|
||||
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
ModelOutput ModelOutput[]
|
||||
scenarioVariantCells ScenarioVariantCell[]
|
||||
EvaluationResult EvaluationResult[]
|
||||
|
||||
@@index([uiId])
|
||||
@@ -60,7 +61,7 @@ model TestScenario {
|
||||
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
ModelOutput ModelOutput[]
|
||||
scenarioVariantCells ScenarioVariantCell[]
|
||||
}
|
||||
|
||||
model TemplateVariable {
|
||||
@@ -75,17 +76,28 @@ model TemplateVariable {
|
||||
updatedAt DateTime @updatedAt
|
||||
}
|
||||
|
||||
model ModelOutput {
|
||||
enum CellRetrievalStatus {
|
||||
PENDING
|
||||
IN_PROGRESS
|
||||
COMPLETE
|
||||
ERROR
|
||||
}
|
||||
|
||||
model ScenarioVariantCell {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
|
||||
inputHash String
|
||||
output Json
|
||||
statusCode Int
|
||||
errorMessage String?
|
||||
timeToComplete Int @default(0)
|
||||
inputHash String? // TODO: Remove once migration is complete
|
||||
output Json? // TODO: Remove once migration is complete
|
||||
statusCode Int?
|
||||
errorMessage String?
|
||||
timeToComplete Int? @default(0) // TODO: Remove once migration is complete
|
||||
retryTime DateTime?
|
||||
streamingChannel String?
|
||||
retrievalStatus CellRetrievalStatus @default(COMPLETE)
|
||||
|
||||
promptTokens Int? // Added promptTokens field
|
||||
completionTokens Int? // Added completionTokens field
|
||||
promptTokens Int? // TODO: Remove once migration is complete
|
||||
completionTokens Int? // TODO: Remove once migration is complete
|
||||
modelOutput ModelOutput?
|
||||
|
||||
promptVariantId String @db.Uuid
|
||||
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
|
||||
@@ -97,6 +109,24 @@ model ModelOutput {
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
@@unique([promptVariantId, testScenarioId])
|
||||
}
|
||||
|
||||
model ModelOutput {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
|
||||
inputHash String
|
||||
output Json
|
||||
timeToComplete Int @default(0)
|
||||
promptTokens Int?
|
||||
completionTokens Int?
|
||||
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
scenarioVariantCellId String @db.Uuid
|
||||
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
|
||||
|
||||
@@unique([scenarioVariantCellId])
|
||||
@@index([inputHash])
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ const experiment = await prisma.experiment.create({
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.modelOutput.deleteMany({
|
||||
await prisma.scenarioVariantCell.deleteMany({
|
||||
where: {
|
||||
promptVariant: {
|
||||
experimentId,
|
||||
|
||||
33
src/components/OutputsTable/OutputCell/CellOptions.tsx
Normal file
33
src/components/OutputsTable/OutputCell/CellOptions.tsx
Normal file
@@ -0,0 +1,33 @@
|
||||
import { Button, HStack, Icon } from "@chakra-ui/react";
|
||||
import { BsArrowClockwise } from "react-icons/bs";
|
||||
|
||||
export const CellOptions = ({
|
||||
refetchingOutput,
|
||||
refetchOutput,
|
||||
}: {
|
||||
refetchingOutput: boolean;
|
||||
refetchOutput: () => void;
|
||||
}) => {
|
||||
return (
|
||||
<HStack justifyContent="flex-end" w="full">
|
||||
{!refetchingOutput && (
|
||||
<Button
|
||||
size="xs"
|
||||
w={4}
|
||||
h={4}
|
||||
py={4}
|
||||
px={4}
|
||||
minW={0}
|
||||
borderRadius={8}
|
||||
color="gray.500"
|
||||
variant="ghost"
|
||||
cursor="pointer"
|
||||
onClick={refetchOutput}
|
||||
aria-label="refetch output"
|
||||
>
|
||||
<Icon as={BsArrowClockwise} boxSize={4} />
|
||||
</Button>
|
||||
)}
|
||||
</HStack>
|
||||
);
|
||||
};
|
||||
@@ -1,29 +1,21 @@
|
||||
import { type ModelOutput } from "@prisma/client";
|
||||
import { HStack, VStack, Text, Button, Icon } from "@chakra-ui/react";
|
||||
import { type ScenarioVariantCell } from "@prisma/client";
|
||||
import { VStack, Text } from "@chakra-ui/react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { BsArrowClockwise } from "react-icons/bs";
|
||||
import { rateLimitErrorMessage } from "~/sharedStrings";
|
||||
import pluralize from "pluralize";
|
||||
|
||||
const MAX_AUTO_RETRIES = 3;
|
||||
|
||||
export const ErrorHandler = ({
|
||||
output,
|
||||
cell,
|
||||
refetchOutput,
|
||||
numPreviousTries,
|
||||
}: {
|
||||
output: ModelOutput;
|
||||
cell: ScenarioVariantCell;
|
||||
refetchOutput: () => void;
|
||||
numPreviousTries: number;
|
||||
}) => {
|
||||
const [msToWait, setMsToWait] = useState(0);
|
||||
const shouldAutoRetry =
|
||||
output.errorMessage === rateLimitErrorMessage && numPreviousTries < MAX_AUTO_RETRIES;
|
||||
|
||||
useEffect(() => {
|
||||
if (!shouldAutoRetry) return;
|
||||
if (!cell.retryTime) return;
|
||||
|
||||
const initialWaitTime = calculateDelay(numPreviousTries);
|
||||
const initialWaitTime = cell.retryTime.getTime() - Date.now();
|
||||
const msModuloOneSecond = initialWaitTime % 1000;
|
||||
let remainingTime = initialWaitTime - msModuloOneSecond;
|
||||
setMsToWait(remainingTime);
|
||||
@@ -35,7 +27,6 @@ export const ErrorHandler = ({
|
||||
setMsToWait(remainingTime);
|
||||
|
||||
if (remainingTime <= 0) {
|
||||
refetchOutput();
|
||||
clearInterval(interval);
|
||||
}
|
||||
}, 1000);
|
||||
@@ -45,32 +36,12 @@ export const ErrorHandler = ({
|
||||
clearInterval(interval);
|
||||
clearTimeout(timeout);
|
||||
};
|
||||
}, [shouldAutoRetry, setMsToWait, refetchOutput, numPreviousTries]);
|
||||
}, [cell.retryTime, cell.statusCode, setMsToWait, refetchOutput]);
|
||||
|
||||
return (
|
||||
<VStack w="full">
|
||||
<HStack w="full" alignItems="flex-start" justifyContent="space-between">
|
||||
<Text color="red.600" fontWeight="bold">
|
||||
Error
|
||||
</Text>
|
||||
<Button
|
||||
size="xs"
|
||||
w={4}
|
||||
h={4}
|
||||
px={4}
|
||||
py={4}
|
||||
minW={0}
|
||||
borderRadius={8}
|
||||
variant="ghost"
|
||||
cursor="pointer"
|
||||
onClick={refetchOutput}
|
||||
aria-label="refetch output"
|
||||
>
|
||||
<Icon as={BsArrowClockwise} boxSize={6} />
|
||||
</Button>
|
||||
</HStack>
|
||||
<Text color="red.600" wordBreak="break-word">
|
||||
{output.errorMessage}
|
||||
{cell.errorMessage}
|
||||
</Text>
|
||||
{msToWait > 0 && (
|
||||
<Text color="red.600" fontSize="sm">
|
||||
@@ -80,12 +51,3 @@ export const ErrorHandler = ({
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
const MIN_DELAY = 500; // milliseconds
|
||||
const MAX_DELAY = 5000; // milliseconds
|
||||
|
||||
function calculateDelay(numPreviousTries: number): number {
|
||||
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
|
||||
const jitter = Math.random() * baseDelay;
|
||||
return baseDelay + jitter;
|
||||
}
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
import { type RouterOutputs, api } from "~/utils/api";
|
||||
import { api } from "~/utils/api";
|
||||
import { type PromptVariant, type Scenario } from "../types";
|
||||
import { Spinner, Text, Box, Center, Flex } from "@chakra-ui/react";
|
||||
import { Spinner, Text, Box, Center, Flex, VStack } from "@chakra-ui/react";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import SyntaxHighlighter from "react-syntax-highlighter";
|
||||
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
|
||||
import stringify from "json-stringify-pretty-compact";
|
||||
import { type ReactElement, useState, useEffect, useRef, useCallback } from "react";
|
||||
import { type ReactElement, useState, useEffect } from "react";
|
||||
import { type ChatCompletion } from "openai/resources/chat";
|
||||
import { generateChannel } from "~/utils/generateChannel";
|
||||
import { isObject } from "lodash";
|
||||
import useSocket from "~/utils/useSocket";
|
||||
import { OutputStats } from "./OutputStats";
|
||||
import { ErrorHandler } from "./ErrorHandler";
|
||||
import { CellOptions } from "./CellOptions";
|
||||
|
||||
export default function OutputCell({
|
||||
scenario,
|
||||
@@ -37,120 +36,103 @@ export default function OutputCell({
|
||||
// if (variant.config === null || Object.keys(variant.config).length === 0)
|
||||
// disabledReason = "Save your prompt variant to see output";
|
||||
|
||||
// const model = getModelName(variant.config as JSONSerializable);
|
||||
// TODO: Temporarily hardcoding this while we get other stuff working
|
||||
const model = "gpt-3.5-turbo";
|
||||
|
||||
const outputMutation = api.outputs.get.useMutation();
|
||||
|
||||
const [output, setOutput] = useState<RouterOutputs["outputs"]["get"]>(null);
|
||||
const [channel, setChannel] = useState<string | undefined>(undefined);
|
||||
const [numPreviousTries, setNumPreviousTries] = useState(0);
|
||||
|
||||
const fetchMutex = useRef(false);
|
||||
const [fetchOutput, fetchingOutput] = useHandledAsyncCallback(
|
||||
async (forceRefetch?: boolean) => {
|
||||
if (fetchMutex.current) return;
|
||||
setNumPreviousTries((prev) => prev + 1);
|
||||
|
||||
fetchMutex.current = true;
|
||||
setOutput(null);
|
||||
|
||||
const shouldStream =
|
||||
isObject(variant) &&
|
||||
"config" in variant &&
|
||||
isObject(variant.config) &&
|
||||
"stream" in variant.config &&
|
||||
variant.config.stream === true;
|
||||
|
||||
const channel = shouldStream ? generateChannel() : undefined;
|
||||
setChannel(channel);
|
||||
|
||||
const output = await outputMutation.mutateAsync({
|
||||
scenarioId: scenario.id,
|
||||
variantId: variant.id,
|
||||
channel,
|
||||
forceRefetch,
|
||||
});
|
||||
setOutput(output);
|
||||
await utils.promptVariants.stats.invalidate();
|
||||
fetchMutex.current = false;
|
||||
},
|
||||
[outputMutation, scenario.id, variant.id],
|
||||
const [refetchInterval, setRefetchInterval] = useState(0);
|
||||
const { data: cell, isLoading: queryLoading } = api.scenarioVariantCells.get.useQuery(
|
||||
{ scenarioId: scenario.id, variantId: variant.id },
|
||||
{ refetchInterval },
|
||||
);
|
||||
const hardRefetch = useCallback(() => fetchOutput(true), [fetchOutput]);
|
||||
|
||||
useEffect(fetchOutput, [scenario.id, variant.id]);
|
||||
const { mutateAsync: hardRefetchMutate, isLoading: refetchingOutput } =
|
||||
api.scenarioVariantCells.forceRefetch.useMutation();
|
||||
const [hardRefetch] = useHandledAsyncCallback(async () => {
|
||||
await hardRefetchMutate({ scenarioId: scenario.id, variantId: variant.id });
|
||||
await utils.scenarioVariantCells.get.invalidate({
|
||||
scenarioId: scenario.id,
|
||||
variantId: variant.id,
|
||||
});
|
||||
}, [hardRefetchMutate, scenario.id, variant.id]);
|
||||
|
||||
const fetchingOutput = queryLoading || refetchingOutput;
|
||||
|
||||
const awaitingOutput =
|
||||
!cell || cell.retrievalStatus === "PENDING" || cell.retrievalStatus === "IN_PROGRESS";
|
||||
useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]);
|
||||
|
||||
const modelOutput = cell?.modelOutput;
|
||||
|
||||
// Disconnect from socket if we're not streaming anymore
|
||||
const streamedMessage = useSocket(fetchingOutput ? channel : undefined);
|
||||
const streamedMessage = useSocket(cell?.streamingChannel);
|
||||
const streamedContent = streamedMessage?.choices?.[0]?.message?.content;
|
||||
|
||||
if (!vars) return null;
|
||||
|
||||
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
|
||||
|
||||
if (fetchingOutput && !streamedMessage)
|
||||
if (awaitingOutput && !streamedMessage)
|
||||
return (
|
||||
<Center h="100%" w="100%">
|
||||
<Spinner />
|
||||
</Center>
|
||||
);
|
||||
|
||||
if (!output && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
|
||||
if (!cell && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
|
||||
|
||||
if (output && output.errorMessage) {
|
||||
return (
|
||||
<ErrorHandler
|
||||
output={output}
|
||||
refetchOutput={hardRefetch}
|
||||
numPreviousTries={numPreviousTries}
|
||||
/>
|
||||
);
|
||||
if (cell && cell.errorMessage) {
|
||||
return <ErrorHandler cell={cell} refetchOutput={hardRefetch} />;
|
||||
}
|
||||
|
||||
const response = output?.output as unknown as ChatCompletion;
|
||||
const response = modelOutput?.output as unknown as ChatCompletion;
|
||||
const message = response?.choices?.[0]?.message;
|
||||
|
||||
if (output && message?.function_call) {
|
||||
if (modelOutput && message?.function_call) {
|
||||
const rawArgs = message.function_call.arguments ?? "null";
|
||||
let parsedArgs: string;
|
||||
try {
|
||||
parsedArgs = JSON.parse(rawArgs);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
} catch (e: any) {
|
||||
parsedArgs = `Failed to parse arguments as JSON: '${rawArgs}' ERROR: ${e.message as string}`;
|
||||
}
|
||||
|
||||
return (
|
||||
<Box fontSize="xs" width="100%" flexWrap="wrap" overflowX="auto">
|
||||
<SyntaxHighlighter
|
||||
customStyle={{ overflowX: "unset" }}
|
||||
language="json"
|
||||
style={docco}
|
||||
lineProps={{
|
||||
style: { wordBreak: "break-all", whiteSpace: "pre-wrap" },
|
||||
}}
|
||||
wrapLines
|
||||
>
|
||||
{stringify(
|
||||
{
|
||||
function: message.function_call.name,
|
||||
args: parsedArgs,
|
||||
},
|
||||
{ maxLength: 40 },
|
||||
)}
|
||||
</SyntaxHighlighter>
|
||||
<OutputStats model={model} modelOutput={output} scenario={scenario} />
|
||||
<VStack w="full" spacing={0}>
|
||||
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
|
||||
<SyntaxHighlighter
|
||||
customStyle={{ overflowX: "unset" }}
|
||||
language="json"
|
||||
style={docco}
|
||||
lineProps={{
|
||||
style: { wordBreak: "break-all", whiteSpace: "pre-wrap" },
|
||||
}}
|
||||
wrapLines
|
||||
>
|
||||
{stringify(
|
||||
{
|
||||
function: message.function_call.name,
|
||||
args: parsedArgs,
|
||||
},
|
||||
{ maxLength: 40 },
|
||||
)}
|
||||
</SyntaxHighlighter>
|
||||
</VStack>
|
||||
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
const contentToDisplay = message?.content ?? streamedContent ?? JSON.stringify(output?.output);
|
||||
const contentToDisplay =
|
||||
message?.content ?? streamedContent ?? JSON.stringify(modelOutput?.output);
|
||||
|
||||
return (
|
||||
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
|
||||
{contentToDisplay}
|
||||
{output && <OutputStats model={model} modelOutput={output} scenario={scenario} />}
|
||||
<VStack w="full" alignItems="flex-start" spacing={0}>
|
||||
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
|
||||
<Text>{contentToDisplay}</Text>
|
||||
</VStack>
|
||||
{modelOutput && (
|
||||
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -9,15 +9,15 @@ import { HStack, Icon, Text } from "@chakra-ui/react";
|
||||
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
|
||||
import { CostTooltip } from "~/components/tooltip/CostTooltip";
|
||||
|
||||
const SHOW_COST = false;
|
||||
const SHOW_TIME = false;
|
||||
const SHOW_COST = true;
|
||||
const SHOW_TIME = true;
|
||||
|
||||
export const OutputStats = ({
|
||||
model,
|
||||
modelOutput,
|
||||
scenario,
|
||||
}: {
|
||||
model: SupportedModel | null;
|
||||
model: SupportedModel | string | null;
|
||||
modelOutput: ModelOutput;
|
||||
scenario: Scenario;
|
||||
}) => {
|
||||
@@ -35,8 +35,6 @@ export const OutputStats = ({
|
||||
|
||||
const cost = promptCost + completionCost;
|
||||
|
||||
if (!evals.length) return null;
|
||||
|
||||
return (
|
||||
<HStack align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
|
||||
<HStack flex={1}>
|
||||
|
||||
@@ -13,10 +13,11 @@ import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||
|
||||
export default function ScenarioEditor({
|
||||
scenario,
|
||||
hovered,
|
||||
...props
|
||||
}: {
|
||||
scenario: Scenario;
|
||||
hovered: boolean;
|
||||
canHide: boolean;
|
||||
}) {
|
||||
const savedValues = scenario.variableValues as Record<string, string>;
|
||||
const utils = api.useContext();
|
||||
@@ -92,30 +93,34 @@ export default function ScenarioEditor({
|
||||
onDrop={onReorder}
|
||||
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
|
||||
>
|
||||
<Stack alignSelf="flex-start" opacity={hovered ? 1 : 0} spacing={0}>
|
||||
<Tooltip label="Hide scenario" hasArrow>
|
||||
{/* for some reason the tooltip can't position itself properly relative to the icon without the wrapping box */}
|
||||
<Button
|
||||
variant="unstyled"
|
||||
color="gray.400"
|
||||
height="unset"
|
||||
width="unset"
|
||||
minW="unset"
|
||||
onClick={onHide}
|
||||
_hover={{
|
||||
color: "gray.800",
|
||||
cursor: "pointer",
|
||||
}}
|
||||
>
|
||||
<Icon as={hidingInProgress ? Spinner : BsX} boxSize={6} />
|
||||
</Button>
|
||||
</Tooltip>
|
||||
<Icon
|
||||
as={RiDraggable}
|
||||
boxSize={6}
|
||||
color="gray.400"
|
||||
_hover={{ color: "gray.800", cursor: "pointer" }}
|
||||
/>
|
||||
<Stack alignSelf="flex-start" opacity={props.hovered ? 1 : 0} spacing={0}>
|
||||
{props.canHide && (
|
||||
<>
|
||||
<Tooltip label="Hide scenario" hasArrow>
|
||||
{/* for some reason the tooltip can't position itself properly relative to the icon without the wrapping box */}
|
||||
<Button
|
||||
variant="unstyled"
|
||||
color="gray.400"
|
||||
height="unset"
|
||||
width="unset"
|
||||
minW="unset"
|
||||
onClick={onHide}
|
||||
_hover={{
|
||||
color: "gray.800",
|
||||
cursor: "pointer",
|
||||
}}
|
||||
>
|
||||
<Icon as={hidingInProgress ? Spinner : BsX} boxSize={6} />
|
||||
</Button>
|
||||
</Tooltip>
|
||||
<Icon
|
||||
as={RiDraggable}
|
||||
boxSize={6}
|
||||
color="gray.400"
|
||||
_hover={{ color: "gray.800", cursor: "pointer" }}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</Stack>
|
||||
{variableLabels.length === 0 ? (
|
||||
<Box color="gray.500">{vars.data ? "No scenario variables configured" : "Loading..."}</Box>
|
||||
|
||||
@@ -5,7 +5,11 @@ import OutputCell from "./OutputCell/OutputCell";
|
||||
import ScenarioEditor from "./ScenarioEditor";
|
||||
import type { PromptVariant, Scenario } from "./types";
|
||||
|
||||
const ScenarioRow = (props: { scenario: Scenario; variants: PromptVariant[] }) => {
|
||||
const ScenarioRow = (props: {
|
||||
scenario: Scenario;
|
||||
variants: PromptVariant[];
|
||||
canHide: boolean;
|
||||
}) => {
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
|
||||
const highlightStyle = { backgroundColor: "gray.50" };
|
||||
@@ -18,7 +22,7 @@ const ScenarioRow = (props: { scenario: Scenario; variants: PromptVariant[] }) =
|
||||
sx={isHovered ? highlightStyle : undefined}
|
||||
borderLeftWidth={1}
|
||||
>
|
||||
<ScenarioEditor scenario={props.scenario} hovered={isHovered} />
|
||||
<ScenarioEditor scenario={props.scenario} hovered={isHovered} canHide={props.canHide} />
|
||||
</GridItem>
|
||||
{props.variants.map((variant) => (
|
||||
<GridItem
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import { Box, Button, HStack, Tooltip, useToast } from "@chakra-ui/react";
|
||||
import { Box, Button, HStack, Tooltip, VStack, useToast } from "@chakra-ui/react";
|
||||
import { useRef, useEffect, useState, useCallback } from "react";
|
||||
import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
|
||||
import { type PromptVariant } from "./types";
|
||||
import { api } from "~/utils/api";
|
||||
import { useAppStore } from "~/state/store";
|
||||
// import openAITypes from "~/codegen/openai.types.ts.txt";
|
||||
import { editorBackground } from "~/state/sharedVariantEditor.slice";
|
||||
|
||||
export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
||||
export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||
const monaco = useAppStore.use.sharedVariantEditor.monaco();
|
||||
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
|
||||
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
|
||||
@@ -18,20 +18,27 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
||||
|
||||
const checkForChanges = useCallback(() => {
|
||||
if (!editorRef.current) return;
|
||||
const currentConfig = editorRef.current.getValue();
|
||||
setIsChanged(currentConfig !== lastSavedFn);
|
||||
const currentFn = editorRef.current.getValue();
|
||||
setIsChanged(currentFn.length > 0 && currentFn !== lastSavedFn);
|
||||
}, [lastSavedFn]);
|
||||
|
||||
useEffect(checkForChanges, [checkForChanges, lastSavedFn]);
|
||||
|
||||
const replaceVariant = api.promptVariants.replaceVariant.useMutation();
|
||||
const utils = api.useContext();
|
||||
const toast = useToast();
|
||||
|
||||
const [onSave] = useHandledAsyncCallback(async () => {
|
||||
const currentFn = editorRef.current?.getValue();
|
||||
if (!editorRef.current) return;
|
||||
|
||||
await editorRef.current.getAction("editor.action.formatDocument")?.run();
|
||||
|
||||
const currentFn = editorRef.current.getValue();
|
||||
|
||||
if (!currentFn) return;
|
||||
|
||||
// Check if the editor has any typescript errors
|
||||
const model = editorRef.current?.getModel();
|
||||
const model = editorRef.current.getModel();
|
||||
if (!model) return;
|
||||
|
||||
const markers = monaco?.editor.getModelMarkers({ resource: model.uri });
|
||||
@@ -59,14 +66,21 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
||||
return;
|
||||
}
|
||||
|
||||
await replaceVariant.mutateAsync({
|
||||
const resp = await replaceVariant.mutateAsync({
|
||||
id: props.variant.id,
|
||||
constructFn: currentFn,
|
||||
});
|
||||
if (resp.status === "error") {
|
||||
return toast({
|
||||
title: "Error saving variant",
|
||||
description: resp.message,
|
||||
status: "error",
|
||||
});
|
||||
}
|
||||
|
||||
setIsChanged(false);
|
||||
|
||||
await utils.promptVariants.list.invalidate();
|
||||
|
||||
checkForChanges();
|
||||
}, [checkForChanges]);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -117,21 +131,21 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
||||
/* eslint-disable-next-line react-hooks/exhaustive-deps */
|
||||
}, [monaco, editorId]);
|
||||
|
||||
// useEffect(() => {
|
||||
// const savedConfigChanged = lastSavedFn !== savedConfig;
|
||||
|
||||
// lastSavedFn = savedConfig;
|
||||
|
||||
// if (savedConfigChanged && editorRef.current?.getValue() !== savedConfig) {
|
||||
// editorRef.current?.setValue(savedConfig);
|
||||
// }
|
||||
|
||||
// checkForChanges();
|
||||
// }, [savedConfig, checkForChanges]);
|
||||
|
||||
return (
|
||||
<Box w="100%" pos="relative">
|
||||
<div id={editorId} style={{ height: "300px", width: "100%" }}></div>
|
||||
<VStack
|
||||
spacing={0}
|
||||
align="stretch"
|
||||
fontSize="xs"
|
||||
fontWeight="bold"
|
||||
color="gray.600"
|
||||
py={2}
|
||||
bgColor={editorBackground}
|
||||
>
|
||||
<code>{`function constructPrompt(scenario: Scenario): Prompt {`}</code>
|
||||
<div id={editorId} style={{ height: "300px", width: "100%" }}></div>
|
||||
<code>{`return prompt; }`}</code>
|
||||
</VStack>
|
||||
{isChanged && (
|
||||
<HStack pos="absolute" bottom={2} right={2}>
|
||||
<Button
|
||||
|
||||
@@ -8,7 +8,7 @@ import { RiDraggable } from "react-icons/ri";
|
||||
import { cellPadding, headerMinHeight } from "../constants";
|
||||
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||
|
||||
export default function VariantHeader(props: { variant: PromptVariant }) {
|
||||
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
|
||||
const utils = api.useContext();
|
||||
const [isDragTarget, setIsDragTarget] = useState(false);
|
||||
const [isInputHovered, setIsInputHovered] = useState(false);
|
||||
@@ -95,11 +95,13 @@ export default function VariantHeader(props: { variant: PromptVariant }) {
|
||||
onMouseEnter={() => setIsInputHovered(true)}
|
||||
onMouseLeave={() => setIsInputHovered(false)}
|
||||
/>
|
||||
<Tooltip label="Hide Variant" hasArrow>
|
||||
<Button variant="ghost" colorScheme="gray" size="sm" onClick={onHide}>
|
||||
<Icon as={BsX} boxSize={6} />
|
||||
</Button>
|
||||
</Tooltip>
|
||||
{props.canHide && (
|
||||
<Tooltip label="Remove Variant" hasArrow>
|
||||
<Button variant="ghost" colorScheme="gray" size="sm" onClick={onHide}>
|
||||
<Icon as={BsX} boxSize={6} />
|
||||
</Button>
|
||||
</Tooltip>
|
||||
)}
|
||||
</HStack>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ import { api } from "~/utils/api";
|
||||
import NewScenarioButton from "./NewScenarioButton";
|
||||
import NewVariantButton from "./NewVariantButton";
|
||||
import ScenarioRow from "./ScenarioRow";
|
||||
import VariantConfigEditor from "./VariantEditor";
|
||||
import VariantEditor from "./VariantEditor";
|
||||
import VariantHeader from "./VariantHeader";
|
||||
import { cellPadding } from "../constants";
|
||||
import { BsPencil } from "react-icons/bs";
|
||||
@@ -78,7 +78,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
||||
|
||||
{variants.data.map((variant) => (
|
||||
<GridItem key={variant.uiId} padding={0} sx={stickyHeaderStyle} borderTopWidth={1}>
|
||||
<VariantHeader variant={variant} />
|
||||
<VariantHeader variant={variant} canHide={variants.data.length > 1} />
|
||||
</GridItem>
|
||||
))}
|
||||
<GridItem
|
||||
@@ -94,7 +94,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
||||
|
||||
{variants.data.map((variant) => (
|
||||
<GridItem key={variant.uiId}>
|
||||
<VariantConfigEditor variant={variant} />
|
||||
<VariantEditor variant={variant} />
|
||||
</GridItem>
|
||||
))}
|
||||
{variants.data.map((variant) => (
|
||||
@@ -103,7 +103,12 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
||||
</GridItem>
|
||||
))}
|
||||
{scenarios.data.map((scenario) => (
|
||||
<ScenarioRow key={scenario.uiId} scenario={scenario} variants={variants.data} />
|
||||
<ScenarioRow
|
||||
key={scenario.uiId}
|
||||
scenario={scenario}
|
||||
variants={variants.data}
|
||||
canHide={scenarios.data.length > 1}
|
||||
/>
|
||||
))}
|
||||
<GridItem borderBottomWidth={0} borderRightWidth={0} w="100%" colSpan={allCols} padding={0}>
|
||||
<NewScenarioButton />
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { type GetServerSideProps } from "next";
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/require-await
|
||||
export const getServerSideProps: GetServerSideProps = async (context) => {
|
||||
export const getServerSideProps: GetServerSideProps = async () => {
|
||||
return {
|
||||
redirect: {
|
||||
destination: "/experiments",
|
||||
|
||||
@@ -3,10 +3,6 @@ import { prisma } from "../db";
|
||||
import { openai } from "../utils/openai";
|
||||
import { pick } from "lodash";
|
||||
|
||||
function promptHasVariable(prompt: string, variableName: string) {
|
||||
return prompt.includes(`{{${variableName}}}`);
|
||||
}
|
||||
|
||||
type AxiosError = {
|
||||
response?: {
|
||||
data?: {
|
||||
|
||||
@@ -2,7 +2,7 @@ import { promptVariantsRouter } from "~/server/api/routers/promptVariants.router
|
||||
import { createTRPCRouter } from "~/server/api/trpc";
|
||||
import { experimentsRouter } from "./routers/experiments.router";
|
||||
import { scenariosRouter } from "./routers/scenarios.router";
|
||||
import { modelOutputsRouter } from "./routers/modelOutputs.router";
|
||||
import { scenarioVariantCellsRouter } from "./routers/scenarioVariantCells.router";
|
||||
import { templateVarsRouter } from "./routers/templateVariables.router";
|
||||
import { evaluationsRouter } from "./routers/evaluations.router";
|
||||
|
||||
@@ -15,7 +15,7 @@ export const appRouter = createTRPCRouter({
|
||||
promptVariants: promptVariantsRouter,
|
||||
experiments: experimentsRouter,
|
||||
scenarios: scenariosRouter,
|
||||
outputs: modelOutputsRouter,
|
||||
scenarioVariantCells: scenarioVariantCellsRouter,
|
||||
templateVars: templateVarsRouter,
|
||||
evaluations: evaluationsRouter,
|
||||
});
|
||||
|
||||
@@ -2,6 +2,7 @@ import { z } from "zod";
|
||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import dedent from "dedent";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
|
||||
export const experimentsRouter = createTRPCRouter({
|
||||
list: publicProcedure.query(async () => {
|
||||
@@ -64,7 +65,7 @@ export const experimentsRouter = createTRPCRouter({
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.$transaction([
|
||||
const [variant, _, scenario] = await prisma.$transaction([
|
||||
prisma.promptVariant.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
@@ -73,18 +74,29 @@ export const experimentsRouter = createTRPCRouter({
|
||||
constructFn: dedent`prompt = {
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
stream: true,
|
||||
messages: [{ role: "system", content: "Return 'Ready to go!'" }],
|
||||
messages: [{ role: "system", content: ${"`Return '${scenario.text}'`"} }],
|
||||
}`,
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
},
|
||||
}),
|
||||
prisma.templateVariable.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
label: "text",
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
variableValues: {},
|
||||
variableValues: {
|
||||
text: "This is a test scenario.",
|
||||
},
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
||||
await generateNewCell(variant.id, scenario.id);
|
||||
|
||||
return exp;
|
||||
}),
|
||||
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import crypto from "crypto";
|
||||
import type { Prisma } from "@prisma/client";
|
||||
import { reevaluateVariant } from "~/server/utils/evaluations";
|
||||
import { getCompletion } from "~/server/utils/getCompletion";
|
||||
import { constructPrompt } from "~/server/utils/constructPrompt";
|
||||
|
||||
export const modelOutputsRouter = createTRPCRouter({
|
||||
get: publicProcedure
|
||||
.input(
|
||||
z.object({
|
||||
scenarioId: z.string(),
|
||||
variantId: z.string(),
|
||||
channel: z.string().optional(),
|
||||
forceRefetch: z.boolean().optional(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input }) => {
|
||||
const existing = await prisma.modelOutput.findUnique({
|
||||
where: {
|
||||
promptVariantId_testScenarioId: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenarioId: input.scenarioId,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
if (existing && !input.forceRefetch) return existing;
|
||||
|
||||
const variant = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: input.variantId,
|
||||
},
|
||||
});
|
||||
|
||||
const scenario = await prisma.testScenario.findUnique({
|
||||
where: {
|
||||
id: input.scenarioId,
|
||||
},
|
||||
});
|
||||
|
||||
if (!variant || !scenario) return null;
|
||||
|
||||
const prompt = await constructPrompt(variant, scenario);
|
||||
|
||||
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
|
||||
|
||||
// TODO: we should probably only use this if temperature=0
|
||||
const existingResponse = await prisma.modelOutput.findFirst({
|
||||
where: { inputHash, errorMessage: null },
|
||||
});
|
||||
|
||||
let modelResponse: Awaited<ReturnType<typeof getCompletion>>;
|
||||
|
||||
if (existingResponse) {
|
||||
modelResponse = {
|
||||
output: existingResponse.output as Prisma.InputJsonValue,
|
||||
statusCode: existingResponse.statusCode,
|
||||
errorMessage: existingResponse.errorMessage,
|
||||
timeToComplete: existingResponse.timeToComplete,
|
||||
promptTokens: existingResponse.promptTokens ?? undefined,
|
||||
completionTokens: existingResponse.completionTokens ?? undefined,
|
||||
};
|
||||
} else {
|
||||
try {
|
||||
modelResponse = await getCompletion(prompt, input.channel);
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
const modelOutput = await prisma.modelOutput.upsert({
|
||||
where: {
|
||||
promptVariantId_testScenarioId: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenarioId: input.scenarioId,
|
||||
},
|
||||
},
|
||||
create: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenarioId: input.scenarioId,
|
||||
inputHash,
|
||||
...modelResponse,
|
||||
},
|
||||
update: {
|
||||
...modelResponse,
|
||||
},
|
||||
});
|
||||
|
||||
await reevaluateVariant(input.variantId);
|
||||
|
||||
return modelOutput;
|
||||
}),
|
||||
});
|
||||
@@ -1,6 +1,12 @@
|
||||
import dedent from "dedent";
|
||||
import { isObject } from "lodash";
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
import { OpenAIChatModel } from "~/server/types";
|
||||
import { constructPrompt } from "~/server/utils/constructPrompt";
|
||||
import userError from "~/server/utils/error";
|
||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
||||
|
||||
@@ -39,17 +45,24 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
const outputCount = await prisma.modelOutput.count({
|
||||
const outputCount = await prisma.scenarioVariantCell.count({
|
||||
where: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenario: { visible: true },
|
||||
modelOutput: {
|
||||
isNot: null,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const overallTokens = await prisma.modelOutput.aggregate({
|
||||
where: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenario: { visible: true },
|
||||
scenarioVariantCell: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenario: {
|
||||
visible: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
_sum: {
|
||||
promptTokens: true,
|
||||
@@ -57,14 +70,10 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
},
|
||||
});
|
||||
|
||||
// TODO: fix this
|
||||
const model = "gpt-3.5-turbo-0613";
|
||||
// const model = getModelName(variant.config);
|
||||
|
||||
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
|
||||
const overallPromptCost = calculateTokenCost(model, promptTokens);
|
||||
const overallPromptCost = calculateTokenCost(variant.model, promptTokens);
|
||||
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
|
||||
const overallCompletionCost = calculateTokenCost(model, completionTokens, true);
|
||||
const overallCompletionCost = calculateTokenCost(variant.model, completionTokens, true);
|
||||
|
||||
const overallCost = overallPromptCost + overallCompletionCost;
|
||||
|
||||
@@ -105,7 +114,19 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
experimentId: input.experimentId,
|
||||
label: `Prompt Variant ${largestSortIndex + 2}`,
|
||||
sortIndex: (lastVariant?.sortIndex ?? 0) + 1,
|
||||
constructFn: lastVariant?.constructFn ?? "",
|
||||
constructFn:
|
||||
lastVariant?.constructFn ??
|
||||
dedent`
|
||||
prompt = {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: "Return 'Hello, world!'",
|
||||
}
|
||||
]
|
||||
}`,
|
||||
model: lastVariant?.model ?? "gpt-3.5-turbo",
|
||||
},
|
||||
});
|
||||
|
||||
@@ -114,6 +135,17 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
recordExperimentUpdated(input.experimentId),
|
||||
]);
|
||||
|
||||
const scenarios = await prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
for (const scenario of scenarios) {
|
||||
await generateNewCell(newVariant.id, scenario.id);
|
||||
}
|
||||
|
||||
return newVariant;
|
||||
}),
|
||||
|
||||
@@ -185,6 +217,27 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
|
||||
}
|
||||
|
||||
let model = existing.model;
|
||||
try {
|
||||
const contructedPrompt = await constructPrompt({ constructFn: input.constructFn }, null);
|
||||
|
||||
if (!isObject(contructedPrompt)) {
|
||||
return userError("Prompt is not an object");
|
||||
}
|
||||
if (!("model" in contructedPrompt)) {
|
||||
return userError("Prompt does not define a model");
|
||||
}
|
||||
if (
|
||||
typeof contructedPrompt.model !== "string" ||
|
||||
!(contructedPrompt.model in OpenAIChatModel)
|
||||
) {
|
||||
return userError("Prompt defines an invalid model");
|
||||
}
|
||||
model = contructedPrompt.model;
|
||||
} catch (e) {
|
||||
return userError((e as Error).message);
|
||||
}
|
||||
|
||||
// Create a duplicate with only the config changed
|
||||
const newVariant = await prisma.promptVariant.create({
|
||||
data: {
|
||||
@@ -193,11 +246,12 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
sortIndex: existing.sortIndex,
|
||||
uiId: existing.uiId,
|
||||
constructFn: input.constructFn,
|
||||
model,
|
||||
},
|
||||
});
|
||||
|
||||
// Hide anything with the same uiId besides the new one
|
||||
const hideOldVariantsAction = prisma.promptVariant.updateMany({
|
||||
const hideOldVariants = prisma.promptVariant.updateMany({
|
||||
where: {
|
||||
uiId: existing.uiId,
|
||||
id: {
|
||||
@@ -209,12 +263,20 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.$transaction([
|
||||
hideOldVariantsAction,
|
||||
recordExperimentUpdated(existing.experimentId),
|
||||
]);
|
||||
await prisma.$transaction([hideOldVariants, recordExperimentUpdated(existing.experimentId)]);
|
||||
|
||||
return newVariant;
|
||||
const scenarios = await prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId: newVariant.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
for (const scenario of scenarios) {
|
||||
await generateNewCell(newVariant.id, scenario.id);
|
||||
}
|
||||
|
||||
return { status: "ok" } as const;
|
||||
}),
|
||||
|
||||
reorder: publicProcedure
|
||||
|
||||
68
src/server/api/routers/scenarioVariantCells.router.ts
Normal file
68
src/server/api/routers/scenarioVariantCells.router.ts
Normal file
@@ -0,0 +1,68 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
import { queueLLMRetrievalTask } from "~/server/utils/queueLLMRetrievalTask";
|
||||
|
||||
export const scenarioVariantCellsRouter = createTRPCRouter({
|
||||
get: publicProcedure
|
||||
.input(
|
||||
z.object({
|
||||
scenarioId: z.string(),
|
||||
variantId: z.string(),
|
||||
}),
|
||||
)
|
||||
.query(async ({ input }) => {
|
||||
return await prisma.scenarioVariantCell.findUnique({
|
||||
where: {
|
||||
promptVariantId_testScenarioId: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenarioId: input.scenarioId,
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelOutput: true,
|
||||
},
|
||||
});
|
||||
}),
|
||||
forceRefetch: publicProcedure
|
||||
.input(
|
||||
z.object({
|
||||
scenarioId: z.string(),
|
||||
variantId: z.string(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input }) => {
|
||||
const cell = await prisma.scenarioVariantCell.findUnique({
|
||||
where: {
|
||||
promptVariantId_testScenarioId: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenarioId: input.scenarioId,
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelOutput: true,
|
||||
},
|
||||
});
|
||||
|
||||
if (!cell) {
|
||||
await generateNewCell(input.variantId, input.scenarioId);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (cell.modelOutput) {
|
||||
// TODO: Maybe keep these around to show previous generations?
|
||||
await prisma.modelOutput.delete({
|
||||
where: { id: cell.modelOutput.id },
|
||||
});
|
||||
}
|
||||
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cell.id },
|
||||
data: { retrievalStatus: "PENDING" },
|
||||
});
|
||||
|
||||
await queueLLMRetrievalTask(cell.id);
|
||||
return true;
|
||||
}),
|
||||
});
|
||||
@@ -4,6 +4,7 @@ import { prisma } from "~/server/db";
|
||||
import { autogenerateScenarioValues } from "../autogen";
|
||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||
import { reevaluateAll } from "~/server/utils/evaluations";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
|
||||
export const scenariosRouter = createTRPCRouter({
|
||||
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
||||
@@ -48,10 +49,21 @@ export const scenariosRouter = createTRPCRouter({
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.$transaction([
|
||||
const [scenario] = await prisma.$transaction([
|
||||
createNewScenarioAction,
|
||||
recordExperimentUpdated(input.experimentId),
|
||||
]);
|
||||
|
||||
const promptVariants = await prisma.promptVariant.findMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
for (const variant of promptVariants) {
|
||||
await generateNewCell(variant.id, scenario.id);
|
||||
}
|
||||
}),
|
||||
|
||||
hide: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
|
||||
@@ -175,6 +187,17 @@ export const scenariosRouter = createTRPCRouter({
|
||||
},
|
||||
});
|
||||
|
||||
const promptVariants = await prisma.promptVariant.findMany({
|
||||
where: {
|
||||
experimentId: newScenario.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
for (const variant of promptVariants) {
|
||||
await generateNewCell(variant.id, newScenario.id);
|
||||
}
|
||||
|
||||
return newScenario;
|
||||
}),
|
||||
});
|
||||
|
||||
47
src/server/scripts/migrateScenarioVariantOutputData.ts
Normal file
47
src/server/scripts/migrateScenarioVariantOutputData.ts
Normal file
@@ -0,0 +1,47 @@
|
||||
import { type Prisma } from "@prisma/client";
|
||||
import { prisma } from "../db";
|
||||
|
||||
async function migrateScenarioVariantOutputData() {
|
||||
// Get all ScenarioVariantCells
|
||||
const cells = await prisma.scenarioVariantCell.findMany({ include: { modelOutput: true } });
|
||||
console.log(`Found ${cells.length} records`);
|
||||
|
||||
let updatedCount = 0;
|
||||
|
||||
// Loop through all scenarioVariants
|
||||
for (const cell of cells) {
|
||||
// Create a new ModelOutput for each ScenarioVariant with an existing output
|
||||
if (cell.output && !cell.modelOutput) {
|
||||
updatedCount++;
|
||||
await prisma.modelOutput.create({
|
||||
data: {
|
||||
scenarioVariantCellId: cell.id,
|
||||
inputHash: cell.inputHash || "",
|
||||
output: cell.output as Prisma.InputJsonValue,
|
||||
timeToComplete: cell.timeToComplete ?? undefined,
|
||||
promptTokens: cell.promptTokens,
|
||||
completionTokens: cell.completionTokens,
|
||||
createdAt: cell.createdAt,
|
||||
updatedAt: cell.updatedAt,
|
||||
},
|
||||
});
|
||||
} else if (cell.errorMessage && cell.retrievalStatus === "COMPLETE") {
|
||||
updatedCount++;
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cell.id },
|
||||
data: {
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
console.log("Data migration completed");
|
||||
console.log(`Updated ${updatedCount} records`);
|
||||
}
|
||||
|
||||
// Execute the function
|
||||
migrateScenarioVariantOutputData().catch((error) => {
|
||||
console.error("An error occurred while migrating data: ", error);
|
||||
process.exit(1);
|
||||
});
|
||||
31
src/server/tasks/defineTask.ts
Normal file
31
src/server/tasks/defineTask.ts
Normal file
@@ -0,0 +1,31 @@
|
||||
// Import necessary dependencies
|
||||
import { quickAddJob, type Helpers, type Task } from "graphile-worker";
|
||||
import { env } from "~/env.mjs";
|
||||
|
||||
// Define the defineTask function
|
||||
function defineTask<TPayload>(
|
||||
taskIdentifier: string,
|
||||
taskHandler: (payload: TPayload, helpers: Helpers) => Promise<void>,
|
||||
) {
|
||||
const enqueue = async (payload: TPayload) => {
|
||||
console.log("Enqueuing task", taskIdentifier, payload);
|
||||
await quickAddJob({ connectionString: env.DATABASE_URL }, taskIdentifier, payload);
|
||||
};
|
||||
|
||||
const handler = (payload: TPayload, helpers: Helpers) => {
|
||||
helpers.logger.info(`Running task ${taskIdentifier} with payload: ${JSON.stringify(payload)}`);
|
||||
return taskHandler(payload, helpers);
|
||||
};
|
||||
|
||||
const task = {
|
||||
identifier: taskIdentifier,
|
||||
handler: handler as Task,
|
||||
};
|
||||
|
||||
return {
|
||||
enqueue,
|
||||
task,
|
||||
};
|
||||
}
|
||||
|
||||
export default defineTask;
|
||||
144
src/server/tasks/queryLLM.task.ts
Normal file
144
src/server/tasks/queryLLM.task.ts
Normal file
@@ -0,0 +1,144 @@
|
||||
import crypto from "crypto";
|
||||
import { prisma } from "~/server/db";
|
||||
import defineTask from "./defineTask";
|
||||
import { type CompletionResponse, getCompletion } from "../utils/getCompletion";
|
||||
import { type JSONSerializable } from "../types";
|
||||
import { sleep } from "../utils/sleep";
|
||||
import { shouldStream } from "../utils/shouldStream";
|
||||
import { generateChannel } from "~/utils/generateChannel";
|
||||
import { reevaluateVariant } from "../utils/evaluations";
|
||||
import { constructPrompt } from "../utils/constructPrompt";
|
||||
import { type CompletionCreateParams } from "openai/resources/chat";
|
||||
|
||||
const MAX_AUTO_RETRIES = 10;
|
||||
const MIN_DELAY = 500; // milliseconds
|
||||
const MAX_DELAY = 15000; // milliseconds
|
||||
|
||||
function calculateDelay(numPreviousTries: number): number {
|
||||
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
|
||||
const jitter = Math.random() * baseDelay;
|
||||
return baseDelay + jitter;
|
||||
}
|
||||
|
||||
const getCompletionWithRetries = async (
|
||||
cellId: string,
|
||||
payload: JSONSerializable,
|
||||
channel?: string,
|
||||
): Promise<CompletionResponse> => {
|
||||
for (let i = 0; i < MAX_AUTO_RETRIES; i++) {
|
||||
const modelResponse = await getCompletion(
|
||||
payload as unknown as CompletionCreateParams,
|
||||
channel,
|
||||
);
|
||||
if (modelResponse.statusCode !== 429 || i === MAX_AUTO_RETRIES - 1) {
|
||||
return modelResponse;
|
||||
}
|
||||
const delay = calculateDelay(i);
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
errorMessage: "Rate limit exceeded",
|
||||
statusCode: 429,
|
||||
retryTime: new Date(Date.now() + delay),
|
||||
},
|
||||
});
|
||||
// TODO: Maybe requeue the job so other jobs can run in the future?
|
||||
await sleep(delay);
|
||||
}
|
||||
throw new Error("Max retries limit reached");
|
||||
};
|
||||
|
||||
export type queryLLMJob = {
|
||||
scenarioVariantCellId: string;
|
||||
};
|
||||
|
||||
export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
const { scenarioVariantCellId } = task;
|
||||
const cell = await prisma.scenarioVariantCell.findUnique({
|
||||
where: { id: scenarioVariantCellId },
|
||||
include: { modelOutput: true },
|
||||
});
|
||||
if (!cell) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If cell is not pending, then some other job is already processing it
|
||||
if (cell.retrievalStatus !== "PENDING") {
|
||||
return;
|
||||
}
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: scenarioVariantCellId },
|
||||
data: {
|
||||
retrievalStatus: "IN_PROGRESS",
|
||||
},
|
||||
});
|
||||
|
||||
const variant = await prisma.promptVariant.findUnique({
|
||||
where: { id: cell.promptVariantId },
|
||||
});
|
||||
if (!variant) {
|
||||
return;
|
||||
}
|
||||
|
||||
const scenario = await prisma.testScenario.findUnique({
|
||||
where: { id: cell.testScenarioId },
|
||||
});
|
||||
if (!scenario) {
|
||||
return;
|
||||
}
|
||||
|
||||
const prompt = await constructPrompt(variant, scenario.variableValues);
|
||||
|
||||
const streamingEnabled = shouldStream(prompt);
|
||||
let streamingChannel;
|
||||
|
||||
if (streamingEnabled) {
|
||||
streamingChannel = generateChannel();
|
||||
// Save streaming channel so that UI can connect to it
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: scenarioVariantCellId },
|
||||
data: {
|
||||
streamingChannel,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const modelResponse = await getCompletionWithRetries(
|
||||
scenarioVariantCellId,
|
||||
prompt,
|
||||
streamingChannel,
|
||||
);
|
||||
|
||||
let modelOutput = null;
|
||||
if (modelResponse.statusCode === 200) {
|
||||
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
|
||||
|
||||
modelOutput = await prisma.modelOutput.create({
|
||||
data: {
|
||||
scenarioVariantCellId,
|
||||
inputHash,
|
||||
output: modelResponse.output,
|
||||
timeToComplete: modelResponse.timeToComplete,
|
||||
promptTokens: modelResponse.promptTokens,
|
||||
completionTokens: modelResponse.completionTokens,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: scenarioVariantCellId },
|
||||
data: {
|
||||
statusCode: modelResponse.statusCode,
|
||||
errorMessage: modelResponse.errorMessage,
|
||||
streamingChannel: null,
|
||||
retrievalStatus: modelOutput ? "COMPLETE" : "ERROR",
|
||||
modelOutput: {
|
||||
connect: {
|
||||
id: modelOutput?.id,
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await reevaluateVariant(cell.promptVariantId);
|
||||
});
|
||||
40
src/server/tasks/worker.ts
Normal file
40
src/server/tasks/worker.ts
Normal file
@@ -0,0 +1,40 @@
|
||||
import { type TaskList, run } from "graphile-worker";
|
||||
import "dotenv/config";
|
||||
|
||||
import { env } from "~/env.mjs";
|
||||
import { queryLLM } from "./queryLLM.task";
|
||||
|
||||
const registeredTasks = [queryLLM];
|
||||
|
||||
const taskList = registeredTasks.reduce((acc, task) => {
|
||||
acc[task.task.identifier] = task.task.handler;
|
||||
return acc;
|
||||
}, {} as TaskList);
|
||||
|
||||
async function main() {
|
||||
// Run a worker to execute jobs:
|
||||
const runner = await run({
|
||||
connectionString: env.DATABASE_URL,
|
||||
concurrency: 20,
|
||||
// Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc
|
||||
noHandleSignals: false,
|
||||
pollInterval: 1000,
|
||||
// you can set the taskList or taskDirectory but not both
|
||||
taskList,
|
||||
// or:
|
||||
// taskDirectory: `${__dirname}/tasks`,
|
||||
});
|
||||
|
||||
// Immediately await (or otherwise handled) the resulting promise, to avoid
|
||||
// "unhandled rejection" errors causing a process crash in the event of
|
||||
// something going wrong.
|
||||
await runner.promise;
|
||||
|
||||
// If the worker exits (whether through fatal error or otherwise), the above
|
||||
// promise will resolve/reject.
|
||||
}
|
||||
|
||||
main().catch((err) => {
|
||||
console.error("Unhandled error occurred running worker: ", err);
|
||||
process.exit(1);
|
||||
});
|
||||
@@ -7,9 +7,7 @@ test.skip("constructPrompt", async () => {
|
||||
constructFn: `prompt = { "fooz": "bar" }`,
|
||||
},
|
||||
{
|
||||
variableValues: {
|
||||
foo: "bar",
|
||||
},
|
||||
foo: "bar",
|
||||
},
|
||||
);
|
||||
|
||||
|
||||
@@ -6,12 +6,10 @@ const isolate = new ivm.Isolate({ memoryLimit: 128 });
|
||||
|
||||
export async function constructPrompt(
|
||||
variant: Pick<PromptVariant, "constructFn">,
|
||||
testScenario: Pick<TestScenario, "variableValues">,
|
||||
scenario: TestScenario["variableValues"],
|
||||
): Promise<JSONSerializable> {
|
||||
const scenario = testScenario.variableValues as JSONSerializable;
|
||||
|
||||
const code = `
|
||||
const scenario = ${JSON.stringify(scenario, null, 2)};
|
||||
const scenario = ${JSON.stringify(scenario ?? {}, null, 2)};
|
||||
let prompt
|
||||
|
||||
${variant.constructFn}
|
||||
|
||||
6
src/server/utils/error.ts
Normal file
6
src/server/utils/error.ts
Normal file
@@ -0,0 +1,6 @@
|
||||
export default function userError(message: string): { status: "error"; message: string } {
|
||||
return {
|
||||
status: "error",
|
||||
message,
|
||||
};
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
import { type Evaluation } from "@prisma/client";
|
||||
import { type ModelOutput, type Evaluation } from "@prisma/client";
|
||||
import { prisma } from "../db";
|
||||
import { evaluateOutput } from "./evaluateOutput";
|
||||
|
||||
@@ -12,21 +12,22 @@ export const reevaluateVariant = async (variantId: string) => {
|
||||
where: { experimentId: variant.experimentId },
|
||||
});
|
||||
|
||||
const modelOutputs = await prisma.modelOutput.findMany({
|
||||
const cells = await prisma.scenarioVariantCell.findMany({
|
||||
where: {
|
||||
promptVariantId: variantId,
|
||||
statusCode: { notIn: [429] },
|
||||
retrievalStatus: "COMPLETE",
|
||||
testScenario: { visible: true },
|
||||
modelOutput: { isNot: null },
|
||||
},
|
||||
include: { testScenario: true },
|
||||
include: { testScenario: true, modelOutput: true },
|
||||
});
|
||||
|
||||
await Promise.all(
|
||||
evaluations.map(async (evaluation) => {
|
||||
const passCount = modelOutputs.filter((output) =>
|
||||
evaluateOutput(output, output.testScenario, evaluation),
|
||||
const passCount = cells.filter((cell) =>
|
||||
evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
|
||||
).length;
|
||||
const failCount = modelOutputs.length - passCount;
|
||||
const failCount = cells.length - passCount;
|
||||
|
||||
await prisma.evaluationResult.upsert({
|
||||
where: {
|
||||
@@ -55,22 +56,23 @@ export const reevaluateEvaluation = async (evaluation: Evaluation) => {
|
||||
where: { experimentId: evaluation.experimentId, visible: true },
|
||||
});
|
||||
|
||||
const modelOutputs = await prisma.modelOutput.findMany({
|
||||
const cells = await prisma.scenarioVariantCell.findMany({
|
||||
where: {
|
||||
promptVariantId: { in: variants.map((v) => v.id) },
|
||||
testScenario: { visible: true },
|
||||
statusCode: { notIn: [429] },
|
||||
modelOutput: { isNot: null },
|
||||
},
|
||||
include: { testScenario: true },
|
||||
include: { testScenario: true, modelOutput: true },
|
||||
});
|
||||
|
||||
await Promise.all(
|
||||
variants.map(async (variant) => {
|
||||
const outputs = modelOutputs.filter((output) => output.promptVariantId === variant.id);
|
||||
const passCount = outputs.filter((output) =>
|
||||
evaluateOutput(output, output.testScenario, evaluation),
|
||||
const variantCells = cells.filter((cell) => cell.promptVariantId === variant.id);
|
||||
const passCount = variantCells.filter((cell) =>
|
||||
evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
|
||||
).length;
|
||||
const failCount = outputs.length - passCount;
|
||||
const failCount = variantCells.length - passCount;
|
||||
|
||||
await prisma.evaluationResult.upsert({
|
||||
where: {
|
||||
|
||||
76
src/server/utils/generateNewCell.ts
Normal file
76
src/server/utils/generateNewCell.ts
Normal file
@@ -0,0 +1,76 @@
|
||||
import crypto from "crypto";
|
||||
import { type Prisma } from "@prisma/client";
|
||||
import { prisma } from "../db";
|
||||
import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
|
||||
import { constructPrompt } from "./constructPrompt";
|
||||
|
||||
export const generateNewCell = async (variantId: string, scenarioId: string) => {
|
||||
const variant = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: variantId,
|
||||
},
|
||||
});
|
||||
|
||||
const scenario = await prisma.testScenario.findUnique({
|
||||
where: {
|
||||
id: scenarioId,
|
||||
},
|
||||
});
|
||||
|
||||
if (!variant || !scenario) return null;
|
||||
|
||||
const prompt = await constructPrompt(variant, scenario.variableValues);
|
||||
|
||||
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
|
||||
|
||||
let cell = await prisma.scenarioVariantCell.findUnique({
|
||||
where: {
|
||||
promptVariantId_testScenarioId: {
|
||||
promptVariantId: variantId,
|
||||
testScenarioId: scenarioId,
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelOutput: true,
|
||||
},
|
||||
});
|
||||
|
||||
if (cell) return cell;
|
||||
|
||||
cell = await prisma.scenarioVariantCell.create({
|
||||
data: {
|
||||
promptVariantId: variantId,
|
||||
testScenarioId: scenarioId,
|
||||
},
|
||||
include: {
|
||||
modelOutput: true,
|
||||
},
|
||||
});
|
||||
|
||||
const matchingModelOutput = await prisma.modelOutput.findFirst({
|
||||
where: {
|
||||
inputHash,
|
||||
},
|
||||
});
|
||||
|
||||
let newModelOutput;
|
||||
|
||||
if (matchingModelOutput) {
|
||||
newModelOutput = await prisma.modelOutput.create({
|
||||
data: {
|
||||
scenarioVariantCellId: cell.id,
|
||||
inputHash,
|
||||
output: matchingModelOutput.output as Prisma.InputJsonValue,
|
||||
timeToComplete: matchingModelOutput.timeToComplete,
|
||||
promptTokens: matchingModelOutput.promptTokens,
|
||||
completionTokens: matchingModelOutput.completionTokens,
|
||||
createdAt: matchingModelOutput.createdAt,
|
||||
updatedAt: matchingModelOutput.updatedAt,
|
||||
},
|
||||
});
|
||||
} else {
|
||||
cell = await queueLLMRetrievalTask(cell.id);
|
||||
}
|
||||
|
||||
return { ...cell, modelOutput: newModelOutput };
|
||||
};
|
||||
@@ -4,15 +4,12 @@ import { Prisma } from "@prisma/client";
|
||||
import { streamChatCompletion } from "./openai";
|
||||
import { wsConnection } from "~/utils/wsConnection";
|
||||
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
||||
import { type JSONSerializable, OpenAIChatModel } from "../types";
|
||||
import { type OpenAIChatModel } from "../types";
|
||||
import { env } from "~/env.mjs";
|
||||
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||
import { getModelName } from "./getModelName";
|
||||
import { rateLimitErrorMessage } from "~/sharedStrings";
|
||||
|
||||
env;
|
||||
|
||||
type CompletionResponse = {
|
||||
export type CompletionResponse = {
|
||||
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
|
||||
statusCode: number;
|
||||
errorMessage: string | null;
|
||||
@@ -22,35 +19,7 @@ type CompletionResponse = {
|
||||
};
|
||||
|
||||
export async function getCompletion(
|
||||
payload: JSONSerializable,
|
||||
channel?: string,
|
||||
): Promise<CompletionResponse> {
|
||||
const modelName = getModelName(payload);
|
||||
if (!modelName)
|
||||
return {
|
||||
output: Prisma.JsonNull,
|
||||
statusCode: 400,
|
||||
errorMessage: "Invalid payload provided",
|
||||
timeToComplete: 0,
|
||||
};
|
||||
if (modelName in OpenAIChatModel) {
|
||||
return getOpenAIChatCompletion(
|
||||
payload as unknown as CompletionCreateParams,
|
||||
env.OPENAI_API_KEY,
|
||||
channel,
|
||||
);
|
||||
}
|
||||
return {
|
||||
output: Prisma.JsonNull,
|
||||
statusCode: 400,
|
||||
errorMessage: "Invalid model provided",
|
||||
timeToComplete: 0,
|
||||
};
|
||||
}
|
||||
|
||||
export async function getOpenAIChatCompletion(
|
||||
payload: CompletionCreateParams,
|
||||
apiKey: string,
|
||||
channel?: string,
|
||||
): Promise<CompletionResponse> {
|
||||
// If functions are enabled, disable streaming so that we get the full response with token counts
|
||||
@@ -60,7 +29,7 @@ export async function getOpenAIChatCompletion(
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
Authorization: `Bearer ${env.OPENAI_API_KEY}`,
|
||||
},
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
import { isObject } from "lodash";
|
||||
import { type JSONSerializable, type SupportedModel } from "../types";
|
||||
import { type Prisma } from "@prisma/client";
|
||||
|
||||
export function getModelName(config: JSONSerializable | Prisma.JsonValue): SupportedModel | null {
|
||||
if (!isObject(config)) return null;
|
||||
if ("model" in config && typeof config.model === "string") return config.model as SupportedModel;
|
||||
return null;
|
||||
}
|
||||
22
src/server/utils/queueLLMRetrievalTask.ts
Normal file
22
src/server/utils/queueLLMRetrievalTask.ts
Normal file
@@ -0,0 +1,22 @@
|
||||
import { prisma } from "../db";
|
||||
import { queryLLM } from "../tasks/queryLLM.task";
|
||||
|
||||
export const queueLLMRetrievalTask = async (cellId: string) => {
|
||||
const updatedCell = await prisma.scenarioVariantCell.update({
|
||||
where: {
|
||||
id: cellId,
|
||||
},
|
||||
data: {
|
||||
retrievalStatus: "PENDING",
|
||||
errorMessage: null,
|
||||
},
|
||||
include: {
|
||||
modelOutput: true,
|
||||
},
|
||||
});
|
||||
|
||||
// @ts-expect-error we aren't passing the helpers but that's ok
|
||||
void queryLLM.task.handler({ scenarioVariantCellId: cellId }, { logger: console });
|
||||
|
||||
return updatedCell;
|
||||
};
|
||||
7
src/server/utils/shouldStream.ts
Normal file
7
src/server/utils/shouldStream.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
import { isObject } from "lodash";
|
||||
import { type JSONSerializable } from "../types";
|
||||
|
||||
export const shouldStream = (config: JSONSerializable): boolean => {
|
||||
const shouldStream = isObject(config) && "stream" in config && config.stream === true;
|
||||
return shouldStream;
|
||||
};
|
||||
1
src/server/utils/sleep.ts
Normal file
1
src/server/utils/sleep.ts
Normal file
@@ -0,0 +1 @@
|
||||
export const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
|
||||
@@ -2,6 +2,9 @@ import { type RouterOutputs } from "~/utils/api";
|
||||
import { type SliceCreator } from "./store";
|
||||
import loader from "@monaco-editor/loader";
|
||||
import openAITypes from "~/codegen/openai.types.ts.txt";
|
||||
import formatPromptConstructor from "~/utils/formatPromptConstructor";
|
||||
|
||||
export const editorBackground = "#fafafa";
|
||||
|
||||
export type SharedVariantEditorSlice = {
|
||||
monaco: null | ReturnType<typeof loader.__getMonacoInstance>;
|
||||
@@ -14,6 +17,9 @@ export type SharedVariantEditorSlice = {
|
||||
export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> = (set, get) => ({
|
||||
monaco: loader.__getMonacoInstance(),
|
||||
loadMonaco: async () => {
|
||||
// We only want to run this client-side
|
||||
if (typeof window === "undefined") return;
|
||||
|
||||
const monaco = await loader.init();
|
||||
|
||||
monaco.editor.defineTheme("customTheme", {
|
||||
@@ -21,12 +27,13 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
|
||||
inherit: true,
|
||||
rules: [],
|
||||
colors: {
|
||||
"editor.background": "#fafafa",
|
||||
"editor.background": editorBackground,
|
||||
},
|
||||
});
|
||||
|
||||
monaco.languages.typescript.typescriptDefaults.setCompilerOptions({
|
||||
allowNonTsExtensions: true,
|
||||
strictNullChecks: true,
|
||||
lib: ["esnext"],
|
||||
});
|
||||
|
||||
@@ -40,6 +47,17 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
|
||||
monaco.Uri.parse("file:///openai.types.ts"),
|
||||
);
|
||||
|
||||
monaco.languages.registerDocumentFormattingEditProvider("typescript", {
|
||||
provideDocumentFormattingEdits: async (model) => {
|
||||
return [
|
||||
{
|
||||
range: model.getFullModelRange(),
|
||||
text: await formatPromptConstructor(model.getValue()),
|
||||
},
|
||||
];
|
||||
},
|
||||
});
|
||||
|
||||
set((state) => {
|
||||
state.sharedVariantEditor.monaco = monaco;
|
||||
});
|
||||
@@ -67,7 +85,7 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
|
||||
)} as const;
|
||||
|
||||
type Scenario = typeof scenarios[number];
|
||||
declare var scenario: Scenario | null;
|
||||
declare var scenario: Scenario | { [key: string]: string };
|
||||
`;
|
||||
|
||||
const scenariosModel = monaco.editor.getModel(monaco.Uri.parse("file:///scenarios.ts"));
|
||||
|
||||
@@ -23,7 +23,7 @@ const openAICompletionTokensToDollars: { [key in OpenAIChatModel]: number } = {
|
||||
};
|
||||
|
||||
export const calculateTokenCost = (
|
||||
model: SupportedModel | null,
|
||||
model: SupportedModel | string | null,
|
||||
numTokens: number,
|
||||
isCompletion = false,
|
||||
) => {
|
||||
|
||||
10
src/utils/formatPromptConstructor.test.ts
Normal file
10
src/utils/formatPromptConstructor.test.ts
Normal file
@@ -0,0 +1,10 @@
|
||||
import { expect, test } from "vitest";
|
||||
import { stripTypes } from "./formatPromptConstructor";
|
||||
|
||||
test("stripTypes", () => {
|
||||
expect(stripTypes(`const foo: string = "bar";`)).toBe(`const foo = "bar";`);
|
||||
});
|
||||
|
||||
test("stripTypes with invalid syntax", () => {
|
||||
expect(stripTypes(`asdf foo: string = "bar"`)).toBe(`asdf foo: string = "bar"`);
|
||||
});
|
||||
31
src/utils/formatPromptConstructor.ts
Normal file
31
src/utils/formatPromptConstructor.ts
Normal file
@@ -0,0 +1,31 @@
|
||||
import prettier from "prettier/standalone";
|
||||
import parserTypescript from "prettier/plugins/typescript";
|
||||
|
||||
// @ts-expect-error for some reason missing from types
|
||||
import parserEstree from "prettier/plugins/estree";
|
||||
|
||||
import * as babel from "@babel/standalone";
|
||||
|
||||
export function stripTypes(tsCode: string): string {
|
||||
const options = {
|
||||
presets: ["typescript"],
|
||||
filename: "file.ts",
|
||||
};
|
||||
|
||||
try {
|
||||
const result = babel.transform(tsCode, options);
|
||||
return result.code ?? tsCode;
|
||||
} catch (error) {
|
||||
console.error("Error stripping types", error);
|
||||
return tsCode;
|
||||
}
|
||||
}
|
||||
|
||||
export default async function formatPromptConstructor(code: string): Promise<string> {
|
||||
return await prettier.format(stripTypes(code), {
|
||||
parser: "typescript",
|
||||
plugins: [parserTypescript, parserEstree],
|
||||
// We're showing these in pretty narrow panes so let's keep the print width low
|
||||
printWidth: 60,
|
||||
});
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import { env } from "~/env.mjs";
|
||||
|
||||
const url = env.NEXT_PUBLIC_SOCKET_URL;
|
||||
|
||||
export default function useSocket(channel?: string) {
|
||||
export default function useSocket(channel?: string | null) {
|
||||
const socketRef = useRef<Socket>();
|
||||
const [message, setMessage] = useState<ChatCompletion | null>(null);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user