Allow user to duplicate prompt (#57)
* Add dropdown header for model switching * Allow variant duplication * Fix prettier
This commit is contained in:
@@ -2,11 +2,25 @@ import { useState, type DragEvent } from "react";
|
|||||||
import { type PromptVariant } from "./types";
|
import { type PromptVariant } from "./types";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
import { useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
import { Button, HStack, Icon, Tooltip } from "@chakra-ui/react"; // Changed here
|
import {
|
||||||
import { BsX } from "react-icons/bs";
|
Button,
|
||||||
import { RiDraggable } from "react-icons/ri";
|
HStack,
|
||||||
|
Icon,
|
||||||
|
Menu,
|
||||||
|
MenuButton,
|
||||||
|
MenuItem,
|
||||||
|
MenuList,
|
||||||
|
MenuDivider,
|
||||||
|
Text,
|
||||||
|
GridItem,
|
||||||
|
Spinner,
|
||||||
|
} from "@chakra-ui/react"; // Changed here
|
||||||
|
import { BsFillTrashFill, BsGear } from "react-icons/bs";
|
||||||
|
import { FaRegClone } from "react-icons/fa";
|
||||||
|
import { RiDraggable, RiExchangeFundsFill } from "react-icons/ri";
|
||||||
import { cellPadding, headerMinHeight } from "../constants";
|
import { cellPadding, headerMinHeight } from "../constants";
|
||||||
import AutoResizeTextArea from "../AutoResizeTextArea";
|
import AutoResizeTextArea from "../AutoResizeTextArea";
|
||||||
|
import { stickyHeaderStyle } from "./styles";
|
||||||
|
|
||||||
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
|
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
@@ -49,59 +63,109 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
|
|||||||
[reorderMutation, props.variant.id],
|
[reorderMutation, props.variant.id],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const [menuOpen, setMenuOpen] = useState(false);
|
||||||
|
const duplicateMutation = api.promptVariants.create.useMutation();
|
||||||
|
|
||||||
|
const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => {
|
||||||
|
await duplicateMutation.mutateAsync({
|
||||||
|
experimentId: props.variant.experimentId,
|
||||||
|
variantId: props.variant.id,
|
||||||
|
});
|
||||||
|
await utils.promptVariants.list.invalidate();
|
||||||
|
}, [duplicateMutation, props.variant.experimentId, props.variant.id]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<HStack
|
<GridItem
|
||||||
spacing={4}
|
padding={0}
|
||||||
alignItems="center"
|
sx={{
|
||||||
minH={headerMinHeight}
|
...stickyHeaderStyle,
|
||||||
draggable={!isInputHovered}
|
zIndex: menuOpen ? "dropdown" : stickyHeaderStyle.zIndex,
|
||||||
onDragStart={(e) => {
|
|
||||||
e.dataTransfer.setData("text/plain", props.variant.id);
|
|
||||||
e.currentTarget.style.opacity = "0.4";
|
|
||||||
}}
|
}}
|
||||||
onDragEnd={(e) => {
|
borderTopWidth={1}
|
||||||
e.currentTarget.style.opacity = "1";
|
|
||||||
}}
|
|
||||||
onDragOver={(e) => {
|
|
||||||
e.preventDefault();
|
|
||||||
setIsDragTarget(true);
|
|
||||||
}}
|
|
||||||
onDragLeave={() => {
|
|
||||||
setIsDragTarget(false);
|
|
||||||
}}
|
|
||||||
onDrop={onReorder}
|
|
||||||
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
|
|
||||||
>
|
>
|
||||||
<Icon
|
<HStack
|
||||||
as={RiDraggable}
|
spacing={4}
|
||||||
boxSize={6}
|
alignItems="flex-start"
|
||||||
color="gray.400"
|
minH={headerMinHeight}
|
||||||
_hover={{ color: "gray.800", cursor: "pointer" }}
|
draggable={!isInputHovered}
|
||||||
/>
|
onDragStart={(e) => {
|
||||||
<AutoResizeTextArea // Changed to Input
|
e.dataTransfer.setData("text/plain", props.variant.id);
|
||||||
size="sm"
|
e.currentTarget.style.opacity = "0.4";
|
||||||
value={label}
|
}}
|
||||||
onChange={(e) => setLabel(e.target.value)}
|
onDragEnd={(e) => {
|
||||||
onBlur={onSaveLabel}
|
e.currentTarget.style.opacity = "1";
|
||||||
placeholder="Variant Name"
|
}}
|
||||||
borderWidth={1}
|
onDragOver={(e) => {
|
||||||
borderColor="transparent"
|
e.preventDefault();
|
||||||
fontWeight="bold"
|
setIsDragTarget(true);
|
||||||
fontSize={16}
|
}}
|
||||||
_hover={{ borderColor: "gray.300" }}
|
onDragLeave={() => {
|
||||||
_focus={{ borderColor: "blue.500", outline: "none" }}
|
setIsDragTarget(false);
|
||||||
flex={1}
|
}}
|
||||||
px={cellPadding.x}
|
onDrop={onReorder}
|
||||||
onMouseEnter={() => setIsInputHovered(true)}
|
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
|
||||||
onMouseLeave={() => setIsInputHovered(false)}
|
>
|
||||||
/>
|
<Icon
|
||||||
{props.canHide && (
|
as={RiDraggable}
|
||||||
<Tooltip label="Remove Variant" hasArrow>
|
boxSize={6}
|
||||||
<Button variant="ghost" colorScheme="gray" size="sm" onClick={onHide}>
|
mt={2}
|
||||||
<Icon as={BsX} boxSize={6} />
|
color="gray.400"
|
||||||
</Button>
|
_hover={{ color: "gray.800", cursor: "pointer" }}
|
||||||
</Tooltip>
|
/>
|
||||||
)}
|
<AutoResizeTextArea
|
||||||
</HStack>
|
size="sm"
|
||||||
|
value={label}
|
||||||
|
onChange={(e) => setLabel(e.target.value)}
|
||||||
|
onBlur={onSaveLabel}
|
||||||
|
placeholder="Variant Name"
|
||||||
|
borderWidth={1}
|
||||||
|
borderColor="transparent"
|
||||||
|
fontWeight="bold"
|
||||||
|
fontSize={16}
|
||||||
|
_hover={{ borderColor: "gray.300" }}
|
||||||
|
_focus={{ borderColor: "blue.500", outline: "none" }}
|
||||||
|
flex={1}
|
||||||
|
px={cellPadding.x}
|
||||||
|
onMouseEnter={() => setIsInputHovered(true)}
|
||||||
|
onMouseLeave={() => setIsInputHovered(false)}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<Menu
|
||||||
|
z-index="dropdown"
|
||||||
|
onOpen={() => setMenuOpen(true)}
|
||||||
|
onClose={() => setMenuOpen(false)}
|
||||||
|
>
|
||||||
|
{duplicationInProgress ? (
|
||||||
|
<Spinner boxSize={4} mx={3} my={3} />
|
||||||
|
) : (
|
||||||
|
<MenuButton>
|
||||||
|
<Button variant="ghost">
|
||||||
|
<Icon as={BsGear} />
|
||||||
|
</Button>
|
||||||
|
</MenuButton>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<MenuList mt={-3} fontSize="md">
|
||||||
|
<MenuItem icon={<Icon as={FaRegClone} boxSize={4} w={5} />} onClick={duplicateVariant}>
|
||||||
|
Duplicate
|
||||||
|
</MenuItem>
|
||||||
|
<MenuItem icon={<Icon as={RiExchangeFundsFill} boxSize={5} />}>Change Model</MenuItem>
|
||||||
|
{props.canHide && (
|
||||||
|
<>
|
||||||
|
<MenuDivider />
|
||||||
|
<MenuItem
|
||||||
|
onClick={onHide}
|
||||||
|
icon={<Icon as={BsFillTrashFill} boxSize={5} />}
|
||||||
|
color="red.600"
|
||||||
|
_hover={{ backgroundColor: "red.50" }}
|
||||||
|
>
|
||||||
|
<Text>Hide</Text>
|
||||||
|
</MenuItem>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</MenuList>
|
||||||
|
</Menu>
|
||||||
|
</HStack>
|
||||||
|
</GridItem>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,9 +43,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
|||||||
<ScenariosHeader headerRows={headerRows} numScenarios={scenarios.data.length} />
|
<ScenariosHeader headerRows={headerRows} numScenarios={scenarios.data.length} />
|
||||||
|
|
||||||
{variants.data.map((variant) => (
|
{variants.data.map((variant) => (
|
||||||
<GridItem key={variant.uiId} padding={0} sx={stickyHeaderStyle} borderTopWidth={1}>
|
<VariantHeader key={variant.uiId} variant={variant} canHide={variants.data.length > 1} />
|
||||||
<VariantHeader variant={variant} canHide={variants.data.length > 1} />
|
|
||||||
</GridItem>
|
|
||||||
))}
|
))}
|
||||||
<GridItem
|
<GridItem
|
||||||
rowSpan={scenarios.data.length + headerRows}
|
rowSpan={scenarios.data.length + headerRows}
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import { constructPrompt } from "~/server/utils/constructPrompt";
|
|||||||
import userError from "~/server/utils/error";
|
import userError from "~/server/utils/error";
|
||||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||||
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
||||||
|
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
|
||||||
|
import { type PromptVariant } from "@prisma/client";
|
||||||
|
|
||||||
export const promptVariantsRouter = createTRPCRouter({
|
export const promptVariantsRouter = createTRPCRouter({
|
||||||
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
||||||
@@ -135,18 +137,28 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
experimentId: z.string(),
|
experimentId: z.string(),
|
||||||
|
variantId: z.string().optional(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input }) => {
|
||||||
const lastVariant = await prisma.promptVariant.findFirst({
|
let originalVariant: PromptVariant | null = null;
|
||||||
where: {
|
if (input.variantId) {
|
||||||
experimentId: input.experimentId,
|
originalVariant = await prisma.promptVariant.findUnique({
|
||||||
visible: true,
|
where: {
|
||||||
},
|
id: input.variantId,
|
||||||
orderBy: {
|
},
|
||||||
sortIndex: "desc",
|
});
|
||||||
},
|
} else {
|
||||||
});
|
originalVariant = await prisma.promptVariant.findFirst({
|
||||||
|
where: {
|
||||||
|
experimentId: input.experimentId,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
orderBy: {
|
||||||
|
sortIndex: "desc",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
const largestSortIndex =
|
const largestSortIndex =
|
||||||
(
|
(
|
||||||
@@ -160,13 +172,18 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
})
|
})
|
||||||
)._max?.sortIndex ?? 0;
|
)._max?.sortIndex ?? 0;
|
||||||
|
|
||||||
|
const newVariantLabel =
|
||||||
|
input.variantId && originalVariant
|
||||||
|
? `${originalVariant?.label} Copy`
|
||||||
|
: `Prompt Variant ${largestSortIndex + 2}`;
|
||||||
|
|
||||||
const createNewVariantAction = prisma.promptVariant.create({
|
const createNewVariantAction = prisma.promptVariant.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
label: `Prompt Variant ${largestSortIndex + 2}`,
|
label: newVariantLabel,
|
||||||
sortIndex: (lastVariant?.sortIndex ?? 0) + 1,
|
sortIndex: (originalVariant?.sortIndex ?? 0) + 1,
|
||||||
constructFn:
|
constructFn:
|
||||||
lastVariant?.constructFn ??
|
originalVariant?.constructFn ??
|
||||||
dedent`
|
dedent`
|
||||||
prompt = {
|
prompt = {
|
||||||
model: "gpt-3.5-turbo",
|
model: "gpt-3.5-turbo",
|
||||||
@@ -177,7 +194,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
}`,
|
}`,
|
||||||
model: lastVariant?.model ?? "gpt-3.5-turbo",
|
model: originalVariant?.model ?? "gpt-3.5-turbo",
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -186,6 +203,11 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
recordExperimentUpdated(input.experimentId),
|
recordExperimentUpdated(input.experimentId),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
if (originalVariant) {
|
||||||
|
// Insert new variant to right of original variant
|
||||||
|
await reorderPromptVariants(newVariant.id, originalVariant.id, true);
|
||||||
|
}
|
||||||
|
|
||||||
const scenarios = await prisma.testScenario.findMany({
|
const scenarios = await prisma.testScenario.findMany({
|
||||||
where: {
|
where: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
@@ -338,64 +360,6 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input }) => {
|
||||||
const dragged = await prisma.promptVariant.findUnique({
|
await reorderPromptVariants(input.draggedId, input.droppedId);
|
||||||
where: {
|
|
||||||
id: input.draggedId,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const dropped = await prisma.promptVariant.findUnique({
|
|
||||||
where: {
|
|
||||||
id: input.droppedId,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!dragged || !dropped || dragged.experimentId !== dropped.experimentId) {
|
|
||||||
throw new Error(
|
|
||||||
`Prompt Variant with id ${input.draggedId} or ${input.droppedId} does not exist`,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
const visibleItems = await prisma.promptVariant.findMany({
|
|
||||||
where: {
|
|
||||||
experimentId: dragged.experimentId,
|
|
||||||
visible: true,
|
|
||||||
},
|
|
||||||
orderBy: {
|
|
||||||
sortIndex: "asc",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Remove the dragged item from its current position
|
|
||||||
const orderedItems = visibleItems.filter((item) => item.id !== dragged.id);
|
|
||||||
|
|
||||||
// Find the index of the dragged item and the dropped item
|
|
||||||
const dragIndex = visibleItems.findIndex((item) => item.id === dragged.id);
|
|
||||||
const dropIndex = visibleItems.findIndex((item) => item.id === dropped.id);
|
|
||||||
|
|
||||||
// Determine the new index for the dragged item
|
|
||||||
let newIndex;
|
|
||||||
if (dragIndex < dropIndex) {
|
|
||||||
newIndex = dropIndex + 1; // Insert after the dropped item
|
|
||||||
} else {
|
|
||||||
newIndex = dropIndex; // Insert before the dropped item
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert the dragged item at the new position
|
|
||||||
orderedItems.splice(newIndex, 0, dragged);
|
|
||||||
|
|
||||||
// Now, we need to update all the items with their new sortIndex
|
|
||||||
await prisma.$transaction(
|
|
||||||
orderedItems.map((item, index) => {
|
|
||||||
return prisma.promptVariant.update({
|
|
||||||
where: {
|
|
||||||
id: item.id,
|
|
||||||
},
|
|
||||||
data: {
|
|
||||||
sortIndex: index,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|||||||
65
src/server/utils/reorderPromptVariants.ts
Normal file
65
src/server/utils/reorderPromptVariants.ts
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
import { prisma } from "~/server/db";
|
||||||
|
|
||||||
|
export const reorderPromptVariants = async (
|
||||||
|
movedId: string,
|
||||||
|
stationaryTargetId: string,
|
||||||
|
alwaysInsertRight?: boolean,
|
||||||
|
) => {
|
||||||
|
const moved = await prisma.promptVariant.findUnique({
|
||||||
|
where: {
|
||||||
|
id: movedId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const target = await prisma.promptVariant.findUnique({
|
||||||
|
where: {
|
||||||
|
id: stationaryTargetId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!moved || !target || moved.experimentId !== target.experimentId) {
|
||||||
|
throw new Error(`Prompt Variant with id ${movedId} or ${stationaryTargetId} does not exist`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const visibleItems = await prisma.promptVariant.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: moved.experimentId,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
orderBy: {
|
||||||
|
sortIndex: "asc",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Remove the moved item from its current position
|
||||||
|
const orderedItems = visibleItems.filter((item) => item.id !== moved.id);
|
||||||
|
|
||||||
|
// Find the index of the moved item and the target item
|
||||||
|
const movedIndex = visibleItems.findIndex((item) => item.id === moved.id);
|
||||||
|
const targetIndex = visibleItems.findIndex((item) => item.id === target.id);
|
||||||
|
|
||||||
|
// Determine the new index for the moved item
|
||||||
|
let newIndex;
|
||||||
|
if (movedIndex < targetIndex || alwaysInsertRight) {
|
||||||
|
newIndex = targetIndex + 1; // Insert after the target item
|
||||||
|
} else {
|
||||||
|
newIndex = targetIndex; // Insert before the target item
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert the moved item at the new position
|
||||||
|
orderedItems.splice(newIndex, 0, moved);
|
||||||
|
|
||||||
|
// Now, we need to update all the items with their new sortIndex
|
||||||
|
await prisma.$transaction(
|
||||||
|
orderedItems.map((item, index) => {
|
||||||
|
return prisma.promptVariant.update({
|
||||||
|
where: {
|
||||||
|
id: item.id,
|
||||||
|
},
|
||||||
|
data: {
|
||||||
|
sortIndex: index,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
};
|
||||||
Reference in New Issue
Block a user