From fa5b1ab1c5268294fbd3afcbe354187867518347 Mon Sep 17 00:00:00 2001
From: arcticfly <41524992+arcticfly@users.noreply.github.com>
Date: Tue, 18 Jul 2023 13:49:33 -0700
Subject: [PATCH] Allow user to duplicate prompt (#57)
* Add dropdown header for model switching
* Allow variant duplication
* Fix prettier
---
src/components/OutputsTable/VariantHeader.tsx | 172 ++++++++++++------
src/components/OutputsTable/index.tsx | 4 +-
.../api/routers/promptVariants.router.ts | 108 ++++-------
src/server/utils/reorderPromptVariants.ts | 65 +++++++
4 files changed, 220 insertions(+), 129 deletions(-)
create mode 100644 src/server/utils/reorderPromptVariants.ts
diff --git a/src/components/OutputsTable/VariantHeader.tsx b/src/components/OutputsTable/VariantHeader.tsx
index cdfc68a..25b26c6 100644
--- a/src/components/OutputsTable/VariantHeader.tsx
+++ b/src/components/OutputsTable/VariantHeader.tsx
@@ -2,11 +2,25 @@ import { useState, type DragEvent } from "react";
import { type PromptVariant } from "./types";
import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks";
-import { Button, HStack, Icon, Tooltip } from "@chakra-ui/react"; // Changed here
-import { BsX } from "react-icons/bs";
-import { RiDraggable } from "react-icons/ri";
+import {
+ Button,
+ HStack,
+ Icon,
+ Menu,
+ MenuButton,
+ MenuItem,
+ MenuList,
+ MenuDivider,
+ Text,
+ GridItem,
+ Spinner,
+} from "@chakra-ui/react"; // Changed here
+import { BsFillTrashFill, BsGear } from "react-icons/bs";
+import { FaRegClone } from "react-icons/fa";
+import { RiDraggable, RiExchangeFundsFill } from "react-icons/ri";
import { cellPadding, headerMinHeight } from "../constants";
import AutoResizeTextArea from "../AutoResizeTextArea";
+import { stickyHeaderStyle } from "./styles";
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
const utils = api.useContext();
@@ -49,59 +63,109 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
[reorderMutation, props.variant.id],
);
+ const [menuOpen, setMenuOpen] = useState(false);
+ const duplicateMutation = api.promptVariants.create.useMutation();
+
+ const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => {
+ await duplicateMutation.mutateAsync({
+ experimentId: props.variant.experimentId,
+ variantId: props.variant.id,
+ });
+ await utils.promptVariants.list.invalidate();
+ }, [duplicateMutation, props.variant.experimentId, props.variant.id]);
+
return (
- {
- e.dataTransfer.setData("text/plain", props.variant.id);
- e.currentTarget.style.opacity = "0.4";
+ {
- e.currentTarget.style.opacity = "1";
- }}
- onDragOver={(e) => {
- e.preventDefault();
- setIsDragTarget(true);
- }}
- onDragLeave={() => {
- setIsDragTarget(false);
- }}
- onDrop={onReorder}
- backgroundColor={isDragTarget ? "gray.100" : "transparent"}
+ borderTopWidth={1}
>
-
- setLabel(e.target.value)}
- onBlur={onSaveLabel}
- placeholder="Variant Name"
- borderWidth={1}
- borderColor="transparent"
- fontWeight="bold"
- fontSize={16}
- _hover={{ borderColor: "gray.300" }}
- _focus={{ borderColor: "blue.500", outline: "none" }}
- flex={1}
- px={cellPadding.x}
- onMouseEnter={() => setIsInputHovered(true)}
- onMouseLeave={() => setIsInputHovered(false)}
- />
- {props.canHide && (
-
-
-
- )}
-
+ {
+ e.dataTransfer.setData("text/plain", props.variant.id);
+ e.currentTarget.style.opacity = "0.4";
+ }}
+ onDragEnd={(e) => {
+ e.currentTarget.style.opacity = "1";
+ }}
+ onDragOver={(e) => {
+ e.preventDefault();
+ setIsDragTarget(true);
+ }}
+ onDragLeave={() => {
+ setIsDragTarget(false);
+ }}
+ onDrop={onReorder}
+ backgroundColor={isDragTarget ? "gray.100" : "transparent"}
+ >
+
+ setLabel(e.target.value)}
+ onBlur={onSaveLabel}
+ placeholder="Variant Name"
+ borderWidth={1}
+ borderColor="transparent"
+ fontWeight="bold"
+ fontSize={16}
+ _hover={{ borderColor: "gray.300" }}
+ _focus={{ borderColor: "blue.500", outline: "none" }}
+ flex={1}
+ px={cellPadding.x}
+ onMouseEnter={() => setIsInputHovered(true)}
+ onMouseLeave={() => setIsInputHovered(false)}
+ />
+
+
+
+
);
}
diff --git a/src/components/OutputsTable/index.tsx b/src/components/OutputsTable/index.tsx
index 782fe4e..d35e327 100644
--- a/src/components/OutputsTable/index.tsx
+++ b/src/components/OutputsTable/index.tsx
@@ -43,9 +43,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
{variants.data.map((variant) => (
-
- 1} />
-
+ 1} />
))}
{
@@ -135,18 +137,28 @@ export const promptVariantsRouter = createTRPCRouter({
.input(
z.object({
experimentId: z.string(),
+ variantId: z.string().optional(),
}),
)
.mutation(async ({ input }) => {
- const lastVariant = await prisma.promptVariant.findFirst({
- where: {
- experimentId: input.experimentId,
- visible: true,
- },
- orderBy: {
- sortIndex: "desc",
- },
- });
+ let originalVariant: PromptVariant | null = null;
+ if (input.variantId) {
+ originalVariant = await prisma.promptVariant.findUnique({
+ where: {
+ id: input.variantId,
+ },
+ });
+ } else {
+ originalVariant = await prisma.promptVariant.findFirst({
+ where: {
+ experimentId: input.experimentId,
+ visible: true,
+ },
+ orderBy: {
+ sortIndex: "desc",
+ },
+ });
+ }
const largestSortIndex =
(
@@ -160,13 +172,18 @@ export const promptVariantsRouter = createTRPCRouter({
})
)._max?.sortIndex ?? 0;
+ const newVariantLabel =
+ input.variantId && originalVariant
+ ? `${originalVariant?.label} Copy`
+ : `Prompt Variant ${largestSortIndex + 2}`;
+
const createNewVariantAction = prisma.promptVariant.create({
data: {
experimentId: input.experimentId,
- label: `Prompt Variant ${largestSortIndex + 2}`,
- sortIndex: (lastVariant?.sortIndex ?? 0) + 1,
+ label: newVariantLabel,
+ sortIndex: (originalVariant?.sortIndex ?? 0) + 1,
constructFn:
- lastVariant?.constructFn ??
+ originalVariant?.constructFn ??
dedent`
prompt = {
model: "gpt-3.5-turbo",
@@ -177,7 +194,7 @@ export const promptVariantsRouter = createTRPCRouter({
}
]
}`,
- model: lastVariant?.model ?? "gpt-3.5-turbo",
+ model: originalVariant?.model ?? "gpt-3.5-turbo",
},
});
@@ -186,6 +203,11 @@ export const promptVariantsRouter = createTRPCRouter({
recordExperimentUpdated(input.experimentId),
]);
+ if (originalVariant) {
+ // Insert new variant to right of original variant
+ await reorderPromptVariants(newVariant.id, originalVariant.id, true);
+ }
+
const scenarios = await prisma.testScenario.findMany({
where: {
experimentId: input.experimentId,
@@ -338,64 +360,6 @@ export const promptVariantsRouter = createTRPCRouter({
}),
)
.mutation(async ({ input }) => {
- const dragged = await prisma.promptVariant.findUnique({
- where: {
- id: input.draggedId,
- },
- });
-
- const dropped = await prisma.promptVariant.findUnique({
- where: {
- id: input.droppedId,
- },
- });
-
- if (!dragged || !dropped || dragged.experimentId !== dropped.experimentId) {
- throw new Error(
- `Prompt Variant with id ${input.draggedId} or ${input.droppedId} does not exist`,
- );
- }
-
- const visibleItems = await prisma.promptVariant.findMany({
- where: {
- experimentId: dragged.experimentId,
- visible: true,
- },
- orderBy: {
- sortIndex: "asc",
- },
- });
-
- // Remove the dragged item from its current position
- const orderedItems = visibleItems.filter((item) => item.id !== dragged.id);
-
- // Find the index of the dragged item and the dropped item
- const dragIndex = visibleItems.findIndex((item) => item.id === dragged.id);
- const dropIndex = visibleItems.findIndex((item) => item.id === dropped.id);
-
- // Determine the new index for the dragged item
- let newIndex;
- if (dragIndex < dropIndex) {
- newIndex = dropIndex + 1; // Insert after the dropped item
- } else {
- newIndex = dropIndex; // Insert before the dropped item
- }
-
- // Insert the dragged item at the new position
- orderedItems.splice(newIndex, 0, dragged);
-
- // Now, we need to update all the items with their new sortIndex
- await prisma.$transaction(
- orderedItems.map((item, index) => {
- return prisma.promptVariant.update({
- where: {
- id: item.id,
- },
- data: {
- sortIndex: index,
- },
- });
- }),
- );
+ await reorderPromptVariants(input.draggedId, input.droppedId);
}),
});
diff --git a/src/server/utils/reorderPromptVariants.ts b/src/server/utils/reorderPromptVariants.ts
new file mode 100644
index 0000000..9a1ba87
--- /dev/null
+++ b/src/server/utils/reorderPromptVariants.ts
@@ -0,0 +1,65 @@
+import { prisma } from "~/server/db";
+
+export const reorderPromptVariants = async (
+ movedId: string,
+ stationaryTargetId: string,
+ alwaysInsertRight?: boolean,
+) => {
+ const moved = await prisma.promptVariant.findUnique({
+ where: {
+ id: movedId,
+ },
+ });
+
+ const target = await prisma.promptVariant.findUnique({
+ where: {
+ id: stationaryTargetId,
+ },
+ });
+
+ if (!moved || !target || moved.experimentId !== target.experimentId) {
+ throw new Error(`Prompt Variant with id ${movedId} or ${stationaryTargetId} does not exist`);
+ }
+
+ const visibleItems = await prisma.promptVariant.findMany({
+ where: {
+ experimentId: moved.experimentId,
+ visible: true,
+ },
+ orderBy: {
+ sortIndex: "asc",
+ },
+ });
+
+ // Remove the moved item from its current position
+ const orderedItems = visibleItems.filter((item) => item.id !== moved.id);
+
+ // Find the index of the moved item and the target item
+ const movedIndex = visibleItems.findIndex((item) => item.id === moved.id);
+ const targetIndex = visibleItems.findIndex((item) => item.id === target.id);
+
+ // Determine the new index for the moved item
+ let newIndex;
+ if (movedIndex < targetIndex || alwaysInsertRight) {
+ newIndex = targetIndex + 1; // Insert after the target item
+ } else {
+ newIndex = targetIndex; // Insert before the target item
+ }
+
+ // Insert the moved item at the new position
+ orderedItems.splice(newIndex, 0, moved);
+
+ // Now, we need to update all the items with their new sortIndex
+ await prisma.$transaction(
+ orderedItems.map((item, index) => {
+ return prisma.promptVariant.update({
+ where: {
+ id: item.id,
+ },
+ data: {
+ sortIndex: index,
+ },
+ });
+ }),
+ );
+};