diff --git a/prisma/schema.prisma b/prisma/schema.prisma index 3bad8cc..baa9f4f 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -26,10 +26,11 @@ model Experiment { } model PromptVariant { - id String @id @default(uuid()) @db.Uuid - label String + id String @id @default(uuid()) @db.Uuid + label String constructFn String + model String @default("gpt-3.5-turbo") uiId String @default(uuid()) @db.Uuid visible Boolean @default(true) diff --git a/src/components/OutputsTable/OutputCell/OutputCell.tsx b/src/components/OutputsTable/OutputCell/OutputCell.tsx index 6f0f615..8307852 100644 --- a/src/components/OutputsTable/OutputCell/OutputCell.tsx +++ b/src/components/OutputsTable/OutputCell/OutputCell.tsx @@ -37,10 +37,6 @@ export default function OutputCell({ // if (variant.config === null || Object.keys(variant.config).length === 0) // disabledReason = "Save your prompt variant to see output"; - // const model = getModelName(variant.config as JSONSerializable); - // TODO: Temporarily hardcoding this while we get other stuff working - const model = "gpt-3.5-turbo"; - const outputMutation = api.outputs.get.useMutation(); const [output, setOutput] = useState(null); @@ -140,7 +136,7 @@ export default function OutputCell({ { maxLength: 40 }, )} - + ); } @@ -150,7 +146,7 @@ export default function OutputCell({ return ( {contentToDisplay} - {output && } + {output && } ); } diff --git a/src/components/OutputsTable/OutputCell/OutputStats.tsx b/src/components/OutputsTable/OutputCell/OutputStats.tsx index 4491300..02e5780 100644 --- a/src/components/OutputsTable/OutputCell/OutputStats.tsx +++ b/src/components/OutputsTable/OutputCell/OutputStats.tsx @@ -17,7 +17,7 @@ export const OutputStats = ({ modelOutput, scenario, }: { - model: SupportedModel | null; + model: SupportedModel | string | null; modelOutput: ModelOutput; scenario: Scenario; }) => { diff --git a/src/components/OutputsTable/VariantEditor.tsx b/src/components/OutputsTable/VariantEditor.tsx index d1eb1ce..526bed0 100644 --- a/src/components/OutputsTable/VariantEditor.tsx +++ b/src/components/OutputsTable/VariantEditor.tsx @@ -59,10 +59,17 @@ export default function VariantConfigEditor(props: { variant: PromptVariant }) { return; } - await replaceVariant.mutateAsync({ + const resp = await replaceVariant.mutateAsync({ id: props.variant.id, constructFn: currentFn, }); + if (resp.status === "error") { + return toast({ + title: "Error saving variant", + description: resp.message, + status: "error", + }); + } await utils.promptVariants.list.invalidate(); diff --git a/src/server/api/routers/experiments.router.ts b/src/server/api/routers/experiments.router.ts index 21e6b52..1c47d7a 100644 --- a/src/server/api/routers/experiments.router.ts +++ b/src/server/api/routers/experiments.router.ts @@ -75,6 +75,7 @@ export const experimentsRouter = createTRPCRouter({ stream: true, messages: [{ role: "system", content: "Return 'Ready to go!'" }], }`, + model: "gpt-3.5-turbo-0613", }, }), prisma.testScenario.create({ diff --git a/src/server/api/routers/modelOutputs.router.ts b/src/server/api/routers/modelOutputs.router.ts index 848305d..7b11850 100644 --- a/src/server/api/routers/modelOutputs.router.ts +++ b/src/server/api/routers/modelOutputs.router.ts @@ -6,6 +6,7 @@ import type { Prisma } from "@prisma/client"; import { reevaluateVariant } from "~/server/utils/evaluations"; import { getCompletion } from "~/server/utils/getCompletion"; import { constructPrompt } from "~/server/utils/constructPrompt"; +import { type CompletionCreateParams } from "openai/resources/chat"; export const modelOutputsRouter = createTRPCRouter({ get: publicProcedure @@ -43,7 +44,7 @@ export const modelOutputsRouter = createTRPCRouter({ if (!variant || !scenario) return null; - const prompt = await constructPrompt(variant, scenario); + const prompt = await constructPrompt(variant, scenario.variableValues); const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex"); @@ -65,7 +66,10 @@ export const modelOutputsRouter = createTRPCRouter({ }; } else { try { - modelResponse = await getCompletion(prompt, input.channel); + modelResponse = await getCompletion( + prompt as unknown as CompletionCreateParams, + input.channel, + ); } catch (e) { console.error(e); throw e; diff --git a/src/server/api/routers/promptVariants.router.ts b/src/server/api/routers/promptVariants.router.ts index 6e10ea6..943c4c1 100644 --- a/src/server/api/routers/promptVariants.router.ts +++ b/src/server/api/routers/promptVariants.router.ts @@ -1,6 +1,10 @@ +import { isObject } from "lodash"; import { z } from "zod"; import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { prisma } from "~/server/db"; +import { OpenAIChatModel } from "~/server/types"; +import { constructPrompt } from "~/server/utils/constructPrompt"; +import userError from "~/server/utils/error"; import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated"; import { calculateTokenCost } from "~/utils/calculateTokenCost"; @@ -57,14 +61,10 @@ export const promptVariantsRouter = createTRPCRouter({ }, }); - // TODO: fix this - const model = "gpt-3.5-turbo-0613"; - // const model = getModelName(variant.config); - const promptTokens = overallTokens._sum?.promptTokens ?? 0; - const overallPromptCost = calculateTokenCost(model, promptTokens); + const overallPromptCost = calculateTokenCost(variant.model, promptTokens); const completionTokens = overallTokens._sum?.completionTokens ?? 0; - const overallCompletionCost = calculateTokenCost(model, completionTokens, true); + const overallCompletionCost = calculateTokenCost(variant.model, completionTokens, true); const overallCost = overallPromptCost + overallCompletionCost; @@ -106,6 +106,7 @@ export const promptVariantsRouter = createTRPCRouter({ label: `Prompt Variant ${largestSortIndex + 2}`, sortIndex: (lastVariant?.sortIndex ?? 0) + 1, constructFn: lastVariant?.constructFn ?? "", + model: lastVariant?.model ?? "gpt-3.5-turbo", }, }); @@ -185,6 +186,27 @@ export const promptVariantsRouter = createTRPCRouter({ throw new Error(`Prompt Variant with id ${input.id} does not exist`); } + let model = existing.model; + try { + const contructedPrompt = await constructPrompt({ constructFn: input.constructFn }, null); + + if (!isObject(contructedPrompt)) { + return userError("Prompt is not an object"); + } + if (!("model" in contructedPrompt)) { + return userError("Prompt does not define a model"); + } + if ( + typeof contructedPrompt.model !== "string" || + !(contructedPrompt.model in OpenAIChatModel) + ) { + return userError("Prompt defines an invalid model"); + } + model = contructedPrompt.model; + } catch (e) { + return userError((e as Error).message); + } + // Create a duplicate with only the config changed const newVariant = await prisma.promptVariant.create({ data: { @@ -193,11 +215,12 @@ export const promptVariantsRouter = createTRPCRouter({ sortIndex: existing.sortIndex, uiId: existing.uiId, constructFn: input.constructFn, + model, }, }); // Hide anything with the same uiId besides the new one - const hideOldVariantsAction = prisma.promptVariant.updateMany({ + const hideOldVariants = prisma.promptVariant.updateMany({ where: { uiId: existing.uiId, id: { @@ -209,12 +232,9 @@ export const promptVariantsRouter = createTRPCRouter({ }, }); - await prisma.$transaction([ - hideOldVariantsAction, - recordExperimentUpdated(existing.experimentId), - ]); + await prisma.$transaction([hideOldVariants, recordExperimentUpdated(existing.experimentId)]); - return newVariant; + return { status: "ok" } as const; }), reorder: publicProcedure diff --git a/src/server/utils/constructPrompt.test.ts b/src/server/utils/constructPrompt.test.ts index 78c485f..d0c98ff 100644 --- a/src/server/utils/constructPrompt.test.ts +++ b/src/server/utils/constructPrompt.test.ts @@ -7,9 +7,7 @@ test.skip("constructPrompt", async () => { constructFn: `prompt = { "fooz": "bar" }`, }, { - variableValues: { - foo: "bar", - }, + foo: "bar", }, ); diff --git a/src/server/utils/constructPrompt.ts b/src/server/utils/constructPrompt.ts index 30eb5e6..3b7c283 100644 --- a/src/server/utils/constructPrompt.ts +++ b/src/server/utils/constructPrompt.ts @@ -6,10 +6,8 @@ const isolate = new ivm.Isolate({ memoryLimit: 128 }); export async function constructPrompt( variant: Pick, - testScenario: Pick, + scenario: TestScenario["variableValues"], ): Promise { - const scenario = testScenario.variableValues as JSONSerializable; - const code = ` const scenario = ${JSON.stringify(scenario, null, 2)}; let prompt diff --git a/src/server/utils/error.ts b/src/server/utils/error.ts new file mode 100644 index 0000000..f284a8d --- /dev/null +++ b/src/server/utils/error.ts @@ -0,0 +1,6 @@ +export default function userError(message: string): { status: "error"; message: string } { + return { + status: "error", + message, + }; +} diff --git a/src/server/utils/getCompletion.ts b/src/server/utils/getCompletion.ts index e7b33a2..61452e4 100644 --- a/src/server/utils/getCompletion.ts +++ b/src/server/utils/getCompletion.ts @@ -4,14 +4,11 @@ import { Prisma } from "@prisma/client"; import { streamChatCompletion } from "./openai"; import { wsConnection } from "~/utils/wsConnection"; import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat"; -import { type JSONSerializable, OpenAIChatModel } from "../types"; +import { type OpenAIChatModel } from "../types"; import { env } from "~/env.mjs"; import { countOpenAIChatTokens } from "~/utils/countTokens"; -import { getModelName } from "./getModelName"; import { rateLimitErrorMessage } from "~/sharedStrings"; -env; - type CompletionResponse = { output: Prisma.InputJsonValue | typeof Prisma.JsonNull; statusCode: number; @@ -22,35 +19,7 @@ type CompletionResponse = { }; export async function getCompletion( - payload: JSONSerializable, - channel?: string, -): Promise { - const modelName = getModelName(payload); - if (!modelName) - return { - output: Prisma.JsonNull, - statusCode: 400, - errorMessage: "Invalid payload provided", - timeToComplete: 0, - }; - if (modelName in OpenAIChatModel) { - return getOpenAIChatCompletion( - payload as unknown as CompletionCreateParams, - env.OPENAI_API_KEY, - channel, - ); - } - return { - output: Prisma.JsonNull, - statusCode: 400, - errorMessage: "Invalid model provided", - timeToComplete: 0, - }; -} - -export async function getOpenAIChatCompletion( payload: CompletionCreateParams, - apiKey: string, channel?: string, ): Promise { // If functions are enabled, disable streaming so that we get the full response with token counts @@ -60,7 +29,7 @@ export async function getOpenAIChatCompletion( method: "POST", headers: { "Content-Type": "application/json", - Authorization: `Bearer ${apiKey}`, + Authorization: `Bearer ${env.OPENAI_API_KEY}`, }, body: JSON.stringify(payload), }); diff --git a/src/server/utils/getModelName.ts b/src/server/utils/getModelName.ts deleted file mode 100644 index 4fcaef1..0000000 --- a/src/server/utils/getModelName.ts +++ /dev/null @@ -1,9 +0,0 @@ -import { isObject } from "lodash"; -import { type JSONSerializable, type SupportedModel } from "../types"; -import { type Prisma } from "@prisma/client"; - -export function getModelName(config: JSONSerializable | Prisma.JsonValue): SupportedModel | null { - if (!isObject(config)) return null; - if ("model" in config && typeof config.model === "string") return config.model as SupportedModel; - return null; -} diff --git a/src/state/sharedVariantEditor.slice.ts b/src/state/sharedVariantEditor.slice.ts index ce455db..d6d8bd8 100644 --- a/src/state/sharedVariantEditor.slice.ts +++ b/src/state/sharedVariantEditor.slice.ts @@ -14,6 +14,9 @@ export type SharedVariantEditorSlice = { export const createVariantEditorSlice: SliceCreator = (set, get) => ({ monaco: loader.__getMonacoInstance(), loadMonaco: async () => { + // We only want to run this client-side + if (typeof window === "undefined") return; + const monaco = await loader.init(); monaco.editor.defineTheme("customTheme", { diff --git a/src/utils/calculateTokenCost.ts b/src/utils/calculateTokenCost.ts index a64a289..56c6f0a 100644 --- a/src/utils/calculateTokenCost.ts +++ b/src/utils/calculateTokenCost.ts @@ -23,7 +23,7 @@ const openAICompletionTokensToDollars: { [key in OpenAIChatModel]: number } = { }; export const calculateTokenCost = ( - model: SupportedModel | null, + model: SupportedModel | string | null, numTokens: number, isCompletion = false, ) => {