From a31c1127458bcb3554f8b54152acc7b1ce5c7973 Mon Sep 17 00:00:00 2001 From: Kyle Corbitt Date: Thu, 22 Jun 2023 17:57:21 -0700 Subject: [PATCH] we're actually calling openai --- package.json | 1 + pnpm-lock.yaml | 33 +++++++++++ prisma/schema.prisma | 5 +- prisma/seed.ts | 12 ++-- src/components/OutputsTable/OutputCell.tsx | 19 ++++++ src/components/OutputsTable/VariantHeader.tsx | 59 +++++++++++++++++++ src/components/OutputsTable/index.tsx | 51 ++++++++++++---- src/components/OutputsTable/types.ts | 5 ++ src/components/nav/AppNav.tsx | 8 +-- src/pages/experiments/[id].tsx | 7 +-- src/server/api/root.router.ts | 2 + src/server/api/routers/modelOutputs.router.ts | 53 +++++++++++++++++ src/server/utils/fillTemplate.ts | 27 +++++++++ src/server/utils/openai.ts | 19 ++++++ 14 files changed, 271 insertions(+), 30 deletions(-) create mode 100644 src/components/OutputsTable/OutputCell.tsx create mode 100644 src/components/OutputsTable/VariantHeader.tsx create mode 100644 src/components/OutputsTable/types.ts create mode 100644 src/server/api/routers/modelOutputs.router.ts create mode 100644 src/server/utils/fillTemplate.ts create mode 100644 src/server/utils/openai.ts diff --git a/package.json b/package.json index cc7c9b6..93f669e 100644 --- a/package.json +++ b/package.json @@ -18,6 +18,7 @@ "@mantine/form": "^6.0.14", "@mantine/hooks": "^6.0.14", "@mantine/next": "^6.0.14", + "@monaco-editor/react": "^4.5.1", "@next-auth/prisma-adapter": "^1.0.5", "@prisma/client": "^4.14.0", "@t3-oss/env-nextjs": "^0.3.1", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index d0a816a..c16098c 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -26,6 +26,9 @@ dependencies: '@mantine/next': specifier: ^6.0.14 version: 6.0.14(@emotion/react@11.11.1)(@emotion/server@11.11.0)(next@13.4.2)(react-dom@18.2.0)(react@18.2.0) + '@monaco-editor/react': + specifier: ^4.5.1 + version: 4.5.1(monaco-editor@0.39.0)(react-dom@18.2.0)(react@18.2.0) '@next-auth/prisma-adapter': specifier: ^1.0.5 version: 1.0.5(@prisma/client@4.14.0)(next-auth@4.22.1) @@ -688,6 +691,28 @@ packages: react: 18.2.0 dev: false + /@monaco-editor/loader@1.3.3(monaco-editor@0.39.0): + resolution: {integrity: sha512-6KKF4CTzcJiS8BJwtxtfyYt9shBiEv32ateQ9T4UVogwn4HM/uPo9iJd2Dmbkpz8CM6Y0PDUpjnZzCwC+eYo2Q==} + peerDependencies: + monaco-editor: '>= 0.21.0 < 1' + dependencies: + monaco-editor: 0.39.0 + state-local: 1.0.7 + dev: false + + /@monaco-editor/react@4.5.1(monaco-editor@0.39.0)(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-NNDFdP+2HojtNhCkRfE6/D6ro6pBNihaOzMbGK84lNWzRu+CfBjwzGt4jmnqimLuqp5yE5viHS2vi+QOAnD5FQ==} + peerDependencies: + monaco-editor: '>= 0.25.0 < 1' + react: ^16.8.0 || ^17.0.0 || ^18.0.0 + react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 + dependencies: + '@monaco-editor/loader': 1.3.3(monaco-editor@0.39.0) + monaco-editor: 0.39.0 + react: 18.2.0 + react-dom: 18.2.0(react@18.2.0) + dev: false + /@next-auth/prisma-adapter@1.0.5(@prisma/client@4.14.0)(next-auth@4.22.1): resolution: {integrity: sha512-VqMS11IxPXrPGXw6Oul6jcyS/n8GLOWzRMrPr3EMdtD6eOalM6zz05j08PcNiis8QzkfuYnCv49OvufTuaEwYQ==} peerDependencies: @@ -2882,6 +2907,10 @@ packages: /minimist@1.2.8: resolution: {integrity: sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==} + /monaco-editor@0.39.0: + resolution: {integrity: sha512-zhbZ2Nx93tLR8aJmL2zI1mhJpsl87HMebNBM6R8z4pLfs8pj604pIVIVwyF1TivcfNtIPpMXL+nb3DsBmE/x6Q==} + dev: false + /ms@2.1.2: resolution: {integrity: sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==} dev: true @@ -3516,6 +3545,10 @@ packages: engines: {node: '>=0.10.0'} dev: false + /state-local@1.0.7: + resolution: {integrity: sha512-HTEHMNieakEnoe33shBYcZ7NX83ACUjCu8c40iOGEZsngj9zRnkqS9j1pqQPXwobB0ZcVTk27REb7COQ0UR59w==} + dev: false + /streamsearch@1.1.0: resolution: {integrity: sha512-Mcc5wHehp9aXz1ax6bZUyY5afg9u2rv5cqQI3mRrYkGC8rW2hM02jWuwjtL++LS5qinSyhj2QfLyNsuc+VsExg==} engines: {node: '>=10.0.0'} diff --git a/prisma/schema.prisma b/prisma/schema.prisma index 9b051c2..2405709 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -70,15 +70,14 @@ model TemplateVariable { model ModelOutput { id String @id @default(uuid()) @db.Uuid + output Json + promptVariantId String @db.Uuid promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id]) testScenarioId String @db.Uuid testScenario TestScenario @relation(fields: [testScenarioId], references: [id]) - variableValues Json - inputsHash String @unique - createdAt DateTime @default(now()) updatedAt DateTime @updatedAt diff --git a/prisma/seed.ts b/prisma/seed.ts index e0a5965..3b8accf 100644 --- a/prisma/seed.ts +++ b/prisma/seed.ts @@ -22,21 +22,21 @@ const resp = await prisma.promptVariant.createMany({ data: [ { experimentId, - label: "Variant 1", + label: "Prompt Variant 1", sortIndex: 0, config: { model: "gpt-3.5-turbo", - messages: [{ role: "user", content: "What is the capitol of {{input}}?" }], + messages: [{ role: "user", content: "What is the capitol of {{state}}?" }], temperature: 0, }, }, { experimentId, - label: "Variant 2", + label: "Prompt Variant 2", sortIndex: 1, config: { model: "gpt-3.5-turbo", - messages: [{ role: "user", content: "What is the capitol of the US state {{input}}?" }], + messages: [{ role: "user", content: "What is the capitol of the US state {{state}}?" }], temperature: 0, }, }, @@ -69,13 +69,13 @@ await prisma.testScenario.createMany({ { experimentId, variableValues: { - input: "Washington", + state: "Washington", }, }, { experimentId, variableValues: { - input: "Georgia", + state: "Georgia", }, }, ], diff --git a/src/components/OutputsTable/OutputCell.tsx b/src/components/OutputsTable/OutputCell.tsx new file mode 100644 index 0000000..1566121 --- /dev/null +++ b/src/components/OutputsTable/OutputCell.tsx @@ -0,0 +1,19 @@ +import { api } from "~/utils/api"; +import { PromptVariant, Scenario } from "./types"; + +export default function OutputCell({ + scenario, + variant, +}: { + scenario: Scenario; + variant: PromptVariant; +}) { + const output = api.outputs.get.useQuery({ + scenarioId: scenario.id, + variantId: variant.id, + }); + + if (!output.data) return null; + + return
{JSON.stringify(output.data.output.choices[0].message.content, null, 2)}
; +} diff --git a/src/components/OutputsTable/VariantHeader.tsx b/src/components/OutputsTable/VariantHeader.tsx new file mode 100644 index 0000000..c2120ed --- /dev/null +++ b/src/components/OutputsTable/VariantHeader.tsx @@ -0,0 +1,59 @@ +import { Header, Stack, Title } from "@mantine/core"; +import { PromptVariant } from "@prisma/client"; +import { useMonaco } from "@monaco-editor/react"; +import { useEffect, useRef, useState } from "react"; + +let isThemeDefined = false; + +export default function VariantHeader({ variant }: { variant: PromptVariant }) { + const monaco = useMonaco(); + const editorRef = useRef(null); + const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`); + + useEffect(() => { + if (monaco && !isThemeDefined) { + monaco.editor.defineTheme("customTheme", { + base: "vs", + inherit: true, + rules: [], + colors: { + "editor.background": "#fafafa", + }, + }); + isThemeDefined = true; + } + }, [monaco]); + + useEffect(() => { + if (monaco) { + editorRef.current = monaco.editor.create(document.getElementById(editorId), { + value: JSON.stringify(variant.config, null, 2), + language: "json", + theme: "customTheme", + lineNumbers: "off", + minimap: { enabled: false }, + wrappingIndent: "indent", + wrappingStrategy: "advanced", + wordWrap: "on", + folding: false, + scrollbar: { vertical: "hidden" }, + wordWrapBreakAfterCharacters: "", + wordWrapBreakBeforeCharacters: "", + }); + } + + // Clean up the editor instance on unmount + return () => { + if (editorRef.current) { + editorRef.current.dispose(); + } + }; + }, [monaco, variant, editorId]); + + return ( + + {variant.label} +
+
+ ); +} diff --git a/src/components/OutputsTable/index.tsx b/src/components/OutputsTable/index.tsx index 5251a6b..741c74d 100644 --- a/src/components/OutputsTable/index.tsx +++ b/src/components/OutputsTable/index.tsx @@ -1,9 +1,12 @@ import { MRT_ColumnDef, MantineReactTable } from "mantine-react-table"; import { useMemo } from "react"; import { RouterOutputs, api } from "~/utils/api"; +import { PromptVariant } from "./types"; +import VariantHeader from "./VariantHeader"; +import OutputCell from "./OutputCell"; type CellData = { - variant: NonNullable[0]; + variant: PromptVariant; scenario: NonNullable[0]; }; @@ -42,15 +45,9 @@ export default function OutputsTable({ experimentId }: { experimentId: string | (variant): MRT_ColumnDef => ({ id: variant.id, header: variant.label, - // size: 300, - Cell: ({ row }) => { - const cellData = row.original[variant.id]; - return ( -
- {row.original.scenario.id} | {variant.id} -
- ); - }, + Header: , + size: 400, + Cell: ({ row }) => , }) ) ?? []), ], @@ -64,9 +61,11 @@ export default function OutputsTable({ experimentId }: { experimentId: string | scenario, } as TableRow; }) ?? [], - [variants.data, scenarios.data] + [scenarios.data] ); + if (!variants.data || !scenarios.data) return null; + return ( .mantine-TableHeadCell-Content-Labels": { + width: "100%", + height: "100%", + + "& > .mantine-TableHeadCell-Content-Wrapper": { + width: "100%", + height: "100%", + }, + }, + }, + }, + }} /> ); - // return
OutputsTable
; } diff --git a/src/components/OutputsTable/types.ts b/src/components/OutputsTable/types.ts new file mode 100644 index 0000000..10f7623 --- /dev/null +++ b/src/components/OutputsTable/types.ts @@ -0,0 +1,5 @@ +import { RouterOutputs } from "~/utils/api"; + +export type PromptVariant = NonNullable[0]; + +export type Scenario = NonNullable[0]; diff --git a/src/components/nav/AppNav.tsx b/src/components/nav/AppNav.tsx index 8cf7892..9eaf04d 100644 --- a/src/components/nav/AppNav.tsx +++ b/src/components/nav/AppNav.tsx @@ -19,7 +19,7 @@ const useStyles = createStyles((theme) => ({ paddingBottom: theme.spacing.md, marginBottom: `calc(${theme.spacing.md} * 1.5)`, borderBottom: `${rem(1)} solid ${ - theme.colorScheme === "dark" ? theme.colors.dark[4] : theme.colors.gray[2] + theme.colorScheme === "dark" ? theme.colors.dark[4] : theme.colors.gray[4] }`, }, @@ -27,12 +27,12 @@ const useStyles = createStyles((theme) => ({ paddingTop: theme.spacing.md, marginTop: theme.spacing.md, borderTop: `${rem(1)} solid ${ - theme.colorScheme === "dark" ? theme.colors.dark[4] : theme.colors.gray[2] + theme.colorScheme === "dark" ? theme.colors.dark[4] : theme.colors.gray[4] }`, }, link: { - ...theme.fn.focusStyles(), + ...(theme.fn.focusStyles() as Record), display: "flex", alignItems: "center", textDecoration: "none", @@ -101,7 +101,7 @@ export default function AppNav(props: { children: React.ReactNode; title?: strin return ( - {props.title && `${props.title} | `}Prompt Bench + {props.title ? `${props.title} | Prompt Bench` : "Prompt Bench"} diff --git a/src/pages/experiments/[id].tsx b/src/pages/experiments/[id].tsx index abdc882..37f650b 100644 --- a/src/pages/experiments/[id].tsx +++ b/src/pages/experiments/[id].tsx @@ -7,16 +7,15 @@ import { api } from "~/utils/api"; export default function Experiment() { const router = useRouter(); - console.log(router.query.id); const experiment = api.experiments.get.useQuery( { id: router.query.id as string }, { enabled: !!router.query.id } ); - if (!experiment.data) { + if (!experiment.isLoading && !experiment.data) { return ( -
+
Experiment not found 😕
@@ -24,7 +23,7 @@ export default function Experiment() { } return ( - + diff --git a/src/server/api/root.router.ts b/src/server/api/root.router.ts index 32e303f..dc00e39 100644 --- a/src/server/api/root.router.ts +++ b/src/server/api/root.router.ts @@ -2,6 +2,7 @@ import { promptVariantsRouter } from "~/server/api/routers/promptVariants.router import { createTRPCRouter } from "~/server/api/trpc"; import { experimentsRouter } from "./routers/experiments.router"; import { scenariosRouter } from "./routers/scenarios.router"; +import { modelOutputsRouter } from "./routers/modelOutputs.router"; /** * This is the primary router for your server. @@ -12,6 +13,7 @@ export const appRouter = createTRPCRouter({ promptVariants: promptVariantsRouter, experiments: experimentsRouter, scenarios: scenariosRouter, + outputs: modelOutputsRouter, }); // export type definition of API diff --git a/src/server/api/routers/modelOutputs.router.ts b/src/server/api/routers/modelOutputs.router.ts new file mode 100644 index 0000000..40d08fc --- /dev/null +++ b/src/server/api/routers/modelOutputs.router.ts @@ -0,0 +1,53 @@ +import { z } from "zod"; +import { createTRPCRouter, publicProcedure, protectedProcedure } from "~/server/api/trpc"; +import { prisma } from "~/server/db"; +import fillTemplate, { JSONSerializable, VariableMap } from "~/server/utils/fillTemplate"; +import { getChatCompletion } from "~/server/utils/openai"; + +export const modelOutputsRouter = createTRPCRouter({ + get: publicProcedure + .input(z.object({ scenarioId: z.string(), variantId: z.string() })) + .query(async ({ input }) => { + const existing = await prisma.modelOutput.findUnique({ + where: { + promptVariantId_testScenarioId: { + promptVariantId: input.variantId, + testScenarioId: input.scenarioId, + }, + }, + }); + + if (existing) return existing; + + const variant = await prisma.promptVariant.findUnique({ + where: { + id: input.variantId, + }, + }); + + const scenario = await prisma.testScenario.findUnique({ + where: { + id: input.scenarioId, + }, + }); + + if (!variant || !scenario) return null; + + const filledTemplate = fillTemplate( + variant.config as JSONSerializable, + scenario.variableValues as VariableMap + ); + + const modelResponse = await getChatCompletion(filledTemplate, process.env.OPENAI_API_KEY!); + + const modelOutput = await prisma.modelOutput.create({ + data: { + promptVariantId: input.variantId, + testScenarioId: input.scenarioId, + output: modelResponse, + }, + }); + + return modelOutput; + }), +}); diff --git a/src/server/utils/fillTemplate.ts b/src/server/utils/fillTemplate.ts new file mode 100644 index 0000000..cb30462 --- /dev/null +++ b/src/server/utils/fillTemplate.ts @@ -0,0 +1,27 @@ +export type JSONSerializable = + | string + | number + | boolean + | null + | JSONSerializable[] + | { [key: string]: JSONSerializable }; + +export type VariableMap = Record; + +export default function fillTemplate( + template: T, + variables: VariableMap +): T { + if (typeof template === "string") { + return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || "") as T; + } else if (Array.isArray(template)) { + return template.map((item) => fillTemplate(item, variables)) as T; + } else if (typeof template === "object" && template !== null) { + return Object.keys(template).reduce((acc, key) => { + acc[key] = fillTemplate(template[key] as JSONSerializable, variables); + return acc; + }, {} as { [key: string]: JSONSerializable } & T); + } else { + return template; + } +} diff --git a/src/server/utils/openai.ts b/src/server/utils/openai.ts new file mode 100644 index 0000000..5f5156f --- /dev/null +++ b/src/server/utils/openai.ts @@ -0,0 +1,19 @@ +import { JSONSerializable } from "./fillTemplate"; + +export async function getChatCompletion(payload: JSONSerializable, apiKey: string) { + const response = await fetch("https://api.openai.com/v1/chat/completions", { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${apiKey}`, + }, + body: JSON.stringify(payload), + }); + + if (!response.ok) { + throw new Error(`OpenAI API request failed with status ${response.status}`); + } + + const data = (await response.json()) as JSONSerializable; + return data; +}