From ca78406ad17a32d6c7374a52e159583b45f59254 Mon Sep 17 00:00:00 2001 From: Kyle Corbitt Date: Fri, 23 Jun 2023 12:18:09 -0700 Subject: [PATCH] saving prompt configs works --- package.json | 1 + pnpm-lock.yaml | 40 ++++++++ prisma/seed.ts | 8 ++ .../OutputsTable/VariantConfigEditor.tsx | 78 +++++++++++++++ src/components/OutputsTable/VariantHeader.tsx | 94 +++++++++---------- src/components/OutputsTable/index.tsx | 16 +++- src/pages/_app.tsx | 2 + src/server/api/routers/modelOutputs.router.ts | 3 +- .../api/routers/promptVariants.router.ts | 55 +++++++++++ src/server/types.ts | 10 ++ src/server/utils/fillTemplate.ts | 8 +- src/server/utils/openai.ts | 2 +- src/utils/hooks.ts | 35 +++++++ 13 files changed, 291 insertions(+), 61 deletions(-) create mode 100644 src/components/OutputsTable/VariantConfigEditor.tsx create mode 100644 src/server/types.ts diff --git a/package.json b/package.json index 93f669e..ca81000 100644 --- a/package.json +++ b/package.json @@ -18,6 +18,7 @@ "@mantine/form": "^6.0.14", "@mantine/hooks": "^6.0.14", "@mantine/next": "^6.0.14", + "@mantine/notifications": "^6.0.14", "@monaco-editor/react": "^4.5.1", "@next-auth/prisma-adapter": "^1.0.5", "@prisma/client": "^4.14.0", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index c16098c..d15e772 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -26,6 +26,9 @@ dependencies: '@mantine/next': specifier: ^6.0.14 version: 6.0.14(@emotion/react@11.11.1)(@emotion/server@11.11.0)(next@13.4.2)(react-dom@18.2.0)(react@18.2.0) + '@mantine/notifications': + specifier: ^6.0.14 + version: 6.0.14(@mantine/core@6.0.14)(@mantine/hooks@6.0.14)(react-dom@18.2.0)(react@18.2.0) '@monaco-editor/react': specifier: ^4.5.1 version: 4.5.1(monaco-editor@0.39.0)(react-dom@18.2.0)(react@18.2.0) @@ -653,6 +656,22 @@ packages: - '@emotion/server' dev: false + /@mantine/notifications@6.0.14(@mantine/core@6.0.14)(@mantine/hooks@6.0.14)(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-ElzIVojgAplm9Gtq1qZWR/kjGupttRq8ctTUYmANV8yyXcbpErFr45RlYjDgJs2klQcZid3Pq7hVsjGKLF2MQw==} + peerDependencies: + '@mantine/core': 6.0.14 + '@mantine/hooks': 6.0.14 + react: '>=16.8.0' + react-dom: '>=16.8.0' + dependencies: + '@mantine/core': 6.0.14(@emotion/react@11.11.1)(@mantine/hooks@6.0.14)(@types/react@18.2.6)(react-dom@18.2.0)(react@18.2.0) + '@mantine/hooks': 6.0.14(react@18.2.0) + '@mantine/utils': 6.0.14(react@18.2.0) + react: 18.2.0 + react-dom: 18.2.0(react@18.2.0) + react-transition-group: 4.4.2(react-dom@18.2.0)(react@18.2.0) + dev: false + /@mantine/ssr@6.0.14(@emotion/react@11.11.1)(@emotion/server@11.11.0)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-vYWSUFIuwUyhtyAMUqceZHR5GslwPIY8/C1vhPF5xXwhLCoY33jpBd+06cqvMmge624NKUerTQKE3Lw39Yli8A==} peerDependencies: @@ -1720,6 +1739,13 @@ packages: esutils: 2.0.3 dev: true + /dom-helpers@5.2.1: + resolution: {integrity: sha512-nRCa7CK3VTrM2NmGkIy4cbK7IZlgBE/PYMn55rrXefr5xXDP0LdtfPnblFDoVdcAfslJ7or6iqAUnx0CCGIWQA==} + dependencies: + '@babel/runtime': 7.22.5 + csstype: 3.1.2 + dev: false + /dom-serializer@1.4.1: resolution: {integrity: sha512-VHwB3KfrcOOkelEG2ZOfxqLZdfkil8PtJi4P8N2MMXucZq2yLp75ClViUlOVwyoHEDjYU433Aq+5zWP61+RGag==} dependencies: @@ -3365,6 +3391,20 @@ packages: - '@types/react' dev: false + /react-transition-group@4.4.2(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-/RNYfRAMlZwDSr6z4zNKV6xu53/e2BuaBbGhbyYIXTrmgu/bGHzmqOs7mJSJBHy9Ud+ApHx3QjrkKSp1pxvlFg==} + peerDependencies: + react: '>=16.6.0' + react-dom: '>=16.6.0' + dependencies: + '@babel/runtime': 7.22.5 + dom-helpers: 5.2.1 + loose-envify: 1.4.0 + prop-types: 15.8.1 + react: 18.2.0 + react-dom: 18.2.0(react@18.2.0) + dev: false + /react@18.2.0: resolution: {integrity: sha512-/3IjMdb2L9QbBdWiW5e3P2/npwMBaU9mHCSCUzNln0ZCYbcfTsGbTJrU/kGemdH2IWmB2ioZ+zkxtmq6g09fGQ==} engines: {node: '>=0.10.0'} diff --git a/prisma/seed.ts b/prisma/seed.ts index 3b8accf..ebd1ad8 100644 --- a/prisma/seed.ts +++ b/prisma/seed.ts @@ -12,6 +12,14 @@ const experiment = await prisma.experiment.upsert({ }, }); +await prisma.modelOutput.deleteMany({ + where: { + promptVariant: { + experimentId, + }, + }, +}); + await prisma.promptVariant.deleteMany({ where: { experimentId, diff --git a/src/components/OutputsTable/VariantConfigEditor.tsx b/src/components/OutputsTable/VariantConfigEditor.tsx new file mode 100644 index 0000000..fb2716a --- /dev/null +++ b/src/components/OutputsTable/VariantConfigEditor.tsx @@ -0,0 +1,78 @@ +import { Box, Button, Stack, Title } from "@mantine/core"; +import { useMonaco } from "@monaco-editor/react"; +import { useRef, useEffect, useState } from "react"; +import { set } from "zod"; +import { useHandledAsyncCallback } from "~/utils/hooks"; + +let isThemeDefined = false; + +export default function VariantConfigEditor(props: { + initialConfig: string; + onSave: (currentConfig: string) => Promise; +}) { + const monaco = useMonaco(); + const editorRef = useRef["editor"]["create"]> | null>(null); + const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`); + const [isChanged, setIsChanged] = useState(false); + + const [onSave] = useHandledAsyncCallback(async () => { + const currentConfig = editorRef.current?.getValue(); + if (!currentConfig) return; + await props.onSave(currentConfig); + }, [props.onSave]); + + useEffect(() => { + if (monaco) { + if (!isThemeDefined) { + monaco.editor.defineTheme("customTheme", { + base: "vs", + inherit: true, + rules: [], + colors: { + "editor.background": "#fafafa", + }, + }); + isThemeDefined = true; + } + + editorRef.current = monaco.editor.create(document.getElementById(editorId) as HTMLElement, { + value: props.initialConfig, + language: "json", + theme: "customTheme", + lineNumbers: "off", + minimap: { enabled: false }, + wrappingIndent: "indent", + wrappingStrategy: "advanced", + wordWrap: "on", + folding: false, + scrollbar: { vertical: "hidden", alwaysConsumeMouseWheel: false }, + wordWrapBreakAfterCharacters: "", + wordWrapBreakBeforeCharacters: "", + }); + + editorRef.current.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.Enter, onSave); + + editorRef.current.onDidChangeModelContent(() => { + const currentConfig = editorRef.current?.getValue(); + if (currentConfig !== props.initialConfig) { + setIsChanged(true); + } else { + setIsChanged(false); + } + }); + + return () => editorRef.current?.dispose(); + } + }, [monaco, editorId, props.initialConfig, onSave]); + + return ( + +
+ {isChanged && ( + + )} +
+ ); +} diff --git a/src/components/OutputsTable/VariantHeader.tsx b/src/components/OutputsTable/VariantHeader.tsx index c2120ed..1f19fbd 100644 --- a/src/components/OutputsTable/VariantHeader.tsx +++ b/src/components/OutputsTable/VariantHeader.tsx @@ -1,59 +1,59 @@ -import { Header, Stack, Title } from "@mantine/core"; -import { PromptVariant } from "@prisma/client"; +import { Button, Stack, Title } from "@mantine/core"; import { useMonaco } from "@monaco-editor/react"; -import { useEffect, useRef, useState } from "react"; - -let isThemeDefined = false; +import { useCallback, useEffect, useRef, useState } from "react"; +import type { PromptVariant } from "./types"; +import { api } from "~/utils/api"; +import { useHandledAsyncCallback } from "~/utils/hooks"; +import { notifications } from "@mantine/notifications"; +import { type JSONSerializable } from "~/server/types"; +import VariantConfigEditor from "./VariantConfigEditor"; export default function VariantHeader({ variant }: { variant: PromptVariant }) { - const monaco = useMonaco(); - const editorRef = useRef(null); - const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`); + const replaceWithConfig = api.promptVariants.replaceWithConfig.useMutation(); + const utils = api.useContext(); - useEffect(() => { - if (monaco && !isThemeDefined) { - monaco.editor.defineTheme("customTheme", { - base: "vs", - inherit: true, - rules: [], - colors: { - "editor.background": "#fafafa", - }, - }); - isThemeDefined = true; - } - }, [monaco]); - - useEffect(() => { - if (monaco) { - editorRef.current = monaco.editor.create(document.getElementById(editorId), { - value: JSON.stringify(variant.config, null, 2), - language: "json", - theme: "customTheme", - lineNumbers: "off", - minimap: { enabled: false }, - wrappingIndent: "indent", - wrappingStrategy: "advanced", - wordWrap: "on", - folding: false, - scrollbar: { vertical: "hidden" }, - wordWrapBreakAfterCharacters: "", - wordWrapBreakBeforeCharacters: "", - }); - } - - // Clean up the editor instance on unmount - return () => { - if (editorRef.current) { - editorRef.current.dispose(); + const onSave = useCallback( + async (currentConfig: string) => { + let parsedConfig: JSONSerializable; + try { + parsedConfig = JSON.parse(currentConfig) as JSONSerializable; + } catch (e) { + notifications.show({ + title: "Invalid JSON", + message: "Please fix the JSON before saving.", + color: "red", + }); + return; } - }; - }, [monaco, variant, editorId]); + + if (parsedConfig === null) { + notifications.show({ + title: "Invalid JSON", + message: "Please fix the JSON before saving.", + color: "red", + }); + return; + } + + await replaceWithConfig.mutateAsync({ + id: variant.id, + config: currentConfig, + }); + + await utils.promptVariants.list.invalidate(); + + // TODO: invalidate the variants query + }, + [variant.id, replaceWithConfig] + ); return ( {variant.label} -
+
); } diff --git a/src/components/OutputsTable/index.tsx b/src/components/OutputsTable/index.tsx index 741c74d..a7ffa05 100644 --- a/src/components/OutputsTable/index.tsx +++ b/src/components/OutputsTable/index.tsx @@ -30,8 +30,12 @@ export default function OutputsTable({ experimentId }: { experimentId: string | { enabled: !!experimentId } ); - const columns = useMemo[]>( - () => [ + const columns = useMemo[]>(() => { + console.log( + "rebuilding cols", + variants.data?.map((variant) => variant.label) + ); + return [ { id: "scenario", header: "Scenario", @@ -50,9 +54,8 @@ export default function OutputsTable({ experimentId }: { experimentId: string | Cell: ({ row }) => , }) ) ?? []), - ], - [variants.data] - ); + ]; + }, [variants.data]); const tableData = useMemo( () => @@ -84,6 +87,9 @@ export default function OutputsTable({ experimentId }: { experimentId: string | enableHiding={false} enableColumnActions={false} enableColumnResizing + state={{ + columnOrder: ["scenario", ...variants.data.map((variant) => variant.id)], + }} mantineTableProps={{ sx: { th: { diff --git a/src/pages/_app.tsx b/src/pages/_app.tsx index 0de730d..6f37f81 100644 --- a/src/pages/_app.tsx +++ b/src/pages/_app.tsx @@ -3,6 +3,7 @@ import { SessionProvider } from "next-auth/react"; import { type AppType } from "next/app"; import { api } from "~/utils/api"; import { MantineProvider } from "@mantine/core"; +import { Notifications } from "@mantine/notifications"; const MyApp: AppType<{ session: Session | null }> = ({ Component, @@ -11,6 +12,7 @@ const MyApp: AppType<{ session: Session | null }> = ({ return ( + diff --git a/src/server/api/routers/modelOutputs.router.ts b/src/server/api/routers/modelOutputs.router.ts index 40d08fc..75249cd 100644 --- a/src/server/api/routers/modelOutputs.router.ts +++ b/src/server/api/routers/modelOutputs.router.ts @@ -1,7 +1,8 @@ import { z } from "zod"; import { createTRPCRouter, publicProcedure, protectedProcedure } from "~/server/api/trpc"; import { prisma } from "~/server/db"; -import fillTemplate, { JSONSerializable, VariableMap } from "~/server/utils/fillTemplate"; +import fillTemplate, { VariableMap } from "~/server/utils/fillTemplate"; +import { JSONSerializable } from "~/server/types"; import { getChatCompletion } from "~/server/utils/openai"; export const modelOutputsRouter = createTRPCRouter({ diff --git a/src/server/api/routers/promptVariants.router.ts b/src/server/api/routers/promptVariants.router.ts index ff3a254..b51f522 100644 --- a/src/server/api/routers/promptVariants.router.ts +++ b/src/server/api/routers/promptVariants.router.ts @@ -1,13 +1,68 @@ import { z } from "zod"; import { createTRPCRouter, publicProcedure, protectedProcedure } from "~/server/api/trpc"; import { prisma } from "~/server/db"; +import { JSONSerializable, OpenAIChatConfig } from "~/server/types"; export const promptVariantsRouter = createTRPCRouter({ list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { return await prisma.promptVariant.findMany({ where: { experimentId: input.experimentId, + visible: true, + }, + orderBy: { + sortIndex: "asc", }, }); }), + + replaceWithConfig: publicProcedure + .input( + z.object({ + id: z.string(), + config: z.string(), + }) + ) + .mutation(async ({ input }) => { + const existing = await prisma.promptVariant.findUnique({ + where: { + id: input.id, + }, + }); + + console.log("got existing", existing); + console.log("config", input.config); + + let parsedConfig; + try { + parsedConfig = JSON.parse(input.config) as OpenAIChatConfig; + } catch (e) { + throw new Error(`Invalid JSON: ${(e as Error).message}`); + } + + if (!existing) { + throw new Error(`Prompt Variant with id ${input.id} does not exist`); + } + + // Create a duplicate with only the config changed + const newVariant = await prisma.promptVariant.create({ + data: { + experimentId: existing.experimentId, + label: existing.label, + sortIndex: existing.sortIndex, + config: parsedConfig, + }, + }); + + await prisma.promptVariant.update({ + where: { + id: input.id, + }, + data: { + visible: false, + }, + }); + + return newVariant; + }), }); diff --git a/src/server/types.ts b/src/server/types.ts new file mode 100644 index 0000000..d110298 --- /dev/null +++ b/src/server/types.ts @@ -0,0 +1,10 @@ +export type JSONSerializable = + | string + | number + | boolean + | null + | JSONSerializable[] + | { [key: string]: JSONSerializable }; + +// Placeholder for now +export type OpenAIChatConfig = NonNullable; diff --git a/src/server/utils/fillTemplate.ts b/src/server/utils/fillTemplate.ts index cb30462..d39cdce 100644 --- a/src/server/utils/fillTemplate.ts +++ b/src/server/utils/fillTemplate.ts @@ -1,10 +1,4 @@ -export type JSONSerializable = - | string - | number - | boolean - | null - | JSONSerializable[] - | { [key: string]: JSONSerializable }; +import { JSONSerializable } from "../types"; export type VariableMap = Record; diff --git a/src/server/utils/openai.ts b/src/server/utils/openai.ts index 5f5156f..b83967a 100644 --- a/src/server/utils/openai.ts +++ b/src/server/utils/openai.ts @@ -1,4 +1,4 @@ -import { JSONSerializable } from "./fillTemplate"; +import { JSONSerializable } from "../types"; export async function getChatCompletion(payload: JSONSerializable, apiKey: string) { const response = await fetch("https://api.openai.com/v1/chat/completions", { diff --git a/src/utils/hooks.ts b/src/utils/hooks.ts index 94f7a26..29a033a 100644 --- a/src/utils/hooks.ts +++ b/src/utils/hooks.ts @@ -1,4 +1,5 @@ import { useRouter } from "next/router"; +import { useCallback, useEffect, useState } from "react"; import { api } from "~/utils/api"; export const useExperiment = () => { @@ -10,3 +11,37 @@ export const useExperiment = () => { return experiment; }; + +export function useHandledAsyncCallback Promise>( + callback: T, + deps: React.DependencyList +) { + const [loading, setLoading] = useState(false); + const [error, setError] = useState(null); + + const wrappedCallback = useCallback((...args: Parameters) => { + setLoading(true); + setError(null); + + callback(...args) + .catch((error) => { + setError(error as Error); + console.error(error); + }) + .finally(() => { + setLoading(false); + }); + }, deps); + + return [wrappedCallback, loading, error] as const; +} + +// Have to do this ugly thing to convince Next not to try to access `navigator` +// on the server side at build time, when it isn't defined. +export const useModifierKeyLabel = () => { + const [label, setLabel] = useState(""); + useEffect(() => { + setLabel(navigator?.platform?.startsWith("Mac") ? "⌘" : "Ctrl"); + }, []); + return label; +};