Allow user to duplicate prompt (#57)

* Add dropdown header for model switching

* Allow variant duplication

* Fix prettier
This commit is contained in:
arcticfly
2023-07-18 13:49:33 -07:00
committed by GitHub
parent 999a4c08fa
commit fa5b1ab1c5
4 changed files with 220 additions and 129 deletions

View File

@@ -2,11 +2,25 @@ import { useState, type DragEvent } from "react";
import { type PromptVariant } from "./types"; import { type PromptVariant } from "./types";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks"; import { useHandledAsyncCallback } from "~/utils/hooks";
import { Button, HStack, Icon, Tooltip } from "@chakra-ui/react"; // Changed here import {
import { BsX } from "react-icons/bs"; Button,
import { RiDraggable } from "react-icons/ri"; HStack,
Icon,
Menu,
MenuButton,
MenuItem,
MenuList,
MenuDivider,
Text,
GridItem,
Spinner,
} from "@chakra-ui/react"; // Changed here
import { BsFillTrashFill, BsGear } from "react-icons/bs";
import { FaRegClone } from "react-icons/fa";
import { RiDraggable, RiExchangeFundsFill } from "react-icons/ri";
import { cellPadding, headerMinHeight } from "../constants"; import { cellPadding, headerMinHeight } from "../constants";
import AutoResizeTextArea from "../AutoResizeTextArea"; import AutoResizeTextArea from "../AutoResizeTextArea";
import { stickyHeaderStyle } from "./styles";
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) { export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
const utils = api.useContext(); const utils = api.useContext();
@@ -49,59 +63,109 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
[reorderMutation, props.variant.id], [reorderMutation, props.variant.id],
); );
const [menuOpen, setMenuOpen] = useState(false);
const duplicateMutation = api.promptVariants.create.useMutation();
const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => {
await duplicateMutation.mutateAsync({
experimentId: props.variant.experimentId,
variantId: props.variant.id,
});
await utils.promptVariants.list.invalidate();
}, [duplicateMutation, props.variant.experimentId, props.variant.id]);
return ( return (
<HStack <GridItem
spacing={4} padding={0}
alignItems="center" sx={{
minH={headerMinHeight} ...stickyHeaderStyle,
draggable={!isInputHovered} zIndex: menuOpen ? "dropdown" : stickyHeaderStyle.zIndex,
onDragStart={(e) => {
e.dataTransfer.setData("text/plain", props.variant.id);
e.currentTarget.style.opacity = "0.4";
}} }}
onDragEnd={(e) => { borderTopWidth={1}
e.currentTarget.style.opacity = "1";
}}
onDragOver={(e) => {
e.preventDefault();
setIsDragTarget(true);
}}
onDragLeave={() => {
setIsDragTarget(false);
}}
onDrop={onReorder}
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
> >
<Icon <HStack
as={RiDraggable} spacing={4}
boxSize={6} alignItems="flex-start"
color="gray.400" minH={headerMinHeight}
_hover={{ color: "gray.800", cursor: "pointer" }} draggable={!isInputHovered}
/> onDragStart={(e) => {
<AutoResizeTextArea // Changed to Input e.dataTransfer.setData("text/plain", props.variant.id);
size="sm" e.currentTarget.style.opacity = "0.4";
value={label} }}
onChange={(e) => setLabel(e.target.value)} onDragEnd={(e) => {
onBlur={onSaveLabel} e.currentTarget.style.opacity = "1";
placeholder="Variant Name" }}
borderWidth={1} onDragOver={(e) => {
borderColor="transparent" e.preventDefault();
fontWeight="bold" setIsDragTarget(true);
fontSize={16} }}
_hover={{ borderColor: "gray.300" }} onDragLeave={() => {
_focus={{ borderColor: "blue.500", outline: "none" }} setIsDragTarget(false);
flex={1} }}
px={cellPadding.x} onDrop={onReorder}
onMouseEnter={() => setIsInputHovered(true)} backgroundColor={isDragTarget ? "gray.100" : "transparent"}
onMouseLeave={() => setIsInputHovered(false)} >
/> <Icon
{props.canHide && ( as={RiDraggable}
<Tooltip label="Remove Variant" hasArrow> boxSize={6}
<Button variant="ghost" colorScheme="gray" size="sm" onClick={onHide}> mt={2}
<Icon as={BsX} boxSize={6} /> color="gray.400"
</Button> _hover={{ color: "gray.800", cursor: "pointer" }}
</Tooltip> />
)} <AutoResizeTextArea
</HStack> size="sm"
value={label}
onChange={(e) => setLabel(e.target.value)}
onBlur={onSaveLabel}
placeholder="Variant Name"
borderWidth={1}
borderColor="transparent"
fontWeight="bold"
fontSize={16}
_hover={{ borderColor: "gray.300" }}
_focus={{ borderColor: "blue.500", outline: "none" }}
flex={1}
px={cellPadding.x}
onMouseEnter={() => setIsInputHovered(true)}
onMouseLeave={() => setIsInputHovered(false)}
/>
<Menu
z-index="dropdown"
onOpen={() => setMenuOpen(true)}
onClose={() => setMenuOpen(false)}
>
{duplicationInProgress ? (
<Spinner boxSize={4} mx={3} my={3} />
) : (
<MenuButton>
<Button variant="ghost">
<Icon as={BsGear} />
</Button>
</MenuButton>
)}
<MenuList mt={-3} fontSize="md">
<MenuItem icon={<Icon as={FaRegClone} boxSize={4} w={5} />} onClick={duplicateVariant}>
Duplicate
</MenuItem>
<MenuItem icon={<Icon as={RiExchangeFundsFill} boxSize={5} />}>Change Model</MenuItem>
{props.canHide && (
<>
<MenuDivider />
<MenuItem
onClick={onHide}
icon={<Icon as={BsFillTrashFill} boxSize={5} />}
color="red.600"
_hover={{ backgroundColor: "red.50" }}
>
<Text>Hide</Text>
</MenuItem>
</>
)}
</MenuList>
</Menu>
</HStack>
</GridItem>
); );
} }

View File

@@ -43,9 +43,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
<ScenariosHeader headerRows={headerRows} numScenarios={scenarios.data.length} /> <ScenariosHeader headerRows={headerRows} numScenarios={scenarios.data.length} />
{variants.data.map((variant) => ( {variants.data.map((variant) => (
<GridItem key={variant.uiId} padding={0} sx={stickyHeaderStyle} borderTopWidth={1}> <VariantHeader key={variant.uiId} variant={variant} canHide={variants.data.length > 1} />
<VariantHeader variant={variant} canHide={variants.data.length > 1} />
</GridItem>
))} ))}
<GridItem <GridItem
rowSpan={scenarios.data.length + headerRows} rowSpan={scenarios.data.length + headerRows}

View File

@@ -9,6 +9,8 @@ import { constructPrompt } from "~/server/utils/constructPrompt";
import userError from "~/server/utils/error"; import userError from "~/server/utils/error";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated"; import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
import { calculateTokenCost } from "~/utils/calculateTokenCost"; import { calculateTokenCost } from "~/utils/calculateTokenCost";
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
import { type PromptVariant } from "@prisma/client";
export const promptVariantsRouter = createTRPCRouter({ export const promptVariantsRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
@@ -135,18 +137,28 @@ export const promptVariantsRouter = createTRPCRouter({
.input( .input(
z.object({ z.object({
experimentId: z.string(), experimentId: z.string(),
variantId: z.string().optional(),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input }) => {
const lastVariant = await prisma.promptVariant.findFirst({ let originalVariant: PromptVariant | null = null;
where: { if (input.variantId) {
experimentId: input.experimentId, originalVariant = await prisma.promptVariant.findUnique({
visible: true, where: {
}, id: input.variantId,
orderBy: { },
sortIndex: "desc", });
}, } else {
}); originalVariant = await prisma.promptVariant.findFirst({
where: {
experimentId: input.experimentId,
visible: true,
},
orderBy: {
sortIndex: "desc",
},
});
}
const largestSortIndex = const largestSortIndex =
( (
@@ -160,13 +172,18 @@ export const promptVariantsRouter = createTRPCRouter({
}) })
)._max?.sortIndex ?? 0; )._max?.sortIndex ?? 0;
const newVariantLabel =
input.variantId && originalVariant
? `${originalVariant?.label} Copy`
: `Prompt Variant ${largestSortIndex + 2}`;
const createNewVariantAction = prisma.promptVariant.create({ const createNewVariantAction = prisma.promptVariant.create({
data: { data: {
experimentId: input.experimentId, experimentId: input.experimentId,
label: `Prompt Variant ${largestSortIndex + 2}`, label: newVariantLabel,
sortIndex: (lastVariant?.sortIndex ?? 0) + 1, sortIndex: (originalVariant?.sortIndex ?? 0) + 1,
constructFn: constructFn:
lastVariant?.constructFn ?? originalVariant?.constructFn ??
dedent` dedent`
prompt = { prompt = {
model: "gpt-3.5-turbo", model: "gpt-3.5-turbo",
@@ -177,7 +194,7 @@ export const promptVariantsRouter = createTRPCRouter({
} }
] ]
}`, }`,
model: lastVariant?.model ?? "gpt-3.5-turbo", model: originalVariant?.model ?? "gpt-3.5-turbo",
}, },
}); });
@@ -186,6 +203,11 @@ export const promptVariantsRouter = createTRPCRouter({
recordExperimentUpdated(input.experimentId), recordExperimentUpdated(input.experimentId),
]); ]);
if (originalVariant) {
// Insert new variant to right of original variant
await reorderPromptVariants(newVariant.id, originalVariant.id, true);
}
const scenarios = await prisma.testScenario.findMany({ const scenarios = await prisma.testScenario.findMany({
where: { where: {
experimentId: input.experimentId, experimentId: input.experimentId,
@@ -338,64 +360,6 @@ export const promptVariantsRouter = createTRPCRouter({
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input }) => {
const dragged = await prisma.promptVariant.findUnique({ await reorderPromptVariants(input.draggedId, input.droppedId);
where: {
id: input.draggedId,
},
});
const dropped = await prisma.promptVariant.findUnique({
where: {
id: input.droppedId,
},
});
if (!dragged || !dropped || dragged.experimentId !== dropped.experimentId) {
throw new Error(
`Prompt Variant with id ${input.draggedId} or ${input.droppedId} does not exist`,
);
}
const visibleItems = await prisma.promptVariant.findMany({
where: {
experimentId: dragged.experimentId,
visible: true,
},
orderBy: {
sortIndex: "asc",
},
});
// Remove the dragged item from its current position
const orderedItems = visibleItems.filter((item) => item.id !== dragged.id);
// Find the index of the dragged item and the dropped item
const dragIndex = visibleItems.findIndex((item) => item.id === dragged.id);
const dropIndex = visibleItems.findIndex((item) => item.id === dropped.id);
// Determine the new index for the dragged item
let newIndex;
if (dragIndex < dropIndex) {
newIndex = dropIndex + 1; // Insert after the dropped item
} else {
newIndex = dropIndex; // Insert before the dropped item
}
// Insert the dragged item at the new position
orderedItems.splice(newIndex, 0, dragged);
// Now, we need to update all the items with their new sortIndex
await prisma.$transaction(
orderedItems.map((item, index) => {
return prisma.promptVariant.update({
where: {
id: item.id,
},
data: {
sortIndex: index,
},
});
}),
);
}), }),
}); });

View File

@@ -0,0 +1,65 @@
import { prisma } from "~/server/db";
export const reorderPromptVariants = async (
movedId: string,
stationaryTargetId: string,
alwaysInsertRight?: boolean,
) => {
const moved = await prisma.promptVariant.findUnique({
where: {
id: movedId,
},
});
const target = await prisma.promptVariant.findUnique({
where: {
id: stationaryTargetId,
},
});
if (!moved || !target || moved.experimentId !== target.experimentId) {
throw new Error(`Prompt Variant with id ${movedId} or ${stationaryTargetId} does not exist`);
}
const visibleItems = await prisma.promptVariant.findMany({
where: {
experimentId: moved.experimentId,
visible: true,
},
orderBy: {
sortIndex: "asc",
},
});
// Remove the moved item from its current position
const orderedItems = visibleItems.filter((item) => item.id !== moved.id);
// Find the index of the moved item and the target item
const movedIndex = visibleItems.findIndex((item) => item.id === moved.id);
const targetIndex = visibleItems.findIndex((item) => item.id === target.id);
// Determine the new index for the moved item
let newIndex;
if (movedIndex < targetIndex || alwaysInsertRight) {
newIndex = targetIndex + 1; // Insert after the target item
} else {
newIndex = targetIndex; // Insert before the target item
}
// Insert the moved item at the new position
orderedItems.splice(newIndex, 0, moved);
// Now, we need to update all the items with their new sortIndex
await prisma.$transaction(
orderedItems.map((item, index) => {
return prisma.promptVariant.update({
where: {
id: item.id,
},
data: {
sortIndex: index,
},
});
}),
);
};