import { Button, Modal, ModalBody, ModalCloseButton, ModalContent, ModalFooter, ModalHeader, ModalOverlay, VStack, Text, Spinner, HStack, Icon, SimpleGrid, } from "@chakra-ui/react"; import { BsStars } from "react-icons/bs"; import { api } from "~/utils/api"; import { useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks"; import { type PromptVariant } from "@prisma/client"; import { useState } from "react"; import CompareFunctions from "./CompareFunctions"; import { CustomInstructionsInput } from "../CustomInstructionsInput"; import { RefineAction } from "./RefineAction"; import { isObject, isString } from "lodash-es"; import { type RefinementAction, type SupportedProvider } from "~/modelProviders/types"; import frontendModelProviders from "~/modelProviders/frontendModelProviders"; export const RefinePromptModal = ({ variant, onClose, }: { variant: PromptVariant; onClose: () => void; }) => { const utils = api.useContext(); const visibleScenarios = useVisibleScenarioIds(); const refinementActions = frontendModelProviders[variant.modelProvider as SupportedProvider].refinementActions || {}; const { mutateAsync: getModifiedPromptMutateAsync, data: refinedPromptFn } = api.promptVariants.getModifiedPromptFn.useMutation(); const [instructions, setInstructions] = useState(""); const [activeRefineActionLabel, setActiveRefineActionLabel] = useState( undefined, ); const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback( async (label?: string) => { if (!variant.experimentId) return; const updatedInstructions = label ? (refinementActions[label] as RefinementAction).instructions : instructions; setActiveRefineActionLabel(label); await getModifiedPromptMutateAsync({ id: variant.id, instructions: updatedInstructions, }); }, [getModifiedPromptMutateAsync, onClose, variant, instructions, setActiveRefineActionLabel], ); const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation(); const [replaceVariant, replacementInProgress] = useHandledAsyncCallback(async () => { if ( !variant.experimentId || !refinedPromptFn || (isObject(refinedPromptFn) && "status" in refinedPromptFn) ) return; await replaceVariantMutation.mutateAsync({ id: variant.id, promptConstructor: refinedPromptFn, streamScenarios: visibleScenarios, }); await utils.promptVariants.list.invalidate(); onClose(); }, [replaceVariantMutation, variant, onClose, refinedPromptFn]); return ( Refine with GPT-4 {Object.keys(refinementActions).length && ( <> {Object.keys(refinementActions).map((label) => ( ))} or )} getModifiedPromptFn()} /> ); };