add prompt variants
This commit is contained in:
37
src/components/OutputsTable/NewVariantButton.tsx
Normal file
37
src/components/OutputsTable/NewVariantButton.tsx
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
import { Button, Tooltip } from "@chakra-ui/react";
|
||||||
|
import { BsPlus } from "react-icons/bs";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
|
||||||
|
export default function NewVariantButton() {
|
||||||
|
const experiment = useExperiment();
|
||||||
|
const mutation = api.promptVariants.create.useMutation();
|
||||||
|
const utils = api.useContext();
|
||||||
|
|
||||||
|
const [onClick] = useHandledAsyncCallback(async () => {
|
||||||
|
if (!experiment.data) return;
|
||||||
|
await mutation.mutateAsync({
|
||||||
|
experimentId: experiment.data.id,
|
||||||
|
});
|
||||||
|
await utils.promptVariants.list.invalidate();
|
||||||
|
}, [mutation]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Tooltip label="Add Prompt Variant" placement="right">
|
||||||
|
<Button
|
||||||
|
w="100%"
|
||||||
|
borderRadius={0}
|
||||||
|
alignItems="flex-start"
|
||||||
|
justifyContent="center"
|
||||||
|
fontWeight="normal"
|
||||||
|
bgColor="blue.100"
|
||||||
|
_hover={{ bgColor: "blue.200" }}
|
||||||
|
py={2}
|
||||||
|
px={0}
|
||||||
|
onClick={onClick}
|
||||||
|
>
|
||||||
|
<BsPlus size={24} />
|
||||||
|
</Button>
|
||||||
|
</Tooltip>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ import { api } from "~/utils/api";
|
|||||||
import { PromptVariant, Scenario } from "./types";
|
import { PromptVariant, Scenario } from "./types";
|
||||||
import { Center, Spinner, Text } from "@chakra-ui/react";
|
import { Center, Spinner, Text } from "@chakra-ui/react";
|
||||||
import { useExperiment } from "~/utils/hooks";
|
import { useExperiment } from "~/utils/hooks";
|
||||||
|
import { JSONSerializable } from "~/server/types";
|
||||||
|
|
||||||
export default function OutputCell({
|
export default function OutputCell({
|
||||||
scenario,
|
scenario,
|
||||||
@@ -18,18 +19,25 @@ export default function OutputCell({
|
|||||||
experimentVariables.length === 0 ||
|
experimentVariables.length === 0 ||
|
||||||
experimentVariables.some((v) => scenarioVariables[v] !== undefined);
|
experimentVariables.some((v) => scenarioVariables[v] !== undefined);
|
||||||
|
|
||||||
|
let disabledReason: string | null = null;
|
||||||
|
|
||||||
|
if (!templateHasVariables) disabledReason = "Add a scenario variable to see output";
|
||||||
|
|
||||||
|
if (variant.config === null || Object.keys(variant.config).length === 0)
|
||||||
|
disabledReason = "Save your prompt variant to see output";
|
||||||
|
|
||||||
const output = api.outputs.get.useQuery(
|
const output = api.outputs.get.useQuery(
|
||||||
{
|
{
|
||||||
scenarioId: scenario.id,
|
scenarioId: scenario.id,
|
||||||
variantId: variant.id,
|
variantId: variant.id,
|
||||||
},
|
},
|
||||||
{ enabled: templateHasVariables }
|
{ enabled: disabledReason === null }
|
||||||
);
|
);
|
||||||
|
|
||||||
if (!templateHasVariables)
|
if (disabledReason)
|
||||||
return (
|
return (
|
||||||
<Center h="100%">
|
<Center h="100%">
|
||||||
<Text color="gray.500">Add a scenario variable to see output</Text>
|
<Text color="gray.500">{disabledReason}</Text>
|
||||||
</Center>
|
</Center>
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -1,19 +1,24 @@
|
|||||||
import { Box, Button, HStack, Tooltip } from "@chakra-ui/react";
|
import { Box, Button, HStack, Tooltip, useToast } from "@chakra-ui/react";
|
||||||
import { useMonaco } from "@monaco-editor/react";
|
import { useMonaco } from "@monaco-editor/react";
|
||||||
import { useRef, useEffect, useState, useCallback } from "react";
|
import { useRef, useEffect, useState, useCallback, useMemo } from "react";
|
||||||
import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
|
import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
|
||||||
|
import { PromptVariant } from "./types";
|
||||||
|
import { JSONSerializable } from "~/server/types";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
|
||||||
let isThemeDefined = false;
|
let isThemeDefined = false;
|
||||||
|
|
||||||
export default function VariantConfigEditor(props: {
|
export default function VariantConfigEditor(props: { variant: PromptVariant }) {
|
||||||
savedConfig: string;
|
|
||||||
onSave: (currentConfig: string) => Promise<void>;
|
|
||||||
}) {
|
|
||||||
const monaco = useMonaco();
|
const monaco = useMonaco();
|
||||||
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
|
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
|
||||||
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
|
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
|
||||||
const [isChanged, setIsChanged] = useState(false);
|
const [isChanged, setIsChanged] = useState(false);
|
||||||
const savedConfigRef = useRef(props.savedConfig);
|
|
||||||
|
const savedConfig = useMemo(
|
||||||
|
() => JSON.stringify(props.variant.config, null, 2),
|
||||||
|
[props.variant.config]
|
||||||
|
);
|
||||||
|
const savedConfigRef = useRef(savedConfig);
|
||||||
|
|
||||||
const modifierKey = useModifierKeyLabel();
|
const modifierKey = useModifierKeyLabel();
|
||||||
|
|
||||||
@@ -23,12 +28,44 @@ export default function VariantConfigEditor(props: {
|
|||||||
setIsChanged(currentConfig !== savedConfigRef.current);
|
setIsChanged(currentConfig !== savedConfigRef.current);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
const replaceWithConfig = api.promptVariants.replaceWithConfig.useMutation();
|
||||||
|
const utils = api.useContext();
|
||||||
|
const toast = useToast();
|
||||||
|
|
||||||
const [onSave] = useHandledAsyncCallback(async () => {
|
const [onSave] = useHandledAsyncCallback(async () => {
|
||||||
const currentConfig = editorRef.current?.getValue();
|
const currentConfig = editorRef.current?.getValue();
|
||||||
if (!currentConfig) return;
|
if (!currentConfig) return;
|
||||||
await props.onSave(currentConfig);
|
|
||||||
|
let parsedConfig: JSONSerializable;
|
||||||
|
try {
|
||||||
|
parsedConfig = JSON.parse(currentConfig) as JSONSerializable;
|
||||||
|
} catch (e) {
|
||||||
|
toast({
|
||||||
|
title: "Invalid JSON",
|
||||||
|
description: "Please fix the JSON before saving.",
|
||||||
|
status: "error",
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parsedConfig === null) {
|
||||||
|
toast({
|
||||||
|
title: "Invalid JSON",
|
||||||
|
description: "Please fix the JSON before saving.",
|
||||||
|
status: "error",
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
await replaceWithConfig.mutateAsync({
|
||||||
|
id: props.variant.id,
|
||||||
|
config: currentConfig,
|
||||||
|
});
|
||||||
|
|
||||||
|
await utils.promptVariants.list.invalidate();
|
||||||
|
|
||||||
checkForChanges();
|
checkForChanges();
|
||||||
}, [props.onSave, checkForChanges]);
|
}, [checkForChanges]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (monaco) {
|
if (monaco) {
|
||||||
@@ -47,7 +84,7 @@ export default function VariantConfigEditor(props: {
|
|||||||
const container = document.getElementById(editorId) as HTMLElement;
|
const container = document.getElementById(editorId) as HTMLElement;
|
||||||
|
|
||||||
editorRef.current = monaco.editor.create(container, {
|
editorRef.current = monaco.editor.create(container, {
|
||||||
value: props.savedConfig,
|
value: savedConfig,
|
||||||
language: "json",
|
language: "json",
|
||||||
theme: "customTheme",
|
theme: "customTheme",
|
||||||
lineNumbers: "off",
|
lineNumbers: "off",
|
||||||
@@ -93,16 +130,16 @@ export default function VariantConfigEditor(props: {
|
|||||||
}, [monaco, editorId]);
|
}, [monaco, editorId]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const savedConfigChanged = savedConfigRef.current !== props.savedConfig;
|
const savedConfigChanged = savedConfigRef.current !== savedConfig;
|
||||||
|
|
||||||
savedConfigRef.current = props.savedConfig;
|
savedConfigRef.current = savedConfig;
|
||||||
|
|
||||||
if (savedConfigChanged && editorRef.current?.getValue() !== props.savedConfig) {
|
if (savedConfigChanged && editorRef.current?.getValue() !== savedConfig) {
|
||||||
editorRef.current?.setValue(props.savedConfig);
|
editorRef.current?.setValue(savedConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
checkForChanges();
|
checkForChanges();
|
||||||
}, [props.savedConfig, checkForChanges]);
|
}, [savedConfig, checkForChanges]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box w="100%" pos="relative">
|
<Box w="100%" pos="relative">
|
||||||
@@ -113,7 +150,7 @@ export default function VariantConfigEditor(props: {
|
|||||||
colorScheme="gray"
|
colorScheme="gray"
|
||||||
size="sm"
|
size="sm"
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
editorRef.current?.setValue(props.savedConfig);
|
editorRef.current?.setValue(savedConfig);
|
||||||
checkForChanges();
|
checkForChanges();
|
||||||
}}
|
}}
|
||||||
borderRadius={0}
|
borderRadius={0}
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ export default function VariantHeader({ variant }: { variant: PromptVariant }) {
|
|||||||
return (
|
return (
|
||||||
<Stack w="100%">
|
<Stack w="100%">
|
||||||
<EditableVariantLabel variant={variant} />
|
<EditableVariantLabel variant={variant} />
|
||||||
<VariantConfigEditor savedConfig={JSON.stringify(variant.config, null, 2)} onSave={onSave} />
|
<VariantConfigEditor variant={variant} />
|
||||||
</Stack>
|
</Stack>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,9 @@ import ScenarioHeader from "./ScenarioHeader";
|
|||||||
import React, { useState } from "react";
|
import React, { useState } from "react";
|
||||||
import { Box, Grid, GridItem, Heading } from "@chakra-ui/react";
|
import { Box, Grid, GridItem, Heading } from "@chakra-ui/react";
|
||||||
import NewScenarioButton from "./NewScenarioButton";
|
import NewScenarioButton from "./NewScenarioButton";
|
||||||
|
import NewVariantButton from "./NewVariantButton";
|
||||||
|
import EditableVariantLabel from "./EditableVariantLabel";
|
||||||
|
import VariantConfigEditor from "./VariantConfigEditor";
|
||||||
|
|
||||||
const ScenarioRow = (props: { scenario: Scenario; variants: PromptVariant[] }) => {
|
const ScenarioRow = (props: { scenario: Scenario; variants: PromptVariant[] }) => {
|
||||||
const [isHovered, setIsHovered] = useState(false);
|
const [isHovered, setIsHovered] = useState(false);
|
||||||
@@ -54,8 +57,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
|||||||
<Grid
|
<Grid
|
||||||
p={4}
|
p={4}
|
||||||
display="grid"
|
display="grid"
|
||||||
gridTemplateColumns={`200px repeat(${variants.data.length}, minmax(300px, 1fr))`}
|
gridTemplateColumns={`200px repeat(${variants.data.length}, minmax(300px, 1fr)) 40px`}
|
||||||
overflowX="auto"
|
|
||||||
sx={{
|
sx={{
|
||||||
"> *": {
|
"> *": {
|
||||||
borderColor: "gray.300",
|
borderColor: "gray.300",
|
||||||
@@ -69,14 +71,33 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
|||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<GridItem display="flex" alignItems="flex-end">
|
<GridItem display="flex" alignItems="flex-end" rowSpan={2}>
|
||||||
<Heading size="md" fontWeight="bold">
|
<Heading size="md" fontWeight="bold">
|
||||||
Scenario
|
Scenario
|
||||||
</Heading>
|
</Heading>
|
||||||
</GridItem>
|
</GridItem>
|
||||||
|
{variants.data.map((variant) => (
|
||||||
|
<GridItem
|
||||||
|
key={variant.uiId}
|
||||||
|
padding={0}
|
||||||
|
sx={{ position: "sticky", top: 0, backgroundColor: "#fff", zIndex: 1 }}
|
||||||
|
>
|
||||||
|
<EditableVariantLabel variant={variant} />
|
||||||
|
</GridItem>
|
||||||
|
))}
|
||||||
|
<GridItem
|
||||||
|
borderBottomWidth={0}
|
||||||
|
rowSpan={scenarios.data.length + 1}
|
||||||
|
padding={0}
|
||||||
|
borderRightWidth={0}
|
||||||
|
sx={{ position: "sticky", top: 0, backgroundColor: "#fff", zIndex: 1 }}
|
||||||
|
>
|
||||||
|
<NewVariantButton />
|
||||||
|
</GridItem>
|
||||||
|
|
||||||
{variants.data.map((variant) => (
|
{variants.data.map((variant) => (
|
||||||
<GridItem key={variant.uiId} padding={0}>
|
<GridItem key={variant.uiId} padding={0}>
|
||||||
<VariantHeader key={variant.uiId} variant={variant} />
|
<VariantConfigEditor variant={variant} />
|
||||||
</GridItem>
|
</GridItem>
|
||||||
))}
|
))}
|
||||||
{scenarios.data.map((scenario) => (
|
{scenarios.data.map((scenario) => (
|
||||||
|
|||||||
@@ -16,6 +16,37 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
}),
|
}),
|
||||||
|
|
||||||
|
create: publicProcedure
|
||||||
|
.input(
|
||||||
|
z.object({
|
||||||
|
experimentId: z.string(),
|
||||||
|
})
|
||||||
|
)
|
||||||
|
.mutation(async ({ input }) => {
|
||||||
|
const maxSortIndex =
|
||||||
|
(
|
||||||
|
await prisma.promptVariant.aggregate({
|
||||||
|
where: {
|
||||||
|
experimentId: input.experimentId,
|
||||||
|
},
|
||||||
|
_max: {
|
||||||
|
sortIndex: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
)._max.sortIndex ?? 0;
|
||||||
|
|
||||||
|
const newScenario = await prisma.promptVariant.create({
|
||||||
|
data: {
|
||||||
|
experimentId: input.experimentId,
|
||||||
|
label: `Prompt Variant ${maxSortIndex + 1}`,
|
||||||
|
sortIndex: maxSortIndex + 1,
|
||||||
|
config: {},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
return newScenario;
|
||||||
|
}),
|
||||||
|
|
||||||
update: publicProcedure
|
update: publicProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
|
|||||||
Reference in New Issue
Block a user