Allow user to duplicate prompt (#57)

* Add dropdown header for model switching

* Allow variant duplication

* Fix prettier
This commit is contained in:
arcticfly
2023-07-18 13:49:33 -07:00
committed by GitHub
parent 999a4c08fa
commit fa5b1ab1c5
4 changed files with 220 additions and 129 deletions

View File

@@ -2,11 +2,25 @@ import { useState, type DragEvent } from "react";
import { type PromptVariant } from "./types"; import { type PromptVariant } from "./types";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks"; import { useHandledAsyncCallback } from "~/utils/hooks";
import { Button, HStack, Icon, Tooltip } from "@chakra-ui/react"; // Changed here import {
import { BsX } from "react-icons/bs"; Button,
import { RiDraggable } from "react-icons/ri"; HStack,
Icon,
Menu,
MenuButton,
MenuItem,
MenuList,
MenuDivider,
Text,
GridItem,
Spinner,
} from "@chakra-ui/react"; // Changed here
import { BsFillTrashFill, BsGear } from "react-icons/bs";
import { FaRegClone } from "react-icons/fa";
import { RiDraggable, RiExchangeFundsFill } from "react-icons/ri";
import { cellPadding, headerMinHeight } from "../constants"; import { cellPadding, headerMinHeight } from "../constants";
import AutoResizeTextArea from "../AutoResizeTextArea"; import AutoResizeTextArea from "../AutoResizeTextArea";
import { stickyHeaderStyle } from "./styles";
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) { export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
const utils = api.useContext(); const utils = api.useContext();
@@ -49,10 +63,29 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
[reorderMutation, props.variant.id], [reorderMutation, props.variant.id],
); );
const [menuOpen, setMenuOpen] = useState(false);
const duplicateMutation = api.promptVariants.create.useMutation();
const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => {
await duplicateMutation.mutateAsync({
experimentId: props.variant.experimentId,
variantId: props.variant.id,
});
await utils.promptVariants.list.invalidate();
}, [duplicateMutation, props.variant.experimentId, props.variant.id]);
return ( return (
<GridItem
padding={0}
sx={{
...stickyHeaderStyle,
zIndex: menuOpen ? "dropdown" : stickyHeaderStyle.zIndex,
}}
borderTopWidth={1}
>
<HStack <HStack
spacing={4} spacing={4}
alignItems="center" alignItems="flex-start"
minH={headerMinHeight} minH={headerMinHeight}
draggable={!isInputHovered} draggable={!isInputHovered}
onDragStart={(e) => { onDragStart={(e) => {
@@ -75,10 +108,11 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
<Icon <Icon
as={RiDraggable} as={RiDraggable}
boxSize={6} boxSize={6}
mt={2}
color="gray.400" color="gray.400"
_hover={{ color: "gray.800", cursor: "pointer" }} _hover={{ color: "gray.800", cursor: "pointer" }}
/> />
<AutoResizeTextArea // Changed to Input <AutoResizeTextArea
size="sm" size="sm"
value={label} value={label}
onChange={(e) => setLabel(e.target.value)} onChange={(e) => setLabel(e.target.value)}
@@ -95,13 +129,43 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
onMouseEnter={() => setIsInputHovered(true)} onMouseEnter={() => setIsInputHovered(true)}
onMouseLeave={() => setIsInputHovered(false)} onMouseLeave={() => setIsInputHovered(false)}
/> />
{props.canHide && (
<Tooltip label="Remove Variant" hasArrow> <Menu
<Button variant="ghost" colorScheme="gray" size="sm" onClick={onHide}> z-index="dropdown"
<Icon as={BsX} boxSize={6} /> onOpen={() => setMenuOpen(true)}
onClose={() => setMenuOpen(false)}
>
{duplicationInProgress ? (
<Spinner boxSize={4} mx={3} my={3} />
) : (
<MenuButton>
<Button variant="ghost">
<Icon as={BsGear} />
</Button> </Button>
</Tooltip> </MenuButton>
)} )}
<MenuList mt={-3} fontSize="md">
<MenuItem icon={<Icon as={FaRegClone} boxSize={4} w={5} />} onClick={duplicateVariant}>
Duplicate
</MenuItem>
<MenuItem icon={<Icon as={RiExchangeFundsFill} boxSize={5} />}>Change Model</MenuItem>
{props.canHide && (
<>
<MenuDivider />
<MenuItem
onClick={onHide}
icon={<Icon as={BsFillTrashFill} boxSize={5} />}
color="red.600"
_hover={{ backgroundColor: "red.50" }}
>
<Text>Hide</Text>
</MenuItem>
</>
)}
</MenuList>
</Menu>
</HStack> </HStack>
</GridItem>
); );
} }

View File

@@ -43,9 +43,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
<ScenariosHeader headerRows={headerRows} numScenarios={scenarios.data.length} /> <ScenariosHeader headerRows={headerRows} numScenarios={scenarios.data.length} />
{variants.data.map((variant) => ( {variants.data.map((variant) => (
<GridItem key={variant.uiId} padding={0} sx={stickyHeaderStyle} borderTopWidth={1}> <VariantHeader key={variant.uiId} variant={variant} canHide={variants.data.length > 1} />
<VariantHeader variant={variant} canHide={variants.data.length > 1} />
</GridItem>
))} ))}
<GridItem <GridItem
rowSpan={scenarios.data.length + headerRows} rowSpan={scenarios.data.length + headerRows}

View File

@@ -9,6 +9,8 @@ import { constructPrompt } from "~/server/utils/constructPrompt";
import userError from "~/server/utils/error"; import userError from "~/server/utils/error";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated"; import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
import { calculateTokenCost } from "~/utils/calculateTokenCost"; import { calculateTokenCost } from "~/utils/calculateTokenCost";
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
import { type PromptVariant } from "@prisma/client";
export const promptVariantsRouter = createTRPCRouter({ export const promptVariantsRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
@@ -135,10 +137,19 @@ export const promptVariantsRouter = createTRPCRouter({
.input( .input(
z.object({ z.object({
experimentId: z.string(), experimentId: z.string(),
variantId: z.string().optional(),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input }) => {
const lastVariant = await prisma.promptVariant.findFirst({ let originalVariant: PromptVariant | null = null;
if (input.variantId) {
originalVariant = await prisma.promptVariant.findUnique({
where: {
id: input.variantId,
},
});
} else {
originalVariant = await prisma.promptVariant.findFirst({
where: { where: {
experimentId: input.experimentId, experimentId: input.experimentId,
visible: true, visible: true,
@@ -147,6 +158,7 @@ export const promptVariantsRouter = createTRPCRouter({
sortIndex: "desc", sortIndex: "desc",
}, },
}); });
}
const largestSortIndex = const largestSortIndex =
( (
@@ -160,13 +172,18 @@ export const promptVariantsRouter = createTRPCRouter({
}) })
)._max?.sortIndex ?? 0; )._max?.sortIndex ?? 0;
const newVariantLabel =
input.variantId && originalVariant
? `${originalVariant?.label} Copy`
: `Prompt Variant ${largestSortIndex + 2}`;
const createNewVariantAction = prisma.promptVariant.create({ const createNewVariantAction = prisma.promptVariant.create({
data: { data: {
experimentId: input.experimentId, experimentId: input.experimentId,
label: `Prompt Variant ${largestSortIndex + 2}`, label: newVariantLabel,
sortIndex: (lastVariant?.sortIndex ?? 0) + 1, sortIndex: (originalVariant?.sortIndex ?? 0) + 1,
constructFn: constructFn:
lastVariant?.constructFn ?? originalVariant?.constructFn ??
dedent` dedent`
prompt = { prompt = {
model: "gpt-3.5-turbo", model: "gpt-3.5-turbo",
@@ -177,7 +194,7 @@ export const promptVariantsRouter = createTRPCRouter({
} }
] ]
}`, }`,
model: lastVariant?.model ?? "gpt-3.5-turbo", model: originalVariant?.model ?? "gpt-3.5-turbo",
}, },
}); });
@@ -186,6 +203,11 @@ export const promptVariantsRouter = createTRPCRouter({
recordExperimentUpdated(input.experimentId), recordExperimentUpdated(input.experimentId),
]); ]);
if (originalVariant) {
// Insert new variant to right of original variant
await reorderPromptVariants(newVariant.id, originalVariant.id, true);
}
const scenarios = await prisma.testScenario.findMany({ const scenarios = await prisma.testScenario.findMany({
where: { where: {
experimentId: input.experimentId, experimentId: input.experimentId,
@@ -338,64 +360,6 @@ export const promptVariantsRouter = createTRPCRouter({
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input }) => {
const dragged = await prisma.promptVariant.findUnique({ await reorderPromptVariants(input.draggedId, input.droppedId);
where: {
id: input.draggedId,
},
});
const dropped = await prisma.promptVariant.findUnique({
where: {
id: input.droppedId,
},
});
if (!dragged || !dropped || dragged.experimentId !== dropped.experimentId) {
throw new Error(
`Prompt Variant with id ${input.draggedId} or ${input.droppedId} does not exist`,
);
}
const visibleItems = await prisma.promptVariant.findMany({
where: {
experimentId: dragged.experimentId,
visible: true,
},
orderBy: {
sortIndex: "asc",
},
});
// Remove the dragged item from its current position
const orderedItems = visibleItems.filter((item) => item.id !== dragged.id);
// Find the index of the dragged item and the dropped item
const dragIndex = visibleItems.findIndex((item) => item.id === dragged.id);
const dropIndex = visibleItems.findIndex((item) => item.id === dropped.id);
// Determine the new index for the dragged item
let newIndex;
if (dragIndex < dropIndex) {
newIndex = dropIndex + 1; // Insert after the dropped item
} else {
newIndex = dropIndex; // Insert before the dropped item
}
// Insert the dragged item at the new position
orderedItems.splice(newIndex, 0, dragged);
// Now, we need to update all the items with their new sortIndex
await prisma.$transaction(
orderedItems.map((item, index) => {
return prisma.promptVariant.update({
where: {
id: item.id,
},
data: {
sortIndex: index,
},
});
}),
);
}), }),
}); });

View File

@@ -0,0 +1,65 @@
import { prisma } from "~/server/db";
export const reorderPromptVariants = async (
movedId: string,
stationaryTargetId: string,
alwaysInsertRight?: boolean,
) => {
const moved = await prisma.promptVariant.findUnique({
where: {
id: movedId,
},
});
const target = await prisma.promptVariant.findUnique({
where: {
id: stationaryTargetId,
},
});
if (!moved || !target || moved.experimentId !== target.experimentId) {
throw new Error(`Prompt Variant with id ${movedId} or ${stationaryTargetId} does not exist`);
}
const visibleItems = await prisma.promptVariant.findMany({
where: {
experimentId: moved.experimentId,
visible: true,
},
orderBy: {
sortIndex: "asc",
},
});
// Remove the moved item from its current position
const orderedItems = visibleItems.filter((item) => item.id !== moved.id);
// Find the index of the moved item and the target item
const movedIndex = visibleItems.findIndex((item) => item.id === moved.id);
const targetIndex = visibleItems.findIndex((item) => item.id === target.id);
// Determine the new index for the moved item
let newIndex;
if (movedIndex < targetIndex || alwaysInsertRight) {
newIndex = targetIndex + 1; // Insert after the target item
} else {
newIndex = targetIndex; // Insert before the target item
}
// Insert the moved item at the new position
orderedItems.splice(newIndex, 0, moved);
// Now, we need to update all the items with their new sortIndex
await prisma.$transaction(
orderedItems.map((item, index) => {
return prisma.promptVariant.update({
where: {
id: item.id,
},
data: {
sortIndex: index,
},
});
}),
);
};