diff --git a/src/components/OutputsTable/OutputCell/OutputCell.tsx b/src/components/OutputsTable/OutputCell/OutputCell.tsx index 1338ab9..595dd20 100644 --- a/src/components/OutputsTable/OutputCell/OutputCell.tsx +++ b/src/components/OutputsTable/OutputCell/OutputCell.tsx @@ -42,8 +42,7 @@ export default function OutputCell({ { refetchInterval }, ); - const { mutateAsync: hardRefetchMutate } = - api.scenarioVariantCells.forceRefetch.useMutation(); + const { mutateAsync: hardRefetchMutate } = api.scenarioVariantCells.forceRefetch.useMutation(); const [hardRefetch, hardRefetching] = useHandledAsyncCallback(async () => { await hardRefetchMutate({ scenarioId: scenario.id, variantId: variant.id }); await utils.scenarioVariantCells.get.invalidate({ diff --git a/src/components/RefinePromptModal/CompareFunctions.tsx b/src/components/RefinePromptModal/CompareFunctions.tsx index 8b8561f..fee9736 100644 --- a/src/components/RefinePromptModal/CompareFunctions.tsx +++ b/src/components/RefinePromptModal/CompareFunctions.tsx @@ -35,7 +35,7 @@ const CompareFunctions = ({ return ( - + void; + loading: boolean; + onSubmit: () => void; +}) => { + return ( + + setInstructions(e.target.value)} + onKeyDown={(e) => { + if (e.key === "Enter" && !e.metaKey && !e.ctrlKey && !e.shiftKey) { + e.preventDefault(); + e.currentTarget.blur(); + onSubmit(); + } + }} + placeholder="Send custom instructions" + py={4} + pl={4} + pr={12} + colorScheme="orange" + borderColor="gray.300" + borderWidth={1} + _hover={{ + borderColor: "gray.300", + }} + _focus={{ + borderColor: "gray.300", + }} + isDisabled={loading} + /> + + + + + + ); +}; diff --git a/src/components/RefinePromptModal/RefineOption.tsx b/src/components/RefinePromptModal/RefineOption.tsx new file mode 100644 index 0000000..9a0eb77 --- /dev/null +++ b/src/components/RefinePromptModal/RefineOption.tsx @@ -0,0 +1,64 @@ +import { HStack, Icon, Heading, Text, VStack, GridItem } from "@chakra-ui/react"; +import { type IconType } from "react-icons"; +import { refineOptions, type RefineOptionLabel } from "./refineOptions"; + +export const RefineOption = ({ + label, + activeLabel, + icon, + onClick, + loading, +}: { + label: RefineOptionLabel; + activeLabel: RefineOptionLabel | undefined; + icon: IconType; + onClick: (label: RefineOptionLabel) => void; + loading: boolean; +}) => { + const isActive = activeLabel === label; + const desciption = refineOptions[label].description; + + return ( + + { + !loading && onClick(label); + }} + borderColor={isActive ? "blue.500" : "gray.200"} + borderWidth={2} + borderRadius={16} + padding={6} + backgroundColor="gray.50" + _hover={ + loading + ? undefined + : { + backgroundColor: "gray.100", + } + } + spacing={8} + boxShadow="0 0 40px 4px rgba(0, 0, 0, 0.1);" + cursor="pointer" + opacity={loading ? 0.5 : 1} + > + + + + {label} + + + + {desciption} + + + + ); +}; diff --git a/src/components/RefinePromptModal/RefinePromptModal.tsx b/src/components/RefinePromptModal/RefinePromptModal.tsx index 8bf206f..7cd13b5 100644 --- a/src/components/RefinePromptModal/RefinePromptModal.tsx +++ b/src/components/RefinePromptModal/RefinePromptModal.tsx @@ -11,17 +11,20 @@ import { Text, Spinner, HStack, - InputGroup, - InputRightElement, Icon, + SimpleGrid, } from "@chakra-ui/react"; -import { IoMdSend } from "react-icons/io"; +import { BsStars } from "react-icons/bs"; +import { VscJson } from "react-icons/vsc"; +import { TfiThought } from "react-icons/tfi"; import { api } from "~/utils/api"; import { useHandledAsyncCallback } from "~/utils/hooks"; import { type PromptVariant } from "@prisma/client"; import { useState } from "react"; import CompareFunctions from "./CompareFunctions"; -import AutoResizeTextArea from "../AutoResizeTextArea"; +import { CustomInstructionsInput } from "./CustomInstructionsInput"; +import { type RefineOptionLabel, refineOptions } from "./refineOptions"; +import { RefineOption } from "./RefineOption"; export const RefinePromptModal = ({ variant, @@ -36,13 +39,22 @@ export const RefinePromptModal = ({ api.promptVariants.getRefinedPromptFn.useMutation(); const [instructions, setInstructions] = useState(""); - const [getRefinedPromptFn, refiningInProgress] = useHandledAsyncCallback(async () => { - if (!variant.experimentId) return; - await getRefinedPromptMutateAsync({ - id: variant.id, - instructions, - }); - }, [getRefinedPromptMutateAsync, onClose, variant, instructions]); + const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState< + RefineOptionLabel | undefined + >(undefined); + + const [getRefinedPromptFn, refiningInProgress] = useHandledAsyncCallback( + async (label?: RefineOptionLabel) => { + if (!variant.experimentId) return; + const updatedInstructions = label ? refineOptions[label].instructions : instructions; + setActiveRefineOptionLabel(label); + await getRefinedPromptMutateAsync({ + id: variant.id, + instructions: updatedInstructions, + }); + }, + [getRefinedPromptMutateAsync, onClose, variant, instructions, setActiveRefineOptionLabel], + ); const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation(); @@ -60,65 +72,42 @@ export const RefinePromptModal = ({ - Refine with GPT-4 + + + + Refine with GPT-4 + + - - - setInstructions(e.target.value)} - onKeyDown={(e) => { - if (e.key === "Enter" && !e.metaKey && !e.ctrlKey && !e.shiftKey) { - e.preventDefault(); - e.currentTarget.blur(); - getRefinedPromptFn(); - } - }} - placeholder="Send instructions" - py={4} - pl={4} - pr={12} - colorScheme="orange" - borderColor="gray.300" - borderWidth={1} - _hover={{ - borderColor: "gray.300", - }} - _focus={{ - borderColor: "gray.300", - }} - /> - - - - + loading={refiningInProgress} + /> + + + + or + + + - +