From cde22ac4bfe39a12a23ffcb9064bdf3da5ad2429 Mon Sep 17 00:00:00 2001 From: Kyle Corbitt Date: Fri, 23 Jun 2023 23:04:50 -0700 Subject: [PATCH] add prompt variants --- .../OutputsTable/NewVariantButton.tsx | 37 ++++++++++ src/components/OutputsTable/OutputCell.tsx | 14 +++- .../OutputsTable/VariantConfigEditor.tsx | 69 ++++++++++++++----- src/components/OutputsTable/VariantHeader.tsx | 2 +- src/components/OutputsTable/index.tsx | 29 ++++++-- .../api/routers/promptVariants.router.ts | 31 +++++++++ 6 files changed, 158 insertions(+), 24 deletions(-) create mode 100644 src/components/OutputsTable/NewVariantButton.tsx diff --git a/src/components/OutputsTable/NewVariantButton.tsx b/src/components/OutputsTable/NewVariantButton.tsx new file mode 100644 index 0000000..c853b1f --- /dev/null +++ b/src/components/OutputsTable/NewVariantButton.tsx @@ -0,0 +1,37 @@ +import { Button, Tooltip } from "@chakra-ui/react"; +import { BsPlus } from "react-icons/bs"; +import { api } from "~/utils/api"; +import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; + +export default function NewVariantButton() { + const experiment = useExperiment(); + const mutation = api.promptVariants.create.useMutation(); + const utils = api.useContext(); + + const [onClick] = useHandledAsyncCallback(async () => { + if (!experiment.data) return; + await mutation.mutateAsync({ + experimentId: experiment.data.id, + }); + await utils.promptVariants.list.invalidate(); + }, [mutation]); + + return ( + + + + ); +} diff --git a/src/components/OutputsTable/OutputCell.tsx b/src/components/OutputsTable/OutputCell.tsx index 6f8b903..0d54f8e 100644 --- a/src/components/OutputsTable/OutputCell.tsx +++ b/src/components/OutputsTable/OutputCell.tsx @@ -2,6 +2,7 @@ import { api } from "~/utils/api"; import { PromptVariant, Scenario } from "./types"; import { Center, Spinner, Text } from "@chakra-ui/react"; import { useExperiment } from "~/utils/hooks"; +import { JSONSerializable } from "~/server/types"; export default function OutputCell({ scenario, @@ -18,18 +19,25 @@ export default function OutputCell({ experimentVariables.length === 0 || experimentVariables.some((v) => scenarioVariables[v] !== undefined); + let disabledReason: string | null = null; + + if (!templateHasVariables) disabledReason = "Add a scenario variable to see output"; + + if (variant.config === null || Object.keys(variant.config).length === 0) + disabledReason = "Save your prompt variant to see output"; + const output = api.outputs.get.useQuery( { scenarioId: scenario.id, variantId: variant.id, }, - { enabled: templateHasVariables } + { enabled: disabledReason === null } ); - if (!templateHasVariables) + if (disabledReason) return (
- Add a scenario variable to see output + {disabledReason}
); diff --git a/src/components/OutputsTable/VariantConfigEditor.tsx b/src/components/OutputsTable/VariantConfigEditor.tsx index fe21eff..4dcd0a0 100644 --- a/src/components/OutputsTable/VariantConfigEditor.tsx +++ b/src/components/OutputsTable/VariantConfigEditor.tsx @@ -1,19 +1,24 @@ -import { Box, Button, HStack, Tooltip } from "@chakra-ui/react"; +import { Box, Button, HStack, Tooltip, useToast } from "@chakra-ui/react"; import { useMonaco } from "@monaco-editor/react"; -import { useRef, useEffect, useState, useCallback } from "react"; +import { useRef, useEffect, useState, useCallback, useMemo } from "react"; import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks"; +import { PromptVariant } from "./types"; +import { JSONSerializable } from "~/server/types"; +import { api } from "~/utils/api"; let isThemeDefined = false; -export default function VariantConfigEditor(props: { - savedConfig: string; - onSave: (currentConfig: string) => Promise; -}) { +export default function VariantConfigEditor(props: { variant: PromptVariant }) { const monaco = useMonaco(); const editorRef = useRef["editor"]["create"]> | null>(null); const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`); const [isChanged, setIsChanged] = useState(false); - const savedConfigRef = useRef(props.savedConfig); + + const savedConfig = useMemo( + () => JSON.stringify(props.variant.config, null, 2), + [props.variant.config] + ); + const savedConfigRef = useRef(savedConfig); const modifierKey = useModifierKeyLabel(); @@ -23,12 +28,44 @@ export default function VariantConfigEditor(props: { setIsChanged(currentConfig !== savedConfigRef.current); }, []); + const replaceWithConfig = api.promptVariants.replaceWithConfig.useMutation(); + const utils = api.useContext(); + const toast = useToast(); + const [onSave] = useHandledAsyncCallback(async () => { const currentConfig = editorRef.current?.getValue(); if (!currentConfig) return; - await props.onSave(currentConfig); + + let parsedConfig: JSONSerializable; + try { + parsedConfig = JSON.parse(currentConfig) as JSONSerializable; + } catch (e) { + toast({ + title: "Invalid JSON", + description: "Please fix the JSON before saving.", + status: "error", + }); + return; + } + + if (parsedConfig === null) { + toast({ + title: "Invalid JSON", + description: "Please fix the JSON before saving.", + status: "error", + }); + return; + } + + await replaceWithConfig.mutateAsync({ + id: props.variant.id, + config: currentConfig, + }); + + await utils.promptVariants.list.invalidate(); + checkForChanges(); - }, [props.onSave, checkForChanges]); + }, [checkForChanges]); useEffect(() => { if (monaco) { @@ -47,7 +84,7 @@ export default function VariantConfigEditor(props: { const container = document.getElementById(editorId) as HTMLElement; editorRef.current = monaco.editor.create(container, { - value: props.savedConfig, + value: savedConfig, language: "json", theme: "customTheme", lineNumbers: "off", @@ -93,16 +130,16 @@ export default function VariantConfigEditor(props: { }, [monaco, editorId]); useEffect(() => { - const savedConfigChanged = savedConfigRef.current !== props.savedConfig; + const savedConfigChanged = savedConfigRef.current !== savedConfig; - savedConfigRef.current = props.savedConfig; + savedConfigRef.current = savedConfig; - if (savedConfigChanged && editorRef.current?.getValue() !== props.savedConfig) { - editorRef.current?.setValue(props.savedConfig); + if (savedConfigChanged && editorRef.current?.getValue() !== savedConfig) { + editorRef.current?.setValue(savedConfig); } checkForChanges(); - }, [props.savedConfig, checkForChanges]); + }, [savedConfig, checkForChanges]); return ( @@ -113,7 +150,7 @@ export default function VariantConfigEditor(props: { colorScheme="gray" size="sm" onClick={() => { - editorRef.current?.setValue(props.savedConfig); + editorRef.current?.setValue(savedConfig); checkForChanges(); }} borderRadius={0} diff --git a/src/components/OutputsTable/VariantHeader.tsx b/src/components/OutputsTable/VariantHeader.tsx index c45aefb..79172c5 100644 --- a/src/components/OutputsTable/VariantHeader.tsx +++ b/src/components/OutputsTable/VariantHeader.tsx @@ -47,7 +47,7 @@ export default function VariantHeader({ variant }: { variant: PromptVariant }) { return ( - + ); } diff --git a/src/components/OutputsTable/index.tsx b/src/components/OutputsTable/index.tsx index 429c01f..31a08a6 100644 --- a/src/components/OutputsTable/index.tsx +++ b/src/components/OutputsTable/index.tsx @@ -6,6 +6,9 @@ import ScenarioHeader from "./ScenarioHeader"; import React, { useState } from "react"; import { Box, Grid, GridItem, Heading } from "@chakra-ui/react"; import NewScenarioButton from "./NewScenarioButton"; +import NewVariantButton from "./NewVariantButton"; +import EditableVariantLabel from "./EditableVariantLabel"; +import VariantConfigEditor from "./VariantConfigEditor"; const ScenarioRow = (props: { scenario: Scenario; variants: PromptVariant[] }) => { const [isHovered, setIsHovered] = useState(false); @@ -54,8 +57,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string | *": { borderColor: "gray.300", @@ -69,14 +71,33 @@ export default function OutputsTable({ experimentId }: { experimentId: string | }, }} > - + Scenario + {variants.data.map((variant) => ( + + + + ))} + + + + {variants.data.map((variant) => ( - + ))} {scenarios.data.map((scenario) => ( diff --git a/src/server/api/routers/promptVariants.router.ts b/src/server/api/routers/promptVariants.router.ts index e896817..c673fd6 100644 --- a/src/server/api/routers/promptVariants.router.ts +++ b/src/server/api/routers/promptVariants.router.ts @@ -16,6 +16,37 @@ export const promptVariantsRouter = createTRPCRouter({ }); }), + create: publicProcedure + .input( + z.object({ + experimentId: z.string(), + }) + ) + .mutation(async ({ input }) => { + const maxSortIndex = + ( + await prisma.promptVariant.aggregate({ + where: { + experimentId: input.experimentId, + }, + _max: { + sortIndex: true, + }, + }) + )._max.sortIndex ?? 0; + + const newScenario = await prisma.promptVariant.create({ + data: { + experimentId: input.experimentId, + label: `Prompt Variant ${maxSortIndex + 1}`, + sortIndex: maxSortIndex + 1, + config: {}, + }, + }); + + return newScenario; + }), + update: publicProcedure .input( z.object({