From fa5b1ab1c5268294fbd3afcbe354187867518347 Mon Sep 17 00:00:00 2001 From: arcticfly <41524992+arcticfly@users.noreply.github.com> Date: Tue, 18 Jul 2023 13:49:33 -0700 Subject: [PATCH] Allow user to duplicate prompt (#57) * Add dropdown header for model switching * Allow variant duplication * Fix prettier --- src/components/OutputsTable/VariantHeader.tsx | 172 ++++++++++++------ src/components/OutputsTable/index.tsx | 4 +- .../api/routers/promptVariants.router.ts | 108 ++++------- src/server/utils/reorderPromptVariants.ts | 65 +++++++ 4 files changed, 220 insertions(+), 129 deletions(-) create mode 100644 src/server/utils/reorderPromptVariants.ts diff --git a/src/components/OutputsTable/VariantHeader.tsx b/src/components/OutputsTable/VariantHeader.tsx index cdfc68a..25b26c6 100644 --- a/src/components/OutputsTable/VariantHeader.tsx +++ b/src/components/OutputsTable/VariantHeader.tsx @@ -2,11 +2,25 @@ import { useState, type DragEvent } from "react"; import { type PromptVariant } from "./types"; import { api } from "~/utils/api"; import { useHandledAsyncCallback } from "~/utils/hooks"; -import { Button, HStack, Icon, Tooltip } from "@chakra-ui/react"; // Changed here -import { BsX } from "react-icons/bs"; -import { RiDraggable } from "react-icons/ri"; +import { + Button, + HStack, + Icon, + Menu, + MenuButton, + MenuItem, + MenuList, + MenuDivider, + Text, + GridItem, + Spinner, +} from "@chakra-ui/react"; // Changed here +import { BsFillTrashFill, BsGear } from "react-icons/bs"; +import { FaRegClone } from "react-icons/fa"; +import { RiDraggable, RiExchangeFundsFill } from "react-icons/ri"; import { cellPadding, headerMinHeight } from "../constants"; import AutoResizeTextArea from "../AutoResizeTextArea"; +import { stickyHeaderStyle } from "./styles"; export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) { const utils = api.useContext(); @@ -49,59 +63,109 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide: [reorderMutation, props.variant.id], ); + const [menuOpen, setMenuOpen] = useState(false); + const duplicateMutation = api.promptVariants.create.useMutation(); + + const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => { + await duplicateMutation.mutateAsync({ + experimentId: props.variant.experimentId, + variantId: props.variant.id, + }); + await utils.promptVariants.list.invalidate(); + }, [duplicateMutation, props.variant.experimentId, props.variant.id]); + return ( - { - e.dataTransfer.setData("text/plain", props.variant.id); - e.currentTarget.style.opacity = "0.4"; + { - e.currentTarget.style.opacity = "1"; - }} - onDragOver={(e) => { - e.preventDefault(); - setIsDragTarget(true); - }} - onDragLeave={() => { - setIsDragTarget(false); - }} - onDrop={onReorder} - backgroundColor={isDragTarget ? "gray.100" : "transparent"} + borderTopWidth={1} > - - setLabel(e.target.value)} - onBlur={onSaveLabel} - placeholder="Variant Name" - borderWidth={1} - borderColor="transparent" - fontWeight="bold" - fontSize={16} - _hover={{ borderColor: "gray.300" }} - _focus={{ borderColor: "blue.500", outline: "none" }} - flex={1} - px={cellPadding.x} - onMouseEnter={() => setIsInputHovered(true)} - onMouseLeave={() => setIsInputHovered(false)} - /> - {props.canHide && ( - - - - )} - + { + e.dataTransfer.setData("text/plain", props.variant.id); + e.currentTarget.style.opacity = "0.4"; + }} + onDragEnd={(e) => { + e.currentTarget.style.opacity = "1"; + }} + onDragOver={(e) => { + e.preventDefault(); + setIsDragTarget(true); + }} + onDragLeave={() => { + setIsDragTarget(false); + }} + onDrop={onReorder} + backgroundColor={isDragTarget ? "gray.100" : "transparent"} + > + + setLabel(e.target.value)} + onBlur={onSaveLabel} + placeholder="Variant Name" + borderWidth={1} + borderColor="transparent" + fontWeight="bold" + fontSize={16} + _hover={{ borderColor: "gray.300" }} + _focus={{ borderColor: "blue.500", outline: "none" }} + flex={1} + px={cellPadding.x} + onMouseEnter={() => setIsInputHovered(true)} + onMouseLeave={() => setIsInputHovered(false)} + /> + + setMenuOpen(true)} + onClose={() => setMenuOpen(false)} + > + {duplicationInProgress ? ( + + ) : ( + + + + )} + + + } onClick={duplicateVariant}> + Duplicate + + }>Change Model + {props.canHide && ( + <> + + } + color="red.600" + _hover={{ backgroundColor: "red.50" }} + > + Hide + + + )} + + + + ); } diff --git a/src/components/OutputsTable/index.tsx b/src/components/OutputsTable/index.tsx index 782fe4e..d35e327 100644 --- a/src/components/OutputsTable/index.tsx +++ b/src/components/OutputsTable/index.tsx @@ -43,9 +43,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string | {variants.data.map((variant) => ( - - 1} /> - + 1} /> ))} { @@ -135,18 +137,28 @@ export const promptVariantsRouter = createTRPCRouter({ .input( z.object({ experimentId: z.string(), + variantId: z.string().optional(), }), ) .mutation(async ({ input }) => { - const lastVariant = await prisma.promptVariant.findFirst({ - where: { - experimentId: input.experimentId, - visible: true, - }, - orderBy: { - sortIndex: "desc", - }, - }); + let originalVariant: PromptVariant | null = null; + if (input.variantId) { + originalVariant = await prisma.promptVariant.findUnique({ + where: { + id: input.variantId, + }, + }); + } else { + originalVariant = await prisma.promptVariant.findFirst({ + where: { + experimentId: input.experimentId, + visible: true, + }, + orderBy: { + sortIndex: "desc", + }, + }); + } const largestSortIndex = ( @@ -160,13 +172,18 @@ export const promptVariantsRouter = createTRPCRouter({ }) )._max?.sortIndex ?? 0; + const newVariantLabel = + input.variantId && originalVariant + ? `${originalVariant?.label} Copy` + : `Prompt Variant ${largestSortIndex + 2}`; + const createNewVariantAction = prisma.promptVariant.create({ data: { experimentId: input.experimentId, - label: `Prompt Variant ${largestSortIndex + 2}`, - sortIndex: (lastVariant?.sortIndex ?? 0) + 1, + label: newVariantLabel, + sortIndex: (originalVariant?.sortIndex ?? 0) + 1, constructFn: - lastVariant?.constructFn ?? + originalVariant?.constructFn ?? dedent` prompt = { model: "gpt-3.5-turbo", @@ -177,7 +194,7 @@ export const promptVariantsRouter = createTRPCRouter({ } ] }`, - model: lastVariant?.model ?? "gpt-3.5-turbo", + model: originalVariant?.model ?? "gpt-3.5-turbo", }, }); @@ -186,6 +203,11 @@ export const promptVariantsRouter = createTRPCRouter({ recordExperimentUpdated(input.experimentId), ]); + if (originalVariant) { + // Insert new variant to right of original variant + await reorderPromptVariants(newVariant.id, originalVariant.id, true); + } + const scenarios = await prisma.testScenario.findMany({ where: { experimentId: input.experimentId, @@ -338,64 +360,6 @@ export const promptVariantsRouter = createTRPCRouter({ }), ) .mutation(async ({ input }) => { - const dragged = await prisma.promptVariant.findUnique({ - where: { - id: input.draggedId, - }, - }); - - const dropped = await prisma.promptVariant.findUnique({ - where: { - id: input.droppedId, - }, - }); - - if (!dragged || !dropped || dragged.experimentId !== dropped.experimentId) { - throw new Error( - `Prompt Variant with id ${input.draggedId} or ${input.droppedId} does not exist`, - ); - } - - const visibleItems = await prisma.promptVariant.findMany({ - where: { - experimentId: dragged.experimentId, - visible: true, - }, - orderBy: { - sortIndex: "asc", - }, - }); - - // Remove the dragged item from its current position - const orderedItems = visibleItems.filter((item) => item.id !== dragged.id); - - // Find the index of the dragged item and the dropped item - const dragIndex = visibleItems.findIndex((item) => item.id === dragged.id); - const dropIndex = visibleItems.findIndex((item) => item.id === dropped.id); - - // Determine the new index for the dragged item - let newIndex; - if (dragIndex < dropIndex) { - newIndex = dropIndex + 1; // Insert after the dropped item - } else { - newIndex = dropIndex; // Insert before the dropped item - } - - // Insert the dragged item at the new position - orderedItems.splice(newIndex, 0, dragged); - - // Now, we need to update all the items with their new sortIndex - await prisma.$transaction( - orderedItems.map((item, index) => { - return prisma.promptVariant.update({ - where: { - id: item.id, - }, - data: { - sortIndex: index, - }, - }); - }), - ); + await reorderPromptVariants(input.draggedId, input.droppedId); }), }); diff --git a/src/server/utils/reorderPromptVariants.ts b/src/server/utils/reorderPromptVariants.ts new file mode 100644 index 0000000..9a1ba87 --- /dev/null +++ b/src/server/utils/reorderPromptVariants.ts @@ -0,0 +1,65 @@ +import { prisma } from "~/server/db"; + +export const reorderPromptVariants = async ( + movedId: string, + stationaryTargetId: string, + alwaysInsertRight?: boolean, +) => { + const moved = await prisma.promptVariant.findUnique({ + where: { + id: movedId, + }, + }); + + const target = await prisma.promptVariant.findUnique({ + where: { + id: stationaryTargetId, + }, + }); + + if (!moved || !target || moved.experimentId !== target.experimentId) { + throw new Error(`Prompt Variant with id ${movedId} or ${stationaryTargetId} does not exist`); + } + + const visibleItems = await prisma.promptVariant.findMany({ + where: { + experimentId: moved.experimentId, + visible: true, + }, + orderBy: { + sortIndex: "asc", + }, + }); + + // Remove the moved item from its current position + const orderedItems = visibleItems.filter((item) => item.id !== moved.id); + + // Find the index of the moved item and the target item + const movedIndex = visibleItems.findIndex((item) => item.id === moved.id); + const targetIndex = visibleItems.findIndex((item) => item.id === target.id); + + // Determine the new index for the moved item + let newIndex; + if (movedIndex < targetIndex || alwaysInsertRight) { + newIndex = targetIndex + 1; // Insert after the target item + } else { + newIndex = targetIndex; // Insert before the target item + } + + // Insert the moved item at the new position + orderedItems.splice(newIndex, 0, moved); + + // Now, we need to update all the items with their new sortIndex + await prisma.$transaction( + orderedItems.map((item, index) => { + return prisma.promptVariant.update({ + where: { + id: item.id, + }, + data: { + sortIndex: index, + }, + }); + }), + ); +};