add prompt variants

This commit is contained in:
Kyle Corbitt
2023-06-23 23:04:50 -07:00
parent 87154fd4b7
commit cde22ac4bf
6 changed files with 158 additions and 24 deletions

View File

@@ -0,0 +1,37 @@
import { Button, Tooltip } from "@chakra-ui/react";
import { BsPlus } from "react-icons/bs";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
export default function NewVariantButton() {
const experiment = useExperiment();
const mutation = api.promptVariants.create.useMutation();
const utils = api.useContext();
const [onClick] = useHandledAsyncCallback(async () => {
if (!experiment.data) return;
await mutation.mutateAsync({
experimentId: experiment.data.id,
});
await utils.promptVariants.list.invalidate();
}, [mutation]);
return (
<Tooltip label="Add Prompt Variant" placement="right">
<Button
w="100%"
borderRadius={0}
alignItems="flex-start"
justifyContent="center"
fontWeight="normal"
bgColor="blue.100"
_hover={{ bgColor: "blue.200" }}
py={2}
px={0}
onClick={onClick}
>
<BsPlus size={24} />
</Button>
</Tooltip>
);
}

View File

@@ -2,6 +2,7 @@ import { api } from "~/utils/api";
import { PromptVariant, Scenario } from "./types";
import { Center, Spinner, Text } from "@chakra-ui/react";
import { useExperiment } from "~/utils/hooks";
import { JSONSerializable } from "~/server/types";
export default function OutputCell({
scenario,
@@ -18,18 +19,25 @@ export default function OutputCell({
experimentVariables.length === 0 ||
experimentVariables.some((v) => scenarioVariables[v] !== undefined);
let disabledReason: string | null = null;
if (!templateHasVariables) disabledReason = "Add a scenario variable to see output";
if (variant.config === null || Object.keys(variant.config).length === 0)
disabledReason = "Save your prompt variant to see output";
const output = api.outputs.get.useQuery(
{
scenarioId: scenario.id,
variantId: variant.id,
},
{ enabled: templateHasVariables }
{ enabled: disabledReason === null }
);
if (!templateHasVariables)
if (disabledReason)
return (
<Center h="100%">
<Text color="gray.500">Add a scenario variable to see output</Text>
<Text color="gray.500">{disabledReason}</Text>
</Center>
);

View File

@@ -1,19 +1,24 @@
import { Box, Button, HStack, Tooltip } from "@chakra-ui/react";
import { Box, Button, HStack, Tooltip, useToast } from "@chakra-ui/react";
import { useMonaco } from "@monaco-editor/react";
import { useRef, useEffect, useState, useCallback } from "react";
import { useRef, useEffect, useState, useCallback, useMemo } from "react";
import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
import { PromptVariant } from "./types";
import { JSONSerializable } from "~/server/types";
import { api } from "~/utils/api";
let isThemeDefined = false;
export default function VariantConfigEditor(props: {
savedConfig: string;
onSave: (currentConfig: string) => Promise<void>;
}) {
export default function VariantConfigEditor(props: { variant: PromptVariant }) {
const monaco = useMonaco();
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
const [isChanged, setIsChanged] = useState(false);
const savedConfigRef = useRef(props.savedConfig);
const savedConfig = useMemo(
() => JSON.stringify(props.variant.config, null, 2),
[props.variant.config]
);
const savedConfigRef = useRef(savedConfig);
const modifierKey = useModifierKeyLabel();
@@ -23,12 +28,44 @@ export default function VariantConfigEditor(props: {
setIsChanged(currentConfig !== savedConfigRef.current);
}, []);
const replaceWithConfig = api.promptVariants.replaceWithConfig.useMutation();
const utils = api.useContext();
const toast = useToast();
const [onSave] = useHandledAsyncCallback(async () => {
const currentConfig = editorRef.current?.getValue();
if (!currentConfig) return;
await props.onSave(currentConfig);
let parsedConfig: JSONSerializable;
try {
parsedConfig = JSON.parse(currentConfig) as JSONSerializable;
} catch (e) {
toast({
title: "Invalid JSON",
description: "Please fix the JSON before saving.",
status: "error",
});
return;
}
if (parsedConfig === null) {
toast({
title: "Invalid JSON",
description: "Please fix the JSON before saving.",
status: "error",
});
return;
}
await replaceWithConfig.mutateAsync({
id: props.variant.id,
config: currentConfig,
});
await utils.promptVariants.list.invalidate();
checkForChanges();
}, [props.onSave, checkForChanges]);
}, [checkForChanges]);
useEffect(() => {
if (monaco) {
@@ -47,7 +84,7 @@ export default function VariantConfigEditor(props: {
const container = document.getElementById(editorId) as HTMLElement;
editorRef.current = monaco.editor.create(container, {
value: props.savedConfig,
value: savedConfig,
language: "json",
theme: "customTheme",
lineNumbers: "off",
@@ -93,16 +130,16 @@ export default function VariantConfigEditor(props: {
}, [monaco, editorId]);
useEffect(() => {
const savedConfigChanged = savedConfigRef.current !== props.savedConfig;
const savedConfigChanged = savedConfigRef.current !== savedConfig;
savedConfigRef.current = props.savedConfig;
savedConfigRef.current = savedConfig;
if (savedConfigChanged && editorRef.current?.getValue() !== props.savedConfig) {
editorRef.current?.setValue(props.savedConfig);
if (savedConfigChanged && editorRef.current?.getValue() !== savedConfig) {
editorRef.current?.setValue(savedConfig);
}
checkForChanges();
}, [props.savedConfig, checkForChanges]);
}, [savedConfig, checkForChanges]);
return (
<Box w="100%" pos="relative">
@@ -113,7 +150,7 @@ export default function VariantConfigEditor(props: {
colorScheme="gray"
size="sm"
onClick={() => {
editorRef.current?.setValue(props.savedConfig);
editorRef.current?.setValue(savedConfig);
checkForChanges();
}}
borderRadius={0}

View File

@@ -47,7 +47,7 @@ export default function VariantHeader({ variant }: { variant: PromptVariant }) {
return (
<Stack w="100%">
<EditableVariantLabel variant={variant} />
<VariantConfigEditor savedConfig={JSON.stringify(variant.config, null, 2)} onSave={onSave} />
<VariantConfigEditor variant={variant} />
</Stack>
);
}

View File

@@ -6,6 +6,9 @@ import ScenarioHeader from "./ScenarioHeader";
import React, { useState } from "react";
import { Box, Grid, GridItem, Heading } from "@chakra-ui/react";
import NewScenarioButton from "./NewScenarioButton";
import NewVariantButton from "./NewVariantButton";
import EditableVariantLabel from "./EditableVariantLabel";
import VariantConfigEditor from "./VariantConfigEditor";
const ScenarioRow = (props: { scenario: Scenario; variants: PromptVariant[] }) => {
const [isHovered, setIsHovered] = useState(false);
@@ -54,8 +57,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
<Grid
p={4}
display="grid"
gridTemplateColumns={`200px repeat(${variants.data.length}, minmax(300px, 1fr))`}
overflowX="auto"
gridTemplateColumns={`200px repeat(${variants.data.length}, minmax(300px, 1fr)) 40px`}
sx={{
"> *": {
borderColor: "gray.300",
@@ -69,14 +71,33 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
},
}}
>
<GridItem display="flex" alignItems="flex-end">
<GridItem display="flex" alignItems="flex-end" rowSpan={2}>
<Heading size="md" fontWeight="bold">
Scenario
</Heading>
</GridItem>
{variants.data.map((variant) => (
<GridItem
key={variant.uiId}
padding={0}
sx={{ position: "sticky", top: 0, backgroundColor: "#fff", zIndex: 1 }}
>
<EditableVariantLabel variant={variant} />
</GridItem>
))}
<GridItem
borderBottomWidth={0}
rowSpan={scenarios.data.length + 1}
padding={0}
borderRightWidth={0}
sx={{ position: "sticky", top: 0, backgroundColor: "#fff", zIndex: 1 }}
>
<NewVariantButton />
</GridItem>
{variants.data.map((variant) => (
<GridItem key={variant.uiId} padding={0}>
<VariantHeader key={variant.uiId} variant={variant} />
<VariantConfigEditor variant={variant} />
</GridItem>
))}
{scenarios.data.map((scenario) => (

View File

@@ -16,6 +16,37 @@ export const promptVariantsRouter = createTRPCRouter({
});
}),
create: publicProcedure
.input(
z.object({
experimentId: z.string(),
})
)
.mutation(async ({ input }) => {
const maxSortIndex =
(
await prisma.promptVariant.aggregate({
where: {
experimentId: input.experimentId,
},
_max: {
sortIndex: true,
},
})
)._max.sortIndex ?? 0;
const newScenario = await prisma.promptVariant.create({
data: {
experimentId: input.experimentId,
label: `Prompt Variant ${maxSortIndex + 1}`,
sortIndex: maxSortIndex + 1,
config: {},
},
});
return newScenario;
}),
update: publicProcedure
.input(
z.object({