Compare commits

..

19 Commits

Author SHA1 Message Date
Kyle Corbitt
74c201d3a8 Make save button disappear on save
Fixes a bug where often the "Save" button wouldn't disappear as expected the first time you clicked it.
2023-07-14 17:35:57 -07:00
David Corbitt
ab9c721d09 Revert change to scenarios header 2023-07-14 17:51:12 -06:00
David Corbitt
0a2578a1d8 Update scenarios header negative margin 2023-07-14 17:41:51 -06:00
David Corbitt
1bebaff386 Merge branch 'main' of github.com:corbt/prompt-lab 2023-07-14 16:55:12 -06:00
David Corbitt
3bf5eaf4a2 Properly extract scenario id in new experiment creation 2023-07-14 16:55:09 -06:00
Kyle Corbitt
ded97f8bb9 fix lockfile 2023-07-14 15:55:01 -07:00
Kyle Corbitt
26ee8698be Make it so you can't delete the last prompt or scenario
No reason for an experiment to have 0 prompts or 0 scenarios and it makes the UI look bad.
2023-07-14 15:49:42 -07:00
arcticfly
b98eb9b729 Trigger llm output retrieval on server (#39)
* Rename tables, add graphile workers, update types

* Add dev:worker command

* Update pnpm-lock.yaml

* Remove sentry config import from worker.ts

* Stop generating new cells in cell router get query

* Generate new cells for new scenarios, variants, and experiments

* Remove most error throwing from queryLLM.task.ts

* Remove promptVariantId and testScenarioId from ModelOutput

* Remove duplicate index from ModelOutput

* Move inputHash from cell to output

* Add TODO

* Add todo

* Show cost and time for each cell

* Always show output stats if there is output

* Trigger LLM outputs when scenario variables are updated

* Add newlines to ends of files

* Add another newline

* Cascade ModelOutput deletion

* Fix linting and prettier

* Return instead of throwing for non-pending cell

* Remove pnpm dev:worker from pnpm:dev

* Update pnpm-lock.yaml
2023-07-14 16:38:46 -06:00
Kyle Corbitt
032c07ec65 Merge pull request #45 from OpenPipe/node-version
warn folks if they use a lower node version
2023-07-14 15:03:49 -07:00
Kyle Corbitt
80c0d13bb9 warn folks if they use a lower node version 2023-07-14 14:59:33 -07:00
Kyle Corbitt
f7c94be3f6 Merge pull request #44 from OpenPipe/strip-types
Strip types from prompt variants
2023-07-14 14:07:07 -07:00
Kyle Corbitt
c3e85607e0 Strip types from prompt variants
We want Monaco to treat the prompt constructor as Typescript so we get type checks, but we actually want to save the prompt constructor as Javascript so we can run it directly without transpiling.
2023-07-14 14:03:28 -07:00
Kyle Corbitt
cd5927b8f5 Merge pull request #43 from OpenPipe/function-ux
Pseudo function signatures
2023-07-14 14:01:10 -07:00
Kyle Corbitt
731406d1f4 Pseudo function signatures
Show pseudo function signatures in the variant editor box as a UX hint that you're typing in javascript and have access to the scenario.
2023-07-14 13:56:45 -07:00
Kyle Corbitt
3c59e4b774 Merge pull request #42 from OpenPipe/autoformat
implement format on save
2023-07-14 12:56:41 -07:00
Kyle Corbitt
972b1f2333 Merge pull request #41 from OpenPipe/github-actions
CI checks
2023-07-14 11:40:42 -07:00
Kyle Corbitt
7321f3deda CI checks 2023-07-14 11:36:47 -07:00
Kyle Corbitt
2bd41fdfbf Merge pull request #40 from OpenPipe:completion-costs
store model and use to calculate completion costs
2023-07-14 11:07:15 -07:00
Kyle Corbitt
a5378b106b store model and use to calculate completion costs 2023-07-14 11:06:07 -07:00
44 changed files with 1713 additions and 591 deletions

51
.github/workflows/ci.yaml vendored Normal file
View File

@@ -0,0 +1,51 @@
name: CI checks
on:
pull_request:
branches: [main]
jobs:
run-checks:
runs-on: ubuntu-latest
steps:
- name: Check out code
uses: actions/checkout@v2
- name: Set up Node.js
uses: actions/setup-node@v2
with:
node-version: "20"
- uses: pnpm/action-setup@v2
name: Install pnpm
id: pnpm-install
with:
version: 8.6.1
run_install: false
- name: Get pnpm store directory
id: pnpm-cache
shell: bash
run: |
echo "STORE_PATH=$(pnpm store path)" >> $GITHUB_OUTPUT
- uses: actions/cache@v3
name: Setup pnpm cache
with:
path: ${{ steps.pnpm-cache.outputs.STORE_PATH }}
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }}
restore-keys: |
${{ runner.os }}-pnpm-store-
- name: Install Dependencies
run: pnpm install
- name: Check types
run: pnpm tsc
- name: Lint
run: SKIP_ENV_VALIDATION=1 pnpm lint
- name: Check prettier
run: pnpm prettier . --check

1
.tool-versions Normal file
View File

@@ -0,0 +1 @@
nodejs 20.2.0

View File

@@ -3,10 +3,15 @@
"type": "module", "type": "module",
"version": "0.1.0", "version": "0.1.0",
"license": "Apache-2.0", "license": "Apache-2.0",
"engines": {
"node": ">=20.0.0",
"pnpm": ">=8.6.1"
},
"scripts": { "scripts": {
"build": "next build", "build": "next build",
"dev:next": "next dev", "dev:next": "next dev",
"dev:wss": "pnpm tsx --watch src/wss-server.ts", "dev:wss": "pnpm tsx --watch src/wss-server.ts",
"dev:worker": "NODE_ENV='development' pnpm tsx --watch src/server/tasks/worker.ts",
"dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss'", "dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss'",
"postinstall": "prisma generate", "postinstall": "prisma generate",
"lint": "next lint", "lint": "next lint",
@@ -14,6 +19,8 @@
"codegen": "tsx src/codegen/export-openai-types.ts" "codegen": "tsx src/codegen/export-openai-types.ts"
}, },
"dependencies": { "dependencies": {
"@babel/preset-typescript": "^7.22.5",
"@babel/standalone": "^7.22.9",
"@chakra-ui/next-js": "^2.1.4", "@chakra-ui/next-js": "^2.1.4",
"@chakra-ui/react": "^2.7.1", "@chakra-ui/react": "^2.7.1",
"@emotion/react": "^11.11.1", "@emotion/react": "^11.11.1",
@@ -38,6 +45,7 @@
"express": "^4.18.2", "express": "^4.18.2",
"framer-motion": "^10.12.17", "framer-motion": "^10.12.17",
"gpt-tokens": "^1.0.10", "gpt-tokens": "^1.0.10",
"graphile-worker": "^0.13.0",
"immer": "^10.0.2", "immer": "^10.0.2",
"isolated-vm": "^4.5.0", "isolated-vm": "^4.5.0",
"json-stringify-pretty-compact": "^4.0.0", "json-stringify-pretty-compact": "^4.0.0",
@@ -63,6 +71,8 @@
}, },
"devDependencies": { "devDependencies": {
"@openapi-contrib/openapi-schema-to-json-schema": "^4.0.5", "@openapi-contrib/openapi-schema-to-json-schema": "^4.0.5",
"@types/babel__core": "^7.20.1",
"@types/babel__standalone": "^7.1.4",
"@types/chroma-js": "^2.4.0", "@types/chroma-js": "^2.4.0",
"@types/cors": "^2.8.13", "@types/cors": "^2.8.13",
"@types/eslint": "^8.37.0", "@types/eslint": "^8.37.0",

894
pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,49 @@
-- Drop the foreign key constraints on the original ModelOutput
ALTER TABLE "ModelOutput" DROP CONSTRAINT "ModelOutput_promptVariantId_fkey";
ALTER TABLE "ModelOutput" DROP CONSTRAINT "ModelOutput_testScenarioId_fkey";
-- Rename the old table
ALTER TABLE "ModelOutput" RENAME TO "ScenarioVariantCell";
ALTER TABLE "ScenarioVariantCell" RENAME CONSTRAINT "ModelOutput_pkey" TO "ScenarioVariantCell_pkey";
ALTER INDEX "ModelOutput_inputHash_idx" RENAME TO "ScenarioVariantCell_inputHash_idx";
ALTER INDEX "ModelOutput_promptVariantId_testScenarioId_key" RENAME TO "ScenarioVariantCell_promptVariantId_testScenarioId_key";
-- Add the new fields to the renamed table
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "retryTime" TIMESTAMP(3);
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "streamingChannel" TEXT;
ALTER TABLE "ScenarioVariantCell" ALTER COLUMN "inputHash" DROP NOT NULL;
ALTER TABLE "ScenarioVariantCell" ALTER COLUMN "output" DROP NOT NULL,
ALTER COLUMN "statusCode" DROP NOT NULL,
ALTER COLUMN "timeToComplete" DROP NOT NULL;
-- Create the new table
CREATE TABLE "ModelOutput" (
"id" UUID NOT NULL,
"inputHash" TEXT NOT NULL,
"output" JSONB NOT NULL,
"timeToComplete" INTEGER NOT NULL DEFAULT 0,
"promptTokens" INTEGER,
"completionTokens" INTEGER,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"scenarioVariantCellId" UUID
);
-- Move inputHash index
DROP INDEX "ScenarioVariantCell_inputHash_idx";
CREATE INDEX "ModelOutput_inputHash_idx" ON "ModelOutput"("inputHash");
CREATE UNIQUE INDEX "ModelOutput_scenarioVariantCellId_key" ON "ModelOutput"("scenarioVariantCellId");
ALTER TABLE "ModelOutput" ADD CONSTRAINT "ModelOutput_scenarioVariantCellId_fkey" FOREIGN KEY ("scenarioVariantCellId") REFERENCES "ScenarioVariantCell"("id") ON DELETE CASCADE ON UPDATE CASCADE;
ALTER TABLE "ModelOutput" ALTER COLUMN "scenarioVariantCellId" SET NOT NULL,
ADD CONSTRAINT "ModelOutput_pkey" PRIMARY KEY ("id");
ALTER TABLE "ScenarioVariantCell" ADD CONSTRAINT "ScenarioVariantCell_promptVariantId_fkey" FOREIGN KEY ("promptVariantId") REFERENCES "PromptVariant"("id") ON DELETE CASCADE ON UPDATE CASCADE;
ALTER TABLE "ScenarioVariantCell" ADD CONSTRAINT "ScenarioVariantCell_testScenarioId_fkey" FOREIGN KEY ("testScenarioId") REFERENCES "TestScenario"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- CreateEnum
CREATE TYPE "CellRetrievalStatus" AS ENUM ('PENDING', 'IN_PROGRESS', 'COMPLETE', 'ERROR');
-- AlterTable
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "retrievalStatus" "CellRetrievalStatus" NOT NULL DEFAULT 'COMPLETE';

View File

@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "PromptVariant" ADD COLUMN "model" TEXT NOT NULL DEFAULT 'gpt-3.5-turbo';

View File

@@ -27,9 +27,10 @@ model Experiment {
model PromptVariant { model PromptVariant {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
label String
label String
constructFn String constructFn String
model String @default("gpt-3.5-turbo")
uiId String @default(uuid()) @db.Uuid uiId String @default(uuid()) @db.Uuid
visible Boolean @default(true) visible Boolean @default(true)
@@ -40,7 +41,7 @@ model PromptVariant {
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
ModelOutput ModelOutput[] scenarioVariantCells ScenarioVariantCell[]
EvaluationResult EvaluationResult[] EvaluationResult EvaluationResult[]
@@index([uiId]) @@index([uiId])
@@ -60,7 +61,7 @@ model TestScenario {
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
ModelOutput ModelOutput[] scenarioVariantCells ScenarioVariantCell[]
} }
model TemplateVariable { model TemplateVariable {
@@ -75,17 +76,28 @@ model TemplateVariable {
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
} }
model ModelOutput { enum CellRetrievalStatus {
PENDING
IN_PROGRESS
COMPLETE
ERROR
}
model ScenarioVariantCell {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
inputHash String inputHash String? // TODO: Remove once migration is complete
output Json output Json? // TODO: Remove once migration is complete
statusCode Int statusCode Int?
errorMessage String? errorMessage String?
timeToComplete Int @default(0) timeToComplete Int? @default(0) // TODO: Remove once migration is complete
retryTime DateTime?
streamingChannel String?
retrievalStatus CellRetrievalStatus @default(COMPLETE)
promptTokens Int? // Added promptTokens field promptTokens Int? // TODO: Remove once migration is complete
completionTokens Int? // Added completionTokens field completionTokens Int? // TODO: Remove once migration is complete
modelOutput ModelOutput?
promptVariantId String @db.Uuid promptVariantId String @db.Uuid
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade) promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
@@ -97,6 +109,24 @@ model ModelOutput {
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
@@unique([promptVariantId, testScenarioId]) @@unique([promptVariantId, testScenarioId])
}
model ModelOutput {
id String @id @default(uuid()) @db.Uuid
inputHash String
output Json
timeToComplete Int @default(0)
promptTokens Int?
completionTokens Int?
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
scenarioVariantCellId String @db.Uuid
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
@@unique([scenarioVariantCellId])
@@index([inputHash]) @@index([inputHash])
} }

View File

@@ -16,7 +16,7 @@ const experiment = await prisma.experiment.create({
}, },
}); });
await prisma.modelOutput.deleteMany({ await prisma.scenarioVariantCell.deleteMany({
where: { where: {
promptVariant: { promptVariant: {
experimentId, experimentId,

View File

@@ -0,0 +1,33 @@
import { Button, HStack, Icon } from "@chakra-ui/react";
import { BsArrowClockwise } from "react-icons/bs";
export const CellOptions = ({
refetchingOutput,
refetchOutput,
}: {
refetchingOutput: boolean;
refetchOutput: () => void;
}) => {
return (
<HStack justifyContent="flex-end" w="full">
{!refetchingOutput && (
<Button
size="xs"
w={4}
h={4}
py={4}
px={4}
minW={0}
borderRadius={8}
color="gray.500"
variant="ghost"
cursor="pointer"
onClick={refetchOutput}
aria-label="refetch output"
>
<Icon as={BsArrowClockwise} boxSize={4} />
</Button>
)}
</HStack>
);
};

View File

@@ -1,29 +1,21 @@
import { type ModelOutput } from "@prisma/client"; import { type ScenarioVariantCell } from "@prisma/client";
import { HStack, VStack, Text, Button, Icon } from "@chakra-ui/react"; import { VStack, Text } from "@chakra-ui/react";
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import { BsArrowClockwise } from "react-icons/bs";
import { rateLimitErrorMessage } from "~/sharedStrings";
import pluralize from "pluralize"; import pluralize from "pluralize";
const MAX_AUTO_RETRIES = 3;
export const ErrorHandler = ({ export const ErrorHandler = ({
output, cell,
refetchOutput, refetchOutput,
numPreviousTries,
}: { }: {
output: ModelOutput; cell: ScenarioVariantCell;
refetchOutput: () => void; refetchOutput: () => void;
numPreviousTries: number;
}) => { }) => {
const [msToWait, setMsToWait] = useState(0); const [msToWait, setMsToWait] = useState(0);
const shouldAutoRetry =
output.errorMessage === rateLimitErrorMessage && numPreviousTries < MAX_AUTO_RETRIES;
useEffect(() => { useEffect(() => {
if (!shouldAutoRetry) return; if (!cell.retryTime) return;
const initialWaitTime = calculateDelay(numPreviousTries); const initialWaitTime = cell.retryTime.getTime() - Date.now();
const msModuloOneSecond = initialWaitTime % 1000; const msModuloOneSecond = initialWaitTime % 1000;
let remainingTime = initialWaitTime - msModuloOneSecond; let remainingTime = initialWaitTime - msModuloOneSecond;
setMsToWait(remainingTime); setMsToWait(remainingTime);
@@ -35,7 +27,6 @@ export const ErrorHandler = ({
setMsToWait(remainingTime); setMsToWait(remainingTime);
if (remainingTime <= 0) { if (remainingTime <= 0) {
refetchOutput();
clearInterval(interval); clearInterval(interval);
} }
}, 1000); }, 1000);
@@ -45,32 +36,12 @@ export const ErrorHandler = ({
clearInterval(interval); clearInterval(interval);
clearTimeout(timeout); clearTimeout(timeout);
}; };
}, [shouldAutoRetry, setMsToWait, refetchOutput, numPreviousTries]); }, [cell.retryTime, cell.statusCode, setMsToWait, refetchOutput]);
return ( return (
<VStack w="full"> <VStack w="full">
<HStack w="full" alignItems="flex-start" justifyContent="space-between">
<Text color="red.600" fontWeight="bold">
Error
</Text>
<Button
size="xs"
w={4}
h={4}
px={4}
py={4}
minW={0}
borderRadius={8}
variant="ghost"
cursor="pointer"
onClick={refetchOutput}
aria-label="refetch output"
>
<Icon as={BsArrowClockwise} boxSize={6} />
</Button>
</HStack>
<Text color="red.600" wordBreak="break-word"> <Text color="red.600" wordBreak="break-word">
{output.errorMessage} {cell.errorMessage}
</Text> </Text>
{msToWait > 0 && ( {msToWait > 0 && (
<Text color="red.600" fontSize="sm"> <Text color="red.600" fontSize="sm">
@@ -80,12 +51,3 @@ export const ErrorHandler = ({
</VStack> </VStack>
); );
}; };
const MIN_DELAY = 500; // milliseconds
const MAX_DELAY = 5000; // milliseconds
function calculateDelay(numPreviousTries: number): number {
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
const jitter = Math.random() * baseDelay;
return baseDelay + jitter;
}

View File

@@ -1,17 +1,16 @@
import { type RouterOutputs, api } from "~/utils/api"; import { api } from "~/utils/api";
import { type PromptVariant, type Scenario } from "../types"; import { type PromptVariant, type Scenario } from "../types";
import { Spinner, Text, Box, Center, Flex } from "@chakra-ui/react"; import { Spinner, Text, Box, Center, Flex, VStack } from "@chakra-ui/react";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import SyntaxHighlighter from "react-syntax-highlighter"; import SyntaxHighlighter from "react-syntax-highlighter";
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs"; import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
import stringify from "json-stringify-pretty-compact"; import stringify from "json-stringify-pretty-compact";
import { type ReactElement, useState, useEffect, useRef, useCallback } from "react"; import { type ReactElement, useState, useEffect } from "react";
import { type ChatCompletion } from "openai/resources/chat"; import { type ChatCompletion } from "openai/resources/chat";
import { generateChannel } from "~/utils/generateChannel";
import { isObject } from "lodash";
import useSocket from "~/utils/useSocket"; import useSocket from "~/utils/useSocket";
import { OutputStats } from "./OutputStats"; import { OutputStats } from "./OutputStats";
import { ErrorHandler } from "./ErrorHandler"; import { ErrorHandler } from "./ErrorHandler";
import { CellOptions } from "./CellOptions";
export default function OutputCell({ export default function OutputCell({
scenario, scenario,
@@ -37,92 +36,68 @@ export default function OutputCell({
// if (variant.config === null || Object.keys(variant.config).length === 0) // if (variant.config === null || Object.keys(variant.config).length === 0)
// disabledReason = "Save your prompt variant to see output"; // disabledReason = "Save your prompt variant to see output";
// const model = getModelName(variant.config as JSONSerializable); const [refetchInterval, setRefetchInterval] = useState(0);
// TODO: Temporarily hardcoding this while we get other stuff working const { data: cell, isLoading: queryLoading } = api.scenarioVariantCells.get.useQuery(
const model = "gpt-3.5-turbo"; { scenarioId: scenario.id, variantId: variant.id },
{ refetchInterval },
);
const outputMutation = api.outputs.get.useMutation(); const { mutateAsync: hardRefetchMutate, isLoading: refetchingOutput } =
api.scenarioVariantCells.forceRefetch.useMutation();
const [output, setOutput] = useState<RouterOutputs["outputs"]["get"]>(null); const [hardRefetch] = useHandledAsyncCallback(async () => {
const [channel, setChannel] = useState<string | undefined>(undefined); await hardRefetchMutate({ scenarioId: scenario.id, variantId: variant.id });
const [numPreviousTries, setNumPreviousTries] = useState(0); await utils.scenarioVariantCells.get.invalidate({
const fetchMutex = useRef(false);
const [fetchOutput, fetchingOutput] = useHandledAsyncCallback(
async (forceRefetch?: boolean) => {
if (fetchMutex.current) return;
setNumPreviousTries((prev) => prev + 1);
fetchMutex.current = true;
setOutput(null);
const shouldStream =
isObject(variant) &&
"config" in variant &&
isObject(variant.config) &&
"stream" in variant.config &&
variant.config.stream === true;
const channel = shouldStream ? generateChannel() : undefined;
setChannel(channel);
const output = await outputMutation.mutateAsync({
scenarioId: scenario.id, scenarioId: scenario.id,
variantId: variant.id, variantId: variant.id,
channel,
forceRefetch,
}); });
setOutput(output); }, [hardRefetchMutate, scenario.id, variant.id]);
await utils.promptVariants.stats.invalidate();
fetchMutex.current = false;
},
[outputMutation, scenario.id, variant.id],
);
const hardRefetch = useCallback(() => fetchOutput(true), [fetchOutput]);
useEffect(fetchOutput, [scenario.id, variant.id]); const fetchingOutput = queryLoading || refetchingOutput;
const awaitingOutput =
!cell || cell.retrievalStatus === "PENDING" || cell.retrievalStatus === "IN_PROGRESS";
useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]);
const modelOutput = cell?.modelOutput;
// Disconnect from socket if we're not streaming anymore // Disconnect from socket if we're not streaming anymore
const streamedMessage = useSocket(fetchingOutput ? channel : undefined); const streamedMessage = useSocket(cell?.streamingChannel);
const streamedContent = streamedMessage?.choices?.[0]?.message?.content; const streamedContent = streamedMessage?.choices?.[0]?.message?.content;
if (!vars) return null; if (!vars) return null;
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>; if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
if (fetchingOutput && !streamedMessage) if (awaitingOutput && !streamedMessage)
return ( return (
<Center h="100%" w="100%"> <Center h="100%" w="100%">
<Spinner /> <Spinner />
</Center> </Center>
); );
if (!output && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>; if (!cell && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
if (output && output.errorMessage) { if (cell && cell.errorMessage) {
return ( return <ErrorHandler cell={cell} refetchOutput={hardRefetch} />;
<ErrorHandler
output={output}
refetchOutput={hardRefetch}
numPreviousTries={numPreviousTries}
/>
);
} }
const response = output?.output as unknown as ChatCompletion; const response = modelOutput?.output as unknown as ChatCompletion;
const message = response?.choices?.[0]?.message; const message = response?.choices?.[0]?.message;
if (output && message?.function_call) { if (modelOutput && message?.function_call) {
const rawArgs = message.function_call.arguments ?? "null"; const rawArgs = message.function_call.arguments ?? "null";
let parsedArgs: string; let parsedArgs: string;
try { try {
parsedArgs = JSON.parse(rawArgs); parsedArgs = JSON.parse(rawArgs);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) { } catch (e: any) {
parsedArgs = `Failed to parse arguments as JSON: '${rawArgs}' ERROR: ${e.message as string}`; parsedArgs = `Failed to parse arguments as JSON: '${rawArgs}' ERROR: ${e.message as string}`;
} }
return ( return (
<Box fontSize="xs" width="100%" flexWrap="wrap" overflowX="auto"> <Box fontSize="xs" width="100%" flexWrap="wrap" overflowX="auto">
<VStack w="full" spacing={0}>
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
<SyntaxHighlighter <SyntaxHighlighter
customStyle={{ overflowX: "unset" }} customStyle={{ overflowX: "unset" }}
language="json" language="json"
@@ -140,17 +115,24 @@ export default function OutputCell({
{ maxLength: 40 }, { maxLength: 40 },
)} )}
</SyntaxHighlighter> </SyntaxHighlighter>
<OutputStats model={model} modelOutput={output} scenario={scenario} /> </VStack>
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
</Box> </Box>
); );
} }
const contentToDisplay = message?.content ?? streamedContent ?? JSON.stringify(output?.output); const contentToDisplay =
message?.content ?? streamedContent ?? JSON.stringify(modelOutput?.output);
return ( return (
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap"> <Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
{contentToDisplay} <VStack w="full" alignItems="flex-start" spacing={0}>
{output && <OutputStats model={model} modelOutput={output} scenario={scenario} />} <CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
<Text>{contentToDisplay}</Text>
</VStack>
{modelOutput && (
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
)}
</Flex> </Flex>
); );
} }

View File

@@ -9,15 +9,15 @@ import { HStack, Icon, Text } from "@chakra-ui/react";
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs"; import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
import { CostTooltip } from "~/components/tooltip/CostTooltip"; import { CostTooltip } from "~/components/tooltip/CostTooltip";
const SHOW_COST = false; const SHOW_COST = true;
const SHOW_TIME = false; const SHOW_TIME = true;
export const OutputStats = ({ export const OutputStats = ({
model, model,
modelOutput, modelOutput,
scenario, scenario,
}: { }: {
model: SupportedModel | null; model: SupportedModel | string | null;
modelOutput: ModelOutput; modelOutput: ModelOutput;
scenario: Scenario; scenario: Scenario;
}) => { }) => {
@@ -35,8 +35,6 @@ export const OutputStats = ({
const cost = promptCost + completionCost; const cost = promptCost + completionCost;
if (!evals.length) return null;
return ( return (
<HStack align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}> <HStack align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
<HStack flex={1}> <HStack flex={1}>

View File

@@ -13,10 +13,11 @@ import AutoResizeTextArea from "../AutoResizeTextArea";
export default function ScenarioEditor({ export default function ScenarioEditor({
scenario, scenario,
hovered, ...props
}: { }: {
scenario: Scenario; scenario: Scenario;
hovered: boolean; hovered: boolean;
canHide: boolean;
}) { }) {
const savedValues = scenario.variableValues as Record<string, string>; const savedValues = scenario.variableValues as Record<string, string>;
const utils = api.useContext(); const utils = api.useContext();
@@ -92,7 +93,9 @@ export default function ScenarioEditor({
onDrop={onReorder} onDrop={onReorder}
backgroundColor={isDragTarget ? "gray.100" : "transparent"} backgroundColor={isDragTarget ? "gray.100" : "transparent"}
> >
<Stack alignSelf="flex-start" opacity={hovered ? 1 : 0} spacing={0}> <Stack alignSelf="flex-start" opacity={props.hovered ? 1 : 0} spacing={0}>
{props.canHide && (
<>
<Tooltip label="Hide scenario" hasArrow> <Tooltip label="Hide scenario" hasArrow>
{/* for some reason the tooltip can't position itself properly relative to the icon without the wrapping box */} {/* for some reason the tooltip can't position itself properly relative to the icon without the wrapping box */}
<Button <Button
@@ -116,6 +119,8 @@ export default function ScenarioEditor({
color="gray.400" color="gray.400"
_hover={{ color: "gray.800", cursor: "pointer" }} _hover={{ color: "gray.800", cursor: "pointer" }}
/> />
</>
)}
</Stack> </Stack>
{variableLabels.length === 0 ? ( {variableLabels.length === 0 ? (
<Box color="gray.500">{vars.data ? "No scenario variables configured" : "Loading..."}</Box> <Box color="gray.500">{vars.data ? "No scenario variables configured" : "Loading..."}</Box>

View File

@@ -5,7 +5,11 @@ import OutputCell from "./OutputCell/OutputCell";
import ScenarioEditor from "./ScenarioEditor"; import ScenarioEditor from "./ScenarioEditor";
import type { PromptVariant, Scenario } from "./types"; import type { PromptVariant, Scenario } from "./types";
const ScenarioRow = (props: { scenario: Scenario; variants: PromptVariant[] }) => { const ScenarioRow = (props: {
scenario: Scenario;
variants: PromptVariant[];
canHide: boolean;
}) => {
const [isHovered, setIsHovered] = useState(false); const [isHovered, setIsHovered] = useState(false);
const highlightStyle = { backgroundColor: "gray.50" }; const highlightStyle = { backgroundColor: "gray.50" };
@@ -18,7 +22,7 @@ const ScenarioRow = (props: { scenario: Scenario; variants: PromptVariant[] }) =
sx={isHovered ? highlightStyle : undefined} sx={isHovered ? highlightStyle : undefined}
borderLeftWidth={1} borderLeftWidth={1}
> >
<ScenarioEditor scenario={props.scenario} hovered={isHovered} /> <ScenarioEditor scenario={props.scenario} hovered={isHovered} canHide={props.canHide} />
</GridItem> </GridItem>
{props.variants.map((variant) => ( {props.variants.map((variant) => (
<GridItem <GridItem

View File

@@ -1,12 +1,12 @@
import { Box, Button, HStack, Tooltip, useToast } from "@chakra-ui/react"; import { Box, Button, HStack, Tooltip, VStack, useToast } from "@chakra-ui/react";
import { useRef, useEffect, useState, useCallback } from "react"; import { useRef, useEffect, useState, useCallback } from "react";
import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks"; import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
import { type PromptVariant } from "./types"; import { type PromptVariant } from "./types";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useAppStore } from "~/state/store"; import { useAppStore } from "~/state/store";
// import openAITypes from "~/codegen/openai.types.ts.txt"; import { editorBackground } from "~/state/sharedVariantEditor.slice";
export default function VariantConfigEditor(props: { variant: PromptVariant }) { export default function VariantEditor(props: { variant: PromptVariant }) {
const monaco = useAppStore.use.sharedVariantEditor.monaco(); const monaco = useAppStore.use.sharedVariantEditor.monaco();
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null); const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`); const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
@@ -18,10 +18,12 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
const checkForChanges = useCallback(() => { const checkForChanges = useCallback(() => {
if (!editorRef.current) return; if (!editorRef.current) return;
const currentConfig = editorRef.current.getValue(); const currentFn = editorRef.current.getValue();
setIsChanged(currentConfig !== lastSavedFn); setIsChanged(currentFn.length > 0 && currentFn !== lastSavedFn);
}, [lastSavedFn]); }, [lastSavedFn]);
useEffect(checkForChanges, [checkForChanges, lastSavedFn]);
const replaceVariant = api.promptVariants.replaceVariant.useMutation(); const replaceVariant = api.promptVariants.replaceVariant.useMutation();
const utils = api.useContext(); const utils = api.useContext();
const toast = useToast(); const toast = useToast();
@@ -64,14 +66,21 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
return; return;
} }
await replaceVariant.mutateAsync({ const resp = await replaceVariant.mutateAsync({
id: props.variant.id, id: props.variant.id,
constructFn: currentFn, constructFn: currentFn,
}); });
if (resp.status === "error") {
return toast({
title: "Error saving variant",
description: resp.message,
status: "error",
});
}
setIsChanged(false);
await utils.promptVariants.list.invalidate(); await utils.promptVariants.list.invalidate();
checkForChanges();
}, [checkForChanges]); }, [checkForChanges]);
useEffect(() => { useEffect(() => {
@@ -122,21 +131,21 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) {
/* eslint-disable-next-line react-hooks/exhaustive-deps */ /* eslint-disable-next-line react-hooks/exhaustive-deps */
}, [monaco, editorId]); }, [monaco, editorId]);
// useEffect(() => {
// const savedConfigChanged = lastSavedFn !== savedConfig;
// lastSavedFn = savedConfig;
// if (savedConfigChanged && editorRef.current?.getValue() !== savedConfig) {
// editorRef.current?.setValue(savedConfig);
// }
// checkForChanges();
// }, [savedConfig, checkForChanges]);
return ( return (
<Box w="100%" pos="relative"> <Box w="100%" pos="relative">
<VStack
spacing={0}
align="stretch"
fontSize="xs"
fontWeight="bold"
color="gray.600"
py={2}
bgColor={editorBackground}
>
<code>{`function constructPrompt(scenario: Scenario): Prompt {`}</code>
<div id={editorId} style={{ height: "300px", width: "100%" }}></div> <div id={editorId} style={{ height: "300px", width: "100%" }}></div>
<code>{`return prompt; }`}</code>
</VStack>
{isChanged && ( {isChanged && (
<HStack pos="absolute" bottom={2} right={2}> <HStack pos="absolute" bottom={2} right={2}>
<Button <Button

View File

@@ -8,7 +8,7 @@ import { RiDraggable } from "react-icons/ri";
import { cellPadding, headerMinHeight } from "../constants"; import { cellPadding, headerMinHeight } from "../constants";
import AutoResizeTextArea from "../AutoResizeTextArea"; import AutoResizeTextArea from "../AutoResizeTextArea";
export default function VariantHeader(props: { variant: PromptVariant }) { export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
const utils = api.useContext(); const utils = api.useContext();
const [isDragTarget, setIsDragTarget] = useState(false); const [isDragTarget, setIsDragTarget] = useState(false);
const [isInputHovered, setIsInputHovered] = useState(false); const [isInputHovered, setIsInputHovered] = useState(false);
@@ -95,11 +95,13 @@ export default function VariantHeader(props: { variant: PromptVariant }) {
onMouseEnter={() => setIsInputHovered(true)} onMouseEnter={() => setIsInputHovered(true)}
onMouseLeave={() => setIsInputHovered(false)} onMouseLeave={() => setIsInputHovered(false)}
/> />
<Tooltip label="Hide Variant" hasArrow> {props.canHide && (
<Tooltip label="Remove Variant" hasArrow>
<Button variant="ghost" colorScheme="gray" size="sm" onClick={onHide}> <Button variant="ghost" colorScheme="gray" size="sm" onClick={onHide}>
<Icon as={BsX} boxSize={6} /> <Icon as={BsX} boxSize={6} />
</Button> </Button>
</Tooltip> </Tooltip>
)}
</HStack> </HStack>
); );
} }

View File

@@ -3,7 +3,7 @@ import { api } from "~/utils/api";
import NewScenarioButton from "./NewScenarioButton"; import NewScenarioButton from "./NewScenarioButton";
import NewVariantButton from "./NewVariantButton"; import NewVariantButton from "./NewVariantButton";
import ScenarioRow from "./ScenarioRow"; import ScenarioRow from "./ScenarioRow";
import VariantConfigEditor from "./VariantEditor"; import VariantEditor from "./VariantEditor";
import VariantHeader from "./VariantHeader"; import VariantHeader from "./VariantHeader";
import { cellPadding } from "../constants"; import { cellPadding } from "../constants";
import { BsPencil } from "react-icons/bs"; import { BsPencil } from "react-icons/bs";
@@ -78,7 +78,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
{variants.data.map((variant) => ( {variants.data.map((variant) => (
<GridItem key={variant.uiId} padding={0} sx={stickyHeaderStyle} borderTopWidth={1}> <GridItem key={variant.uiId} padding={0} sx={stickyHeaderStyle} borderTopWidth={1}>
<VariantHeader variant={variant} /> <VariantHeader variant={variant} canHide={variants.data.length > 1} />
</GridItem> </GridItem>
))} ))}
<GridItem <GridItem
@@ -94,7 +94,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
{variants.data.map((variant) => ( {variants.data.map((variant) => (
<GridItem key={variant.uiId}> <GridItem key={variant.uiId}>
<VariantConfigEditor variant={variant} /> <VariantEditor variant={variant} />
</GridItem> </GridItem>
))} ))}
{variants.data.map((variant) => ( {variants.data.map((variant) => (
@@ -103,7 +103,12 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
</GridItem> </GridItem>
))} ))}
{scenarios.data.map((scenario) => ( {scenarios.data.map((scenario) => (
<ScenarioRow key={scenario.uiId} scenario={scenario} variants={variants.data} /> <ScenarioRow
key={scenario.uiId}
scenario={scenario}
variants={variants.data}
canHide={scenarios.data.length > 1}
/>
))} ))}
<GridItem borderBottomWidth={0} borderRightWidth={0} w="100%" colSpan={allCols} padding={0}> <GridItem borderBottomWidth={0} borderRightWidth={0} w="100%" colSpan={allCols} padding={0}>
<NewScenarioButton /> <NewScenarioButton />

View File

@@ -1,7 +1,7 @@
import { type GetServerSideProps } from "next"; import { type GetServerSideProps } from "next";
// eslint-disable-next-line @typescript-eslint/require-await // eslint-disable-next-line @typescript-eslint/require-await
export const getServerSideProps: GetServerSideProps = async (context) => { export const getServerSideProps: GetServerSideProps = async () => {
return { return {
redirect: { redirect: {
destination: "/experiments", destination: "/experiments",

View File

@@ -3,10 +3,6 @@ import { prisma } from "../db";
import { openai } from "../utils/openai"; import { openai } from "../utils/openai";
import { pick } from "lodash"; import { pick } from "lodash";
function promptHasVariable(prompt: string, variableName: string) {
return prompt.includes(`{{${variableName}}}`);
}
type AxiosError = { type AxiosError = {
response?: { response?: {
data?: { data?: {

View File

@@ -2,7 +2,7 @@ import { promptVariantsRouter } from "~/server/api/routers/promptVariants.router
import { createTRPCRouter } from "~/server/api/trpc"; import { createTRPCRouter } from "~/server/api/trpc";
import { experimentsRouter } from "./routers/experiments.router"; import { experimentsRouter } from "./routers/experiments.router";
import { scenariosRouter } from "./routers/scenarios.router"; import { scenariosRouter } from "./routers/scenarios.router";
import { modelOutputsRouter } from "./routers/modelOutputs.router"; import { scenarioVariantCellsRouter } from "./routers/scenarioVariantCells.router";
import { templateVarsRouter } from "./routers/templateVariables.router"; import { templateVarsRouter } from "./routers/templateVariables.router";
import { evaluationsRouter } from "./routers/evaluations.router"; import { evaluationsRouter } from "./routers/evaluations.router";
@@ -15,7 +15,7 @@ export const appRouter = createTRPCRouter({
promptVariants: promptVariantsRouter, promptVariants: promptVariantsRouter,
experiments: experimentsRouter, experiments: experimentsRouter,
scenarios: scenariosRouter, scenarios: scenariosRouter,
outputs: modelOutputsRouter, scenarioVariantCells: scenarioVariantCellsRouter,
templateVars: templateVarsRouter, templateVars: templateVarsRouter,
evaluations: evaluationsRouter, evaluations: evaluationsRouter,
}); });

View File

@@ -2,6 +2,7 @@ import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import dedent from "dedent"; import dedent from "dedent";
import { generateNewCell } from "~/server/utils/generateNewCell";
export const experimentsRouter = createTRPCRouter({ export const experimentsRouter = createTRPCRouter({
list: publicProcedure.query(async () => { list: publicProcedure.query(async () => {
@@ -64,7 +65,7 @@ export const experimentsRouter = createTRPCRouter({
}, },
}); });
await prisma.$transaction([ const [variant, _, scenario] = await prisma.$transaction([
prisma.promptVariant.create({ prisma.promptVariant.create({
data: { data: {
experimentId: exp.id, experimentId: exp.id,
@@ -73,18 +74,29 @@ export const experimentsRouter = createTRPCRouter({
constructFn: dedent`prompt = { constructFn: dedent`prompt = {
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
stream: true, stream: true,
messages: [{ role: "system", content: "Return 'Ready to go!'" }], messages: [{ role: "system", content: ${"`Return '${scenario.text}'`"} }],
}`, }`,
model: "gpt-3.5-turbo-0613",
},
}),
prisma.templateVariable.create({
data: {
experimentId: exp.id,
label: "text",
}, },
}), }),
prisma.testScenario.create({ prisma.testScenario.create({
data: { data: {
experimentId: exp.id, experimentId: exp.id,
variableValues: {}, variableValues: {
text: "This is a test scenario.",
},
}, },
}), }),
]); ]);
await generateNewCell(variant.id, scenario.id);
return exp; return exp;
}), }),

View File

@@ -1,97 +0,0 @@
import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import crypto from "crypto";
import type { Prisma } from "@prisma/client";
import { reevaluateVariant } from "~/server/utils/evaluations";
import { getCompletion } from "~/server/utils/getCompletion";
import { constructPrompt } from "~/server/utils/constructPrompt";
export const modelOutputsRouter = createTRPCRouter({
get: publicProcedure
.input(
z.object({
scenarioId: z.string(),
variantId: z.string(),
channel: z.string().optional(),
forceRefetch: z.boolean().optional(),
}),
)
.mutation(async ({ input }) => {
const existing = await prisma.modelOutput.findUnique({
where: {
promptVariantId_testScenarioId: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
},
},
});
if (existing && !input.forceRefetch) return existing;
const variant = await prisma.promptVariant.findUnique({
where: {
id: input.variantId,
},
});
const scenario = await prisma.testScenario.findUnique({
where: {
id: input.scenarioId,
},
});
if (!variant || !scenario) return null;
const prompt = await constructPrompt(variant, scenario);
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
// TODO: we should probably only use this if temperature=0
const existingResponse = await prisma.modelOutput.findFirst({
where: { inputHash, errorMessage: null },
});
let modelResponse: Awaited<ReturnType<typeof getCompletion>>;
if (existingResponse) {
modelResponse = {
output: existingResponse.output as Prisma.InputJsonValue,
statusCode: existingResponse.statusCode,
errorMessage: existingResponse.errorMessage,
timeToComplete: existingResponse.timeToComplete,
promptTokens: existingResponse.promptTokens ?? undefined,
completionTokens: existingResponse.completionTokens ?? undefined,
};
} else {
try {
modelResponse = await getCompletion(prompt, input.channel);
} catch (e) {
console.error(e);
throw e;
}
}
const modelOutput = await prisma.modelOutput.upsert({
where: {
promptVariantId_testScenarioId: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
},
},
create: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
inputHash,
...modelResponse,
},
update: {
...modelResponse,
},
});
await reevaluateVariant(input.variantId);
return modelOutput;
}),
});

View File

@@ -1,6 +1,12 @@
import dedent from "dedent";
import { isObject } from "lodash";
import { z } from "zod"; import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { generateNewCell } from "~/server/utils/generateNewCell";
import { OpenAIChatModel } from "~/server/types";
import { constructPrompt } from "~/server/utils/constructPrompt";
import userError from "~/server/utils/error";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated"; import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
import { calculateTokenCost } from "~/utils/calculateTokenCost"; import { calculateTokenCost } from "~/utils/calculateTokenCost";
@@ -39,17 +45,24 @@ export const promptVariantsRouter = createTRPCRouter({
visible: true, visible: true,
}, },
}); });
const outputCount = await prisma.modelOutput.count({ const outputCount = await prisma.scenarioVariantCell.count({
where: { where: {
promptVariantId: input.variantId, promptVariantId: input.variantId,
testScenario: { visible: true }, testScenario: { visible: true },
modelOutput: {
isNot: null,
},
}, },
}); });
const overallTokens = await prisma.modelOutput.aggregate({ const overallTokens = await prisma.modelOutput.aggregate({
where: { where: {
scenarioVariantCell: {
promptVariantId: input.variantId, promptVariantId: input.variantId,
testScenario: { visible: true }, testScenario: {
visible: true,
},
},
}, },
_sum: { _sum: {
promptTokens: true, promptTokens: true,
@@ -57,14 +70,10 @@ export const promptVariantsRouter = createTRPCRouter({
}, },
}); });
// TODO: fix this
const model = "gpt-3.5-turbo-0613";
// const model = getModelName(variant.config);
const promptTokens = overallTokens._sum?.promptTokens ?? 0; const promptTokens = overallTokens._sum?.promptTokens ?? 0;
const overallPromptCost = calculateTokenCost(model, promptTokens); const overallPromptCost = calculateTokenCost(variant.model, promptTokens);
const completionTokens = overallTokens._sum?.completionTokens ?? 0; const completionTokens = overallTokens._sum?.completionTokens ?? 0;
const overallCompletionCost = calculateTokenCost(model, completionTokens, true); const overallCompletionCost = calculateTokenCost(variant.model, completionTokens, true);
const overallCost = overallPromptCost + overallCompletionCost; const overallCost = overallPromptCost + overallCompletionCost;
@@ -105,7 +114,19 @@ export const promptVariantsRouter = createTRPCRouter({
experimentId: input.experimentId, experimentId: input.experimentId,
label: `Prompt Variant ${largestSortIndex + 2}`, label: `Prompt Variant ${largestSortIndex + 2}`,
sortIndex: (lastVariant?.sortIndex ?? 0) + 1, sortIndex: (lastVariant?.sortIndex ?? 0) + 1,
constructFn: lastVariant?.constructFn ?? "", constructFn:
lastVariant?.constructFn ??
dedent`
prompt = {
model: "gpt-3.5-turbo",
messages: [
{
role: "system",
content: "Return 'Hello, world!'",
}
]
}`,
model: lastVariant?.model ?? "gpt-3.5-turbo",
}, },
}); });
@@ -114,6 +135,17 @@ export const promptVariantsRouter = createTRPCRouter({
recordExperimentUpdated(input.experimentId), recordExperimentUpdated(input.experimentId),
]); ]);
const scenarios = await prisma.testScenario.findMany({
where: {
experimentId: input.experimentId,
visible: true,
},
});
for (const scenario of scenarios) {
await generateNewCell(newVariant.id, scenario.id);
}
return newVariant; return newVariant;
}), }),
@@ -185,6 +217,27 @@ export const promptVariantsRouter = createTRPCRouter({
throw new Error(`Prompt Variant with id ${input.id} does not exist`); throw new Error(`Prompt Variant with id ${input.id} does not exist`);
} }
let model = existing.model;
try {
const contructedPrompt = await constructPrompt({ constructFn: input.constructFn }, null);
if (!isObject(contructedPrompt)) {
return userError("Prompt is not an object");
}
if (!("model" in contructedPrompt)) {
return userError("Prompt does not define a model");
}
if (
typeof contructedPrompt.model !== "string" ||
!(contructedPrompt.model in OpenAIChatModel)
) {
return userError("Prompt defines an invalid model");
}
model = contructedPrompt.model;
} catch (e) {
return userError((e as Error).message);
}
// Create a duplicate with only the config changed // Create a duplicate with only the config changed
const newVariant = await prisma.promptVariant.create({ const newVariant = await prisma.promptVariant.create({
data: { data: {
@@ -193,11 +246,12 @@ export const promptVariantsRouter = createTRPCRouter({
sortIndex: existing.sortIndex, sortIndex: existing.sortIndex,
uiId: existing.uiId, uiId: existing.uiId,
constructFn: input.constructFn, constructFn: input.constructFn,
model,
}, },
}); });
// Hide anything with the same uiId besides the new one // Hide anything with the same uiId besides the new one
const hideOldVariantsAction = prisma.promptVariant.updateMany({ const hideOldVariants = prisma.promptVariant.updateMany({
where: { where: {
uiId: existing.uiId, uiId: existing.uiId,
id: { id: {
@@ -209,12 +263,20 @@ export const promptVariantsRouter = createTRPCRouter({
}, },
}); });
await prisma.$transaction([ await prisma.$transaction([hideOldVariants, recordExperimentUpdated(existing.experimentId)]);
hideOldVariantsAction,
recordExperimentUpdated(existing.experimentId),
]);
return newVariant; const scenarios = await prisma.testScenario.findMany({
where: {
experimentId: newVariant.experimentId,
visible: true,
},
});
for (const scenario of scenarios) {
await generateNewCell(newVariant.id, scenario.id);
}
return { status: "ok" } as const;
}), }),
reorder: publicProcedure reorder: publicProcedure

View File

@@ -0,0 +1,68 @@
import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { generateNewCell } from "~/server/utils/generateNewCell";
import { queueLLMRetrievalTask } from "~/server/utils/queueLLMRetrievalTask";
export const scenarioVariantCellsRouter = createTRPCRouter({
get: publicProcedure
.input(
z.object({
scenarioId: z.string(),
variantId: z.string(),
}),
)
.query(async ({ input }) => {
return await prisma.scenarioVariantCell.findUnique({
where: {
promptVariantId_testScenarioId: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
},
},
include: {
modelOutput: true,
},
});
}),
forceRefetch: publicProcedure
.input(
z.object({
scenarioId: z.string(),
variantId: z.string(),
}),
)
.mutation(async ({ input }) => {
const cell = await prisma.scenarioVariantCell.findUnique({
where: {
promptVariantId_testScenarioId: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
},
},
include: {
modelOutput: true,
},
});
if (!cell) {
await generateNewCell(input.variantId, input.scenarioId);
return true;
}
if (cell.modelOutput) {
// TODO: Maybe keep these around to show previous generations?
await prisma.modelOutput.delete({
where: { id: cell.modelOutput.id },
});
}
await prisma.scenarioVariantCell.update({
where: { id: cell.id },
data: { retrievalStatus: "PENDING" },
});
await queueLLMRetrievalTask(cell.id);
return true;
}),
});

View File

@@ -4,6 +4,7 @@ import { prisma } from "~/server/db";
import { autogenerateScenarioValues } from "../autogen"; import { autogenerateScenarioValues } from "../autogen";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated"; import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
import { reevaluateAll } from "~/server/utils/evaluations"; import { reevaluateAll } from "~/server/utils/evaluations";
import { generateNewCell } from "~/server/utils/generateNewCell";
export const scenariosRouter = createTRPCRouter({ export const scenariosRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
@@ -48,10 +49,21 @@ export const scenariosRouter = createTRPCRouter({
}, },
}); });
await prisma.$transaction([ const [scenario] = await prisma.$transaction([
createNewScenarioAction, createNewScenarioAction,
recordExperimentUpdated(input.experimentId), recordExperimentUpdated(input.experimentId),
]); ]);
const promptVariants = await prisma.promptVariant.findMany({
where: {
experimentId: input.experimentId,
visible: true,
},
});
for (const variant of promptVariants) {
await generateNewCell(variant.id, scenario.id);
}
}), }),
hide: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => { hide: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
@@ -175,6 +187,17 @@ export const scenariosRouter = createTRPCRouter({
}, },
}); });
const promptVariants = await prisma.promptVariant.findMany({
where: {
experimentId: newScenario.experimentId,
visible: true,
},
});
for (const variant of promptVariants) {
await generateNewCell(variant.id, newScenario.id);
}
return newScenario; return newScenario;
}), }),
}); });

View File

@@ -0,0 +1,47 @@
import { type Prisma } from "@prisma/client";
import { prisma } from "../db";
async function migrateScenarioVariantOutputData() {
// Get all ScenarioVariantCells
const cells = await prisma.scenarioVariantCell.findMany({ include: { modelOutput: true } });
console.log(`Found ${cells.length} records`);
let updatedCount = 0;
// Loop through all scenarioVariants
for (const cell of cells) {
// Create a new ModelOutput for each ScenarioVariant with an existing output
if (cell.output && !cell.modelOutput) {
updatedCount++;
await prisma.modelOutput.create({
data: {
scenarioVariantCellId: cell.id,
inputHash: cell.inputHash || "",
output: cell.output as Prisma.InputJsonValue,
timeToComplete: cell.timeToComplete ?? undefined,
promptTokens: cell.promptTokens,
completionTokens: cell.completionTokens,
createdAt: cell.createdAt,
updatedAt: cell.updatedAt,
},
});
} else if (cell.errorMessage && cell.retrievalStatus === "COMPLETE") {
updatedCount++;
await prisma.scenarioVariantCell.update({
where: { id: cell.id },
data: {
retrievalStatus: "ERROR",
},
});
}
}
console.log("Data migration completed");
console.log(`Updated ${updatedCount} records`);
}
// Execute the function
migrateScenarioVariantOutputData().catch((error) => {
console.error("An error occurred while migrating data: ", error);
process.exit(1);
});

View File

@@ -0,0 +1,31 @@
// Import necessary dependencies
import { quickAddJob, type Helpers, type Task } from "graphile-worker";
import { env } from "~/env.mjs";
// Define the defineTask function
function defineTask<TPayload>(
taskIdentifier: string,
taskHandler: (payload: TPayload, helpers: Helpers) => Promise<void>,
) {
const enqueue = async (payload: TPayload) => {
console.log("Enqueuing task", taskIdentifier, payload);
await quickAddJob({ connectionString: env.DATABASE_URL }, taskIdentifier, payload);
};
const handler = (payload: TPayload, helpers: Helpers) => {
helpers.logger.info(`Running task ${taskIdentifier} with payload: ${JSON.stringify(payload)}`);
return taskHandler(payload, helpers);
};
const task = {
identifier: taskIdentifier,
handler: handler as Task,
};
return {
enqueue,
task,
};
}
export default defineTask;

View File

@@ -0,0 +1,144 @@
import crypto from "crypto";
import { prisma } from "~/server/db";
import defineTask from "./defineTask";
import { type CompletionResponse, getCompletion } from "../utils/getCompletion";
import { type JSONSerializable } from "../types";
import { sleep } from "../utils/sleep";
import { shouldStream } from "../utils/shouldStream";
import { generateChannel } from "~/utils/generateChannel";
import { reevaluateVariant } from "../utils/evaluations";
import { constructPrompt } from "../utils/constructPrompt";
import { type CompletionCreateParams } from "openai/resources/chat";
const MAX_AUTO_RETRIES = 10;
const MIN_DELAY = 500; // milliseconds
const MAX_DELAY = 15000; // milliseconds
function calculateDelay(numPreviousTries: number): number {
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
const jitter = Math.random() * baseDelay;
return baseDelay + jitter;
}
const getCompletionWithRetries = async (
cellId: string,
payload: JSONSerializable,
channel?: string,
): Promise<CompletionResponse> => {
for (let i = 0; i < MAX_AUTO_RETRIES; i++) {
const modelResponse = await getCompletion(
payload as unknown as CompletionCreateParams,
channel,
);
if (modelResponse.statusCode !== 429 || i === MAX_AUTO_RETRIES - 1) {
return modelResponse;
}
const delay = calculateDelay(i);
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
errorMessage: "Rate limit exceeded",
statusCode: 429,
retryTime: new Date(Date.now() + delay),
},
});
// TODO: Maybe requeue the job so other jobs can run in the future?
await sleep(delay);
}
throw new Error("Max retries limit reached");
};
export type queryLLMJob = {
scenarioVariantCellId: string;
};
export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
const { scenarioVariantCellId } = task;
const cell = await prisma.scenarioVariantCell.findUnique({
where: { id: scenarioVariantCellId },
include: { modelOutput: true },
});
if (!cell) {
return;
}
// If cell is not pending, then some other job is already processing it
if (cell.retrievalStatus !== "PENDING") {
return;
}
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
retrievalStatus: "IN_PROGRESS",
},
});
const variant = await prisma.promptVariant.findUnique({
where: { id: cell.promptVariantId },
});
if (!variant) {
return;
}
const scenario = await prisma.testScenario.findUnique({
where: { id: cell.testScenarioId },
});
if (!scenario) {
return;
}
const prompt = await constructPrompt(variant, scenario.variableValues);
const streamingEnabled = shouldStream(prompt);
let streamingChannel;
if (streamingEnabled) {
streamingChannel = generateChannel();
// Save streaming channel so that UI can connect to it
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
streamingChannel,
},
});
}
const modelResponse = await getCompletionWithRetries(
scenarioVariantCellId,
prompt,
streamingChannel,
);
let modelOutput = null;
if (modelResponse.statusCode === 200) {
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
modelOutput = await prisma.modelOutput.create({
data: {
scenarioVariantCellId,
inputHash,
output: modelResponse.output,
timeToComplete: modelResponse.timeToComplete,
promptTokens: modelResponse.promptTokens,
completionTokens: modelResponse.completionTokens,
},
});
}
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
statusCode: modelResponse.statusCode,
errorMessage: modelResponse.errorMessage,
streamingChannel: null,
retrievalStatus: modelOutput ? "COMPLETE" : "ERROR",
modelOutput: {
connect: {
id: modelOutput?.id,
},
},
},
});
await reevaluateVariant(cell.promptVariantId);
});

View File

@@ -0,0 +1,40 @@
import { type TaskList, run } from "graphile-worker";
import "dotenv/config";
import { env } from "~/env.mjs";
import { queryLLM } from "./queryLLM.task";
const registeredTasks = [queryLLM];
const taskList = registeredTasks.reduce((acc, task) => {
acc[task.task.identifier] = task.task.handler;
return acc;
}, {} as TaskList);
async function main() {
// Run a worker to execute jobs:
const runner = await run({
connectionString: env.DATABASE_URL,
concurrency: 20,
// Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc
noHandleSignals: false,
pollInterval: 1000,
// you can set the taskList or taskDirectory but not both
taskList,
// or:
// taskDirectory: `${__dirname}/tasks`,
});
// Immediately await (or otherwise handled) the resulting promise, to avoid
// "unhandled rejection" errors causing a process crash in the event of
// something going wrong.
await runner.promise;
// If the worker exits (whether through fatal error or otherwise), the above
// promise will resolve/reject.
}
main().catch((err) => {
console.error("Unhandled error occurred running worker: ", err);
process.exit(1);
});

View File

@@ -7,10 +7,8 @@ test.skip("constructPrompt", async () => {
constructFn: `prompt = { "fooz": "bar" }`, constructFn: `prompt = { "fooz": "bar" }`,
}, },
{ {
variableValues: {
foo: "bar", foo: "bar",
}, },
},
); );
console.log(constructed); console.log(constructed);

View File

@@ -6,12 +6,10 @@ const isolate = new ivm.Isolate({ memoryLimit: 128 });
export async function constructPrompt( export async function constructPrompt(
variant: Pick<PromptVariant, "constructFn">, variant: Pick<PromptVariant, "constructFn">,
testScenario: Pick<TestScenario, "variableValues">, scenario: TestScenario["variableValues"],
): Promise<JSONSerializable> { ): Promise<JSONSerializable> {
const scenario = testScenario.variableValues as JSONSerializable;
const code = ` const code = `
const scenario = ${JSON.stringify(scenario, null, 2)}; const scenario = ${JSON.stringify(scenario ?? {}, null, 2)};
let prompt let prompt
${variant.constructFn} ${variant.constructFn}

View File

@@ -0,0 +1,6 @@
export default function userError(message: string): { status: "error"; message: string } {
return {
status: "error",
message,
};
}

View File

@@ -1,4 +1,4 @@
import { type Evaluation } from "@prisma/client"; import { type ModelOutput, type Evaluation } from "@prisma/client";
import { prisma } from "../db"; import { prisma } from "../db";
import { evaluateOutput } from "./evaluateOutput"; import { evaluateOutput } from "./evaluateOutput";
@@ -12,21 +12,22 @@ export const reevaluateVariant = async (variantId: string) => {
where: { experimentId: variant.experimentId }, where: { experimentId: variant.experimentId },
}); });
const modelOutputs = await prisma.modelOutput.findMany({ const cells = await prisma.scenarioVariantCell.findMany({
where: { where: {
promptVariantId: variantId, promptVariantId: variantId,
statusCode: { notIn: [429] }, retrievalStatus: "COMPLETE",
testScenario: { visible: true }, testScenario: { visible: true },
modelOutput: { isNot: null },
}, },
include: { testScenario: true }, include: { testScenario: true, modelOutput: true },
}); });
await Promise.all( await Promise.all(
evaluations.map(async (evaluation) => { evaluations.map(async (evaluation) => {
const passCount = modelOutputs.filter((output) => const passCount = cells.filter((cell) =>
evaluateOutput(output, output.testScenario, evaluation), evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
).length; ).length;
const failCount = modelOutputs.length - passCount; const failCount = cells.length - passCount;
await prisma.evaluationResult.upsert({ await prisma.evaluationResult.upsert({
where: { where: {
@@ -55,22 +56,23 @@ export const reevaluateEvaluation = async (evaluation: Evaluation) => {
where: { experimentId: evaluation.experimentId, visible: true }, where: { experimentId: evaluation.experimentId, visible: true },
}); });
const modelOutputs = await prisma.modelOutput.findMany({ const cells = await prisma.scenarioVariantCell.findMany({
where: { where: {
promptVariantId: { in: variants.map((v) => v.id) }, promptVariantId: { in: variants.map((v) => v.id) },
testScenario: { visible: true }, testScenario: { visible: true },
statusCode: { notIn: [429] }, statusCode: { notIn: [429] },
modelOutput: { isNot: null },
}, },
include: { testScenario: true }, include: { testScenario: true, modelOutput: true },
}); });
await Promise.all( await Promise.all(
variants.map(async (variant) => { variants.map(async (variant) => {
const outputs = modelOutputs.filter((output) => output.promptVariantId === variant.id); const variantCells = cells.filter((cell) => cell.promptVariantId === variant.id);
const passCount = outputs.filter((output) => const passCount = variantCells.filter((cell) =>
evaluateOutput(output, output.testScenario, evaluation), evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
).length; ).length;
const failCount = outputs.length - passCount; const failCount = variantCells.length - passCount;
await prisma.evaluationResult.upsert({ await prisma.evaluationResult.upsert({
where: { where: {

View File

@@ -0,0 +1,76 @@
import crypto from "crypto";
import { type Prisma } from "@prisma/client";
import { prisma } from "../db";
import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
import { constructPrompt } from "./constructPrompt";
export const generateNewCell = async (variantId: string, scenarioId: string) => {
const variant = await prisma.promptVariant.findUnique({
where: {
id: variantId,
},
});
const scenario = await prisma.testScenario.findUnique({
where: {
id: scenarioId,
},
});
if (!variant || !scenario) return null;
const prompt = await constructPrompt(variant, scenario.variableValues);
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
let cell = await prisma.scenarioVariantCell.findUnique({
where: {
promptVariantId_testScenarioId: {
promptVariantId: variantId,
testScenarioId: scenarioId,
},
},
include: {
modelOutput: true,
},
});
if (cell) return cell;
cell = await prisma.scenarioVariantCell.create({
data: {
promptVariantId: variantId,
testScenarioId: scenarioId,
},
include: {
modelOutput: true,
},
});
const matchingModelOutput = await prisma.modelOutput.findFirst({
where: {
inputHash,
},
});
let newModelOutput;
if (matchingModelOutput) {
newModelOutput = await prisma.modelOutput.create({
data: {
scenarioVariantCellId: cell.id,
inputHash,
output: matchingModelOutput.output as Prisma.InputJsonValue,
timeToComplete: matchingModelOutput.timeToComplete,
promptTokens: matchingModelOutput.promptTokens,
completionTokens: matchingModelOutput.completionTokens,
createdAt: matchingModelOutput.createdAt,
updatedAt: matchingModelOutput.updatedAt,
},
});
} else {
cell = await queueLLMRetrievalTask(cell.id);
}
return { ...cell, modelOutput: newModelOutput };
};

View File

@@ -4,15 +4,12 @@ import { Prisma } from "@prisma/client";
import { streamChatCompletion } from "./openai"; import { streamChatCompletion } from "./openai";
import { wsConnection } from "~/utils/wsConnection"; import { wsConnection } from "~/utils/wsConnection";
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat"; import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
import { type JSONSerializable, OpenAIChatModel } from "../types"; import { type OpenAIChatModel } from "../types";
import { env } from "~/env.mjs"; import { env } from "~/env.mjs";
import { countOpenAIChatTokens } from "~/utils/countTokens"; import { countOpenAIChatTokens } from "~/utils/countTokens";
import { getModelName } from "./getModelName";
import { rateLimitErrorMessage } from "~/sharedStrings"; import { rateLimitErrorMessage } from "~/sharedStrings";
env; export type CompletionResponse = {
type CompletionResponse = {
output: Prisma.InputJsonValue | typeof Prisma.JsonNull; output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
statusCode: number; statusCode: number;
errorMessage: string | null; errorMessage: string | null;
@@ -22,35 +19,7 @@ type CompletionResponse = {
}; };
export async function getCompletion( export async function getCompletion(
payload: JSONSerializable,
channel?: string,
): Promise<CompletionResponse> {
const modelName = getModelName(payload);
if (!modelName)
return {
output: Prisma.JsonNull,
statusCode: 400,
errorMessage: "Invalid payload provided",
timeToComplete: 0,
};
if (modelName in OpenAIChatModel) {
return getOpenAIChatCompletion(
payload as unknown as CompletionCreateParams,
env.OPENAI_API_KEY,
channel,
);
}
return {
output: Prisma.JsonNull,
statusCode: 400,
errorMessage: "Invalid model provided",
timeToComplete: 0,
};
}
export async function getOpenAIChatCompletion(
payload: CompletionCreateParams, payload: CompletionCreateParams,
apiKey: string,
channel?: string, channel?: string,
): Promise<CompletionResponse> { ): Promise<CompletionResponse> {
// If functions are enabled, disable streaming so that we get the full response with token counts // If functions are enabled, disable streaming so that we get the full response with token counts
@@ -60,7 +29,7 @@ export async function getOpenAIChatCompletion(
method: "POST", method: "POST",
headers: { headers: {
"Content-Type": "application/json", "Content-Type": "application/json",
Authorization: `Bearer ${apiKey}`, Authorization: `Bearer ${env.OPENAI_API_KEY}`,
}, },
body: JSON.stringify(payload), body: JSON.stringify(payload),
}); });

View File

@@ -1,9 +0,0 @@
import { isObject } from "lodash";
import { type JSONSerializable, type SupportedModel } from "../types";
import { type Prisma } from "@prisma/client";
export function getModelName(config: JSONSerializable | Prisma.JsonValue): SupportedModel | null {
if (!isObject(config)) return null;
if ("model" in config && typeof config.model === "string") return config.model as SupportedModel;
return null;
}

View File

@@ -0,0 +1,22 @@
import { prisma } from "../db";
import { queryLLM } from "../tasks/queryLLM.task";
export const queueLLMRetrievalTask = async (cellId: string) => {
const updatedCell = await prisma.scenarioVariantCell.update({
where: {
id: cellId,
},
data: {
retrievalStatus: "PENDING",
errorMessage: null,
},
include: {
modelOutput: true,
},
});
// @ts-expect-error we aren't passing the helpers but that's ok
void queryLLM.task.handler({ scenarioVariantCellId: cellId }, { logger: console });
return updatedCell;
};

View File

@@ -0,0 +1,7 @@
import { isObject } from "lodash";
import { type JSONSerializable } from "../types";
export const shouldStream = (config: JSONSerializable): boolean => {
const shouldStream = isObject(config) && "stream" in config && config.stream === true;
return shouldStream;
};

View File

@@ -0,0 +1 @@
export const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));

View File

@@ -2,12 +2,9 @@ import { type RouterOutputs } from "~/utils/api";
import { type SliceCreator } from "./store"; import { type SliceCreator } from "./store";
import loader from "@monaco-editor/loader"; import loader from "@monaco-editor/loader";
import openAITypes from "~/codegen/openai.types.ts.txt"; import openAITypes from "~/codegen/openai.types.ts.txt";
import prettier from "prettier/standalone"; import formatPromptConstructor from "~/utils/formatPromptConstructor";
import parserTypescript from "prettier/plugins/typescript";
// @ts-expect-error for some reason missing from types export const editorBackground = "#fafafa";
import parserEstree from "prettier/plugins/estree";
import { type languages } from "monaco-editor/esm/vs/editor/editor.api";
export type SharedVariantEditorSlice = { export type SharedVariantEditorSlice = {
monaco: null | ReturnType<typeof loader.__getMonacoInstance>; monaco: null | ReturnType<typeof loader.__getMonacoInstance>;
@@ -17,29 +14,12 @@ export type SharedVariantEditorSlice = {
setScenarios: (scenarios: RouterOutputs["scenarios"]["list"]) => void; setScenarios: (scenarios: RouterOutputs["scenarios"]["list"]) => void;
}; };
const customFormatter: languages.DocumentFormattingEditProvider = {
provideDocumentFormattingEdits: async (model) => {
const val = model.getValue();
console.log("going to format!", val);
const text = await prettier.format(val, {
parser: "typescript",
plugins: [parserTypescript, parserEstree],
// We're showing these in pretty narrow panes so let's keep the print width low
printWidth: 60,
});
return [
{
range: model.getFullModelRange(),
text,
},
];
},
};
export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> = (set, get) => ({ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> = (set, get) => ({
monaco: loader.__getMonacoInstance(), monaco: loader.__getMonacoInstance(),
loadMonaco: async () => { loadMonaco: async () => {
// We only want to run this client-side
if (typeof window === "undefined") return;
const monaco = await loader.init(); const monaco = await loader.init();
monaco.editor.defineTheme("customTheme", { monaco.editor.defineTheme("customTheme", {
@@ -47,12 +27,13 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
inherit: true, inherit: true,
rules: [], rules: [],
colors: { colors: {
"editor.background": "#fafafa", "editor.background": editorBackground,
}, },
}); });
monaco.languages.typescript.typescriptDefaults.setCompilerOptions({ monaco.languages.typescript.typescriptDefaults.setCompilerOptions({
allowNonTsExtensions: true, allowNonTsExtensions: true,
strictNullChecks: true,
lib: ["esnext"], lib: ["esnext"],
}); });
@@ -66,7 +47,16 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
monaco.Uri.parse("file:///openai.types.ts"), monaco.Uri.parse("file:///openai.types.ts"),
); );
monaco.languages.registerDocumentFormattingEditProvider("typescript", customFormatter); monaco.languages.registerDocumentFormattingEditProvider("typescript", {
provideDocumentFormattingEdits: async (model) => {
return [
{
range: model.getFullModelRange(),
text: await formatPromptConstructor(model.getValue()),
},
];
},
});
set((state) => { set((state) => {
state.sharedVariantEditor.monaco = monaco; state.sharedVariantEditor.monaco = monaco;
@@ -95,7 +85,7 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
)} as const; )} as const;
type Scenario = typeof scenarios[number]; type Scenario = typeof scenarios[number];
declare var scenario: Scenario | null; declare var scenario: Scenario | { [key: string]: string };
`; `;
const scenariosModel = monaco.editor.getModel(monaco.Uri.parse("file:///scenarios.ts")); const scenariosModel = monaco.editor.getModel(monaco.Uri.parse("file:///scenarios.ts"));

View File

@@ -23,7 +23,7 @@ const openAICompletionTokensToDollars: { [key in OpenAIChatModel]: number } = {
}; };
export const calculateTokenCost = ( export const calculateTokenCost = (
model: SupportedModel | null, model: SupportedModel | string | null,
numTokens: number, numTokens: number,
isCompletion = false, isCompletion = false,
) => { ) => {

View File

@@ -0,0 +1,10 @@
import { expect, test } from "vitest";
import { stripTypes } from "./formatPromptConstructor";
test("stripTypes", () => {
expect(stripTypes(`const foo: string = "bar";`)).toBe(`const foo = "bar";`);
});
test("stripTypes with invalid syntax", () => {
expect(stripTypes(`asdf foo: string = "bar"`)).toBe(`asdf foo: string = "bar"`);
});

View File

@@ -0,0 +1,31 @@
import prettier from "prettier/standalone";
import parserTypescript from "prettier/plugins/typescript";
// @ts-expect-error for some reason missing from types
import parserEstree from "prettier/plugins/estree";
import * as babel from "@babel/standalone";
export function stripTypes(tsCode: string): string {
const options = {
presets: ["typescript"],
filename: "file.ts",
};
try {
const result = babel.transform(tsCode, options);
return result.code ?? tsCode;
} catch (error) {
console.error("Error stripping types", error);
return tsCode;
}
}
export default async function formatPromptConstructor(code: string): Promise<string> {
return await prettier.format(stripTypes(code), {
parser: "typescript",
plugins: [parserTypescript, parserEstree],
// We're showing these in pretty narrow panes so let's keep the print width low
printWidth: 60,
});
}

View File

@@ -5,7 +5,7 @@ import { env } from "~/env.mjs";
const url = env.NEXT_PUBLIC_SOCKET_URL; const url = env.NEXT_PUBLIC_SOCKET_URL;
export default function useSocket(channel?: string) { export default function useSocket(channel?: string | null) {
const socketRef = useRef<Socket>(); const socketRef = useRef<Socket>();
const [message, setMessage] = useState<ChatCompletion | null>(null); const [message, setMessage] = useState<ChatCompletion | null>(null);