saving prompt configs works

This commit is contained in:
Kyle Corbitt
2023-06-23 12:18:09 -07:00
parent a31c112745
commit ca78406ad1
13 changed files with 291 additions and 61 deletions

View File

@@ -18,6 +18,7 @@
"@mantine/form": "^6.0.14",
"@mantine/hooks": "^6.0.14",
"@mantine/next": "^6.0.14",
"@mantine/notifications": "^6.0.14",
"@monaco-editor/react": "^4.5.1",
"@next-auth/prisma-adapter": "^1.0.5",
"@prisma/client": "^4.14.0",

40
pnpm-lock.yaml generated
View File

@@ -26,6 +26,9 @@ dependencies:
'@mantine/next':
specifier: ^6.0.14
version: 6.0.14(@emotion/react@11.11.1)(@emotion/server@11.11.0)(next@13.4.2)(react-dom@18.2.0)(react@18.2.0)
'@mantine/notifications':
specifier: ^6.0.14
version: 6.0.14(@mantine/core@6.0.14)(@mantine/hooks@6.0.14)(react-dom@18.2.0)(react@18.2.0)
'@monaco-editor/react':
specifier: ^4.5.1
version: 4.5.1(monaco-editor@0.39.0)(react-dom@18.2.0)(react@18.2.0)
@@ -653,6 +656,22 @@ packages:
- '@emotion/server'
dev: false
/@mantine/notifications@6.0.14(@mantine/core@6.0.14)(@mantine/hooks@6.0.14)(react-dom@18.2.0)(react@18.2.0):
resolution: {integrity: sha512-ElzIVojgAplm9Gtq1qZWR/kjGupttRq8ctTUYmANV8yyXcbpErFr45RlYjDgJs2klQcZid3Pq7hVsjGKLF2MQw==}
peerDependencies:
'@mantine/core': 6.0.14
'@mantine/hooks': 6.0.14
react: '>=16.8.0'
react-dom: '>=16.8.0'
dependencies:
'@mantine/core': 6.0.14(@emotion/react@11.11.1)(@mantine/hooks@6.0.14)(@types/react@18.2.6)(react-dom@18.2.0)(react@18.2.0)
'@mantine/hooks': 6.0.14(react@18.2.0)
'@mantine/utils': 6.0.14(react@18.2.0)
react: 18.2.0
react-dom: 18.2.0(react@18.2.0)
react-transition-group: 4.4.2(react-dom@18.2.0)(react@18.2.0)
dev: false
/@mantine/ssr@6.0.14(@emotion/react@11.11.1)(@emotion/server@11.11.0)(react-dom@18.2.0)(react@18.2.0):
resolution: {integrity: sha512-vYWSUFIuwUyhtyAMUqceZHR5GslwPIY8/C1vhPF5xXwhLCoY33jpBd+06cqvMmge624NKUerTQKE3Lw39Yli8A==}
peerDependencies:
@@ -1720,6 +1739,13 @@ packages:
esutils: 2.0.3
dev: true
/dom-helpers@5.2.1:
resolution: {integrity: sha512-nRCa7CK3VTrM2NmGkIy4cbK7IZlgBE/PYMn55rrXefr5xXDP0LdtfPnblFDoVdcAfslJ7or6iqAUnx0CCGIWQA==}
dependencies:
'@babel/runtime': 7.22.5
csstype: 3.1.2
dev: false
/dom-serializer@1.4.1:
resolution: {integrity: sha512-VHwB3KfrcOOkelEG2ZOfxqLZdfkil8PtJi4P8N2MMXucZq2yLp75ClViUlOVwyoHEDjYU433Aq+5zWP61+RGag==}
dependencies:
@@ -3365,6 +3391,20 @@ packages:
- '@types/react'
dev: false
/react-transition-group@4.4.2(react-dom@18.2.0)(react@18.2.0):
resolution: {integrity: sha512-/RNYfRAMlZwDSr6z4zNKV6xu53/e2BuaBbGhbyYIXTrmgu/bGHzmqOs7mJSJBHy9Ud+ApHx3QjrkKSp1pxvlFg==}
peerDependencies:
react: '>=16.6.0'
react-dom: '>=16.6.0'
dependencies:
'@babel/runtime': 7.22.5
dom-helpers: 5.2.1
loose-envify: 1.4.0
prop-types: 15.8.1
react: 18.2.0
react-dom: 18.2.0(react@18.2.0)
dev: false
/react@18.2.0:
resolution: {integrity: sha512-/3IjMdb2L9QbBdWiW5e3P2/npwMBaU9mHCSCUzNln0ZCYbcfTsGbTJrU/kGemdH2IWmB2ioZ+zkxtmq6g09fGQ==}
engines: {node: '>=0.10.0'}

View File

@@ -12,6 +12,14 @@ const experiment = await prisma.experiment.upsert({
},
});
await prisma.modelOutput.deleteMany({
where: {
promptVariant: {
experimentId,
},
},
});
await prisma.promptVariant.deleteMany({
where: {
experimentId,

View File

@@ -0,0 +1,78 @@
import { Box, Button, Stack, Title } from "@mantine/core";
import { useMonaco } from "@monaco-editor/react";
import { useRef, useEffect, useState } from "react";
import { set } from "zod";
import { useHandledAsyncCallback } from "~/utils/hooks";
let isThemeDefined = false;
export default function VariantConfigEditor(props: {
initialConfig: string;
onSave: (currentConfig: string) => Promise<void>;
}) {
const monaco = useMonaco();
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
const [isChanged, setIsChanged] = useState(false);
const [onSave] = useHandledAsyncCallback(async () => {
const currentConfig = editorRef.current?.getValue();
if (!currentConfig) return;
await props.onSave(currentConfig);
}, [props.onSave]);
useEffect(() => {
if (monaco) {
if (!isThemeDefined) {
monaco.editor.defineTheme("customTheme", {
base: "vs",
inherit: true,
rules: [],
colors: {
"editor.background": "#fafafa",
},
});
isThemeDefined = true;
}
editorRef.current = monaco.editor.create(document.getElementById(editorId) as HTMLElement, {
value: props.initialConfig,
language: "json",
theme: "customTheme",
lineNumbers: "off",
minimap: { enabled: false },
wrappingIndent: "indent",
wrappingStrategy: "advanced",
wordWrap: "on",
folding: false,
scrollbar: { vertical: "hidden", alwaysConsumeMouseWheel: false },
wordWrapBreakAfterCharacters: "",
wordWrapBreakBeforeCharacters: "",
});
editorRef.current.addCommand(monaco.KeyMod.CtrlCmd | monaco.KeyCode.Enter, onSave);
editorRef.current.onDidChangeModelContent(() => {
const currentConfig = editorRef.current?.getValue();
if (currentConfig !== props.initialConfig) {
setIsChanged(true);
} else {
setIsChanged(false);
}
});
return () => editorRef.current?.dispose();
}
}, [monaco, editorId, props.initialConfig, onSave]);
return (
<Box w="100%" pos="relative">
<div id={editorId} style={{ height: "300px", width: "100%" }}></div>
{isChanged && (
<Button size="xs" sx={{ position: "absolute", bottom: 0, right: 0 }} onClick={onSave}>
Save
</Button>
)}
</Box>
);
}

View File

@@ -1,59 +1,59 @@
import { Header, Stack, Title } from "@mantine/core";
import { PromptVariant } from "@prisma/client";
import { Button, Stack, Title } from "@mantine/core";
import { useMonaco } from "@monaco-editor/react";
import { useEffect, useRef, useState } from "react";
let isThemeDefined = false;
import { useCallback, useEffect, useRef, useState } from "react";
import type { PromptVariant } from "./types";
import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks";
import { notifications } from "@mantine/notifications";
import { type JSONSerializable } from "~/server/types";
import VariantConfigEditor from "./VariantConfigEditor";
export default function VariantHeader({ variant }: { variant: PromptVariant }) {
const monaco = useMonaco();
const editorRef = useRef(null);
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
const replaceWithConfig = api.promptVariants.replaceWithConfig.useMutation();
const utils = api.useContext();
useEffect(() => {
if (monaco && !isThemeDefined) {
monaco.editor.defineTheme("customTheme", {
base: "vs",
inherit: true,
rules: [],
colors: {
"editor.background": "#fafafa",
const onSave = useCallback(
async (currentConfig: string) => {
let parsedConfig: JSONSerializable;
try {
parsedConfig = JSON.parse(currentConfig) as JSONSerializable;
} catch (e) {
notifications.show({
title: "Invalid JSON",
message: "Please fix the JSON before saving.",
color: "red",
});
return;
}
if (parsedConfig === null) {
notifications.show({
title: "Invalid JSON",
message: "Please fix the JSON before saving.",
color: "red",
});
return;
}
await replaceWithConfig.mutateAsync({
id: variant.id,
config: currentConfig,
});
await utils.promptVariants.list.invalidate();
// TODO: invalidate the variants query
},
});
isThemeDefined = true;
}
}, [monaco]);
useEffect(() => {
if (monaco) {
editorRef.current = monaco.editor.create(document.getElementById(editorId), {
value: JSON.stringify(variant.config, null, 2),
language: "json",
theme: "customTheme",
lineNumbers: "off",
minimap: { enabled: false },
wrappingIndent: "indent",
wrappingStrategy: "advanced",
wordWrap: "on",
folding: false,
scrollbar: { vertical: "hidden" },
wordWrapBreakAfterCharacters: "",
wordWrapBreakBeforeCharacters: "",
});
}
// Clean up the editor instance on unmount
return () => {
if (editorRef.current) {
editorRef.current.dispose();
}
};
}, [monaco, variant, editorId]);
[variant.id, replaceWithConfig]
);
return (
<Stack w="100%">
<Title order={4}>{variant.label}</Title>
<div id={editorId} style={{ height: "300px", width: "100%" }}></div>
<VariantConfigEditor
initialConfig={JSON.stringify(variant.config, null, 2)}
onSave={onSave}
/>
</Stack>
);
}

View File

@@ -30,8 +30,12 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
{ enabled: !!experimentId }
);
const columns = useMemo<MRT_ColumnDef<TableRow>[]>(
() => [
const columns = useMemo<MRT_ColumnDef<TableRow>[]>(() => {
console.log(
"rebuilding cols",
variants.data?.map((variant) => variant.label)
);
return [
{
id: "scenario",
header: "Scenario",
@@ -50,9 +54,8 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
Cell: ({ row }) => <OutputCell scenario={row.original.scenario} variant={variant} />,
})
) ?? []),
],
[variants.data]
);
];
}, [variants.data]);
const tableData = useMemo(
() =>
@@ -84,6 +87,9 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
enableHiding={false}
enableColumnActions={false}
enableColumnResizing
state={{
columnOrder: ["scenario", ...variants.data.map((variant) => variant.id)],
}}
mantineTableProps={{
sx: {
th: {

View File

@@ -3,6 +3,7 @@ import { SessionProvider } from "next-auth/react";
import { type AppType } from "next/app";
import { api } from "~/utils/api";
import { MantineProvider } from "@mantine/core";
import { Notifications } from "@mantine/notifications";
const MyApp: AppType<{ session: Session | null }> = ({
Component,
@@ -11,6 +12,7 @@ const MyApp: AppType<{ session: Session | null }> = ({
return (
<SessionProvider session={session}>
<MantineProvider withGlobalStyles withNormalizeCSS>
<Notifications position="bottom-center" />
<Component {...pageProps} />
</MantineProvider>
</SessionProvider>

View File

@@ -1,7 +1,8 @@
import { z } from "zod";
import { createTRPCRouter, publicProcedure, protectedProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import fillTemplate, { JSONSerializable, VariableMap } from "~/server/utils/fillTemplate";
import fillTemplate, { VariableMap } from "~/server/utils/fillTemplate";
import { JSONSerializable } from "~/server/types";
import { getChatCompletion } from "~/server/utils/openai";
export const modelOutputsRouter = createTRPCRouter({

View File

@@ -1,13 +1,68 @@
import { z } from "zod";
import { createTRPCRouter, publicProcedure, protectedProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { JSONSerializable, OpenAIChatConfig } from "~/server/types";
export const promptVariantsRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
return await prisma.promptVariant.findMany({
where: {
experimentId: input.experimentId,
visible: true,
},
orderBy: {
sortIndex: "asc",
},
});
}),
replaceWithConfig: publicProcedure
.input(
z.object({
id: z.string(),
config: z.string(),
})
)
.mutation(async ({ input }) => {
const existing = await prisma.promptVariant.findUnique({
where: {
id: input.id,
},
});
console.log("got existing", existing);
console.log("config", input.config);
let parsedConfig;
try {
parsedConfig = JSON.parse(input.config) as OpenAIChatConfig;
} catch (e) {
throw new Error(`Invalid JSON: ${(e as Error).message}`);
}
if (!existing) {
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
}
// Create a duplicate with only the config changed
const newVariant = await prisma.promptVariant.create({
data: {
experimentId: existing.experimentId,
label: existing.label,
sortIndex: existing.sortIndex,
config: parsedConfig,
},
});
await prisma.promptVariant.update({
where: {
id: input.id,
},
data: {
visible: false,
},
});
return newVariant;
}),
});

10
src/server/types.ts Normal file
View File

@@ -0,0 +1,10 @@
export type JSONSerializable =
| string
| number
| boolean
| null
| JSONSerializable[]
| { [key: string]: JSONSerializable };
// Placeholder for now
export type OpenAIChatConfig = NonNullable<JSONSerializable>;

View File

@@ -1,10 +1,4 @@
export type JSONSerializable =
| string
| number
| boolean
| null
| JSONSerializable[]
| { [key: string]: JSONSerializable };
import { JSONSerializable } from "../types";
export type VariableMap = Record<string, string>;

View File

@@ -1,4 +1,4 @@
import { JSONSerializable } from "./fillTemplate";
import { JSONSerializable } from "../types";
export async function getChatCompletion(payload: JSONSerializable, apiKey: string) {
const response = await fetch("https://api.openai.com/v1/chat/completions", {

View File

@@ -1,4 +1,5 @@
import { useRouter } from "next/router";
import { useCallback, useEffect, useState } from "react";
import { api } from "~/utils/api";
export const useExperiment = () => {
@@ -10,3 +11,37 @@ export const useExperiment = () => {
return experiment;
};
export function useHandledAsyncCallback<T extends (...args: any[]) => Promise<any>>(
callback: T,
deps: React.DependencyList
) {
const [loading, setLoading] = useState(false);
const [error, setError] = useState<Error | null>(null);
const wrappedCallback = useCallback((...args: Parameters<T>) => {
setLoading(true);
setError(null);
callback(...args)
.catch((error) => {
setError(error as Error);
console.error(error);
})
.finally(() => {
setLoading(false);
});
}, deps);
return [wrappedCallback, loading, error] as const;
}
// Have to do this ugly thing to convince Next not to try to access `navigator`
// on the server side at build time, when it isn't defined.
export const useModifierKeyLabel = () => {
const [label, setLabel] = useState("");
useEffect(() => {
setLabel(navigator?.platform?.startsWith("Mac") ? "⌘" : "Ctrl");
}, []);
return label;
};