Compare commits
	
		
			7 Commits
		
	
	
		
			empty-scen
			...
			change-mod
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | 01343efb6a | ||
|   | c7aaaea426 | ||
|   | 332e7afb0c | ||
|   | fe08e29f47 | ||
|   | 89ce730e52 | ||
|   | ad87c1b2eb | ||
|   | 58ddc72cbb | 
| @@ -1,4 +1,4 @@ | ||||
| import { HStack, VStack, useBreakpointValue } from "@chakra-ui/react"; | ||||
| import { type StackProps, VStack, useBreakpointValue } from "@chakra-ui/react"; | ||||
| import React from "react"; | ||||
| import DiffViewer, { DiffMethod } from "react-diff-viewer"; | ||||
| import Prism from "prismjs"; | ||||
| @@ -19,10 +19,15 @@ const highlightSyntax = (str: string) => { | ||||
| const CompareFunctions = ({ | ||||
|   originalFunction, | ||||
|   newFunction = "", | ||||
|   leftTitle = "Original", | ||||
|   rightTitle = "Modified", | ||||
|   ...props | ||||
| }: { | ||||
|   originalFunction: string; | ||||
|   newFunction?: string; | ||||
| }) => { | ||||
|   leftTitle?: string; | ||||
|   rightTitle?: string; | ||||
| } & StackProps) => { | ||||
|   const showSplitView = useBreakpointValue( | ||||
|     { | ||||
|       base: false, | ||||
| @@ -34,22 +39,20 @@ const CompareFunctions = ({ | ||||
|   ); | ||||
|  | ||||
|   return ( | ||||
|     <HStack w="full" spacing={5}> | ||||
|       <VStack w="full" spacing={4} maxH="40vh" fontSize={12} lineHeight={1} overflowY="auto"> | ||||
|         <DiffViewer | ||||
|           oldValue={originalFunction} | ||||
|           newValue={newFunction || originalFunction} | ||||
|           splitView={showSplitView} | ||||
|           hideLineNumbers={!showSplitView} | ||||
|           leftTitle="Original" | ||||
|           rightTitle={newFunction ? "Modified" : "Unmodified"} | ||||
|           disableWordDiff={true} | ||||
|           compareMethod={DiffMethod.CHARS} | ||||
|           renderContent={highlightSyntax} | ||||
|           showDiffOnly={false} | ||||
|         /> | ||||
|       </VStack> | ||||
|     </HStack> | ||||
|     <VStack w="full" spacing={4} fontSize={12} lineHeight={1} overflowY="auto" {...props}> | ||||
|       <DiffViewer | ||||
|         oldValue={originalFunction} | ||||
|         newValue={newFunction || originalFunction} | ||||
|         splitView={showSplitView} | ||||
|         hideLineNumbers={!showSplitView} | ||||
|         leftTitle={leftTitle} | ||||
|         rightTitle={rightTitle} | ||||
|         disableWordDiff={true} | ||||
|         compareMethod={DiffMethod.CHARS} | ||||
|         renderContent={highlightSyntax} | ||||
|         showDiffOnly={false} | ||||
|       /> | ||||
|     </VStack> | ||||
|   ); | ||||
| }; | ||||
|  | ||||
|   | ||||
| @@ -56,7 +56,6 @@ export const CustomInstructionsInput = ({ | ||||
|           minW="unset" | ||||
|           size="sm" | ||||
|           onClick={() => onSubmit()} | ||||
|           disabled={!instructions} | ||||
|           variant={instructions ? "solid" : "ghost"} | ||||
|           mr={4} | ||||
|           borderRadius="8" | ||||
|   | ||||
| @@ -36,25 +36,25 @@ export const RefinePromptModal = ({ | ||||
| }) => { | ||||
|   const utils = api.useContext(); | ||||
|  | ||||
|   const { mutateAsync: getRefinedPromptMutateAsync, data: refinedPromptFn } = | ||||
|     api.promptVariants.getRefinedPromptFn.useMutation(); | ||||
|   const { mutateAsync: getModifiedPromptMutateAsync, data: refinedPromptFn } = | ||||
|     api.promptVariants.getModifiedPromptFn.useMutation(); | ||||
|   const [instructions, setInstructions] = useState<string>(""); | ||||
|  | ||||
|   const [activeRefineOptionLabel, setActiveRefineOptionLabel] = useState< | ||||
|     RefineOptionLabel | undefined | ||||
|   >(undefined); | ||||
|  | ||||
|   const [getRefinedPromptFn, refiningInProgress] = useHandledAsyncCallback( | ||||
|   const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback( | ||||
|     async (label?: RefineOptionLabel) => { | ||||
|       if (!variant.experimentId) return; | ||||
|       const updatedInstructions = label ? refineOptions[label].instructions : instructions; | ||||
|       setActiveRefineOptionLabel(label); | ||||
|       await getRefinedPromptMutateAsync({ | ||||
|       await getModifiedPromptMutateAsync({ | ||||
|         id: variant.id, | ||||
|         instructions: updatedInstructions, | ||||
|       }); | ||||
|     }, | ||||
|     [getRefinedPromptMutateAsync, onClose, variant, instructions, setActiveRefineOptionLabel], | ||||
|     [getModifiedPromptMutateAsync, onClose, variant, instructions, setActiveRefineOptionLabel], | ||||
|   ); | ||||
|  | ||||
|   const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation(); | ||||
| @@ -75,7 +75,11 @@ export const RefinePromptModal = ({ | ||||
|   }, [replaceVariantMutation, variant, onClose, refinedPromptFn]); | ||||
|  | ||||
|   return ( | ||||
|     <Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "7xl" }}> | ||||
|     <Modal | ||||
|       isOpen | ||||
|       onClose={onClose} | ||||
|       size={{ base: "xl", sm: "2xl", md: "3xl", lg: "5xl", xl: "7xl" }} | ||||
|     > | ||||
|       <ModalOverlay /> | ||||
|       <ModalContent w={1200}> | ||||
|         <ModalHeader> | ||||
| @@ -93,15 +97,15 @@ export const RefinePromptModal = ({ | ||||
|                   label="Convert to function call" | ||||
|                   activeLabel={activeRefineOptionLabel} | ||||
|                   icon={VscJson} | ||||
|                   onClick={getRefinedPromptFn} | ||||
|                   loading={refiningInProgress} | ||||
|                   onClick={getModifiedPromptFn} | ||||
|                   loading={modificationInProgress} | ||||
|                 /> | ||||
|                 <RefineOption | ||||
|                   label="Add chain of thought" | ||||
|                   activeLabel={activeRefineOptionLabel} | ||||
|                   icon={TfiThought} | ||||
|                   onClick={getRefinedPromptFn} | ||||
|                   loading={refiningInProgress} | ||||
|                   onClick={getModifiedPromptFn} | ||||
|                   loading={modificationInProgress} | ||||
|                 /> | ||||
|               </SimpleGrid> | ||||
|               <HStack> | ||||
| @@ -110,13 +114,14 @@ export const RefinePromptModal = ({ | ||||
|               <CustomInstructionsInput | ||||
|                 instructions={instructions} | ||||
|                 setInstructions={setInstructions} | ||||
|                 loading={refiningInProgress} | ||||
|                 onSubmit={getRefinedPromptFn} | ||||
|                 loading={modificationInProgress} | ||||
|                 onSubmit={getModifiedPromptFn} | ||||
|               /> | ||||
|             </VStack> | ||||
|             <CompareFunctions | ||||
|               originalFunction={variant.constructFn} | ||||
|               newFunction={isString(refinedPromptFn) ? refinedPromptFn : undefined} | ||||
|               maxH="40vh" | ||||
|             /> | ||||
|           </VStack> | ||||
|         </ModalBody> | ||||
| @@ -124,12 +129,10 @@ export const RefinePromptModal = ({ | ||||
|         <ModalFooter> | ||||
|           <HStack spacing={4}> | ||||
|             <Button | ||||
|               colorScheme="blue" | ||||
|               onClick={replaceVariant} | ||||
|               minW={24} | ||||
|               disabled={replacementInProgress || !refinedPromptFn} | ||||
|               _disabled={{ | ||||
|                 bgColor: "blue.500", | ||||
|               }} | ||||
|               isDisabled={replacementInProgress || !refinedPromptFn} | ||||
|             > | ||||
|               {replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>} | ||||
|             </Button> | ||||
|   | ||||
| @@ -12,7 +12,7 @@ export const refineOptions: Record< | ||||
|        | ||||
|     This is what a prompt looks like before adding chain of thought: | ||||
|      | ||||
|     prompt = { | ||||
|     definePrompt("openai/ChatCompletion", { | ||||
|         model: "gpt-4", | ||||
|         stream: true, | ||||
|         messages: [ | ||||
| @@ -25,11 +25,11 @@ export const refineOptions: Record< | ||||
|             content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`, | ||||
|             }, | ||||
|         ], | ||||
|     }; | ||||
|     }); | ||||
|  | ||||
|     This is what one looks like after adding chain of thought: | ||||
|  | ||||
|     prompt = { | ||||
|     definePrompt("openai/ChatCompletion", { | ||||
|         model: "gpt-4", | ||||
|         stream: true, | ||||
|         messages: [ | ||||
| @@ -42,13 +42,13 @@ export const refineOptions: Record< | ||||
|             content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral". Explain your answer before you give a score, then return the score on a new line.\`, | ||||
|             }, | ||||
|         ], | ||||
|     }; | ||||
|     }); | ||||
|  | ||||
|     Here's another example: | ||||
|  | ||||
|     Before: | ||||
|  | ||||
|     prompt = { | ||||
|     definePrompt("openai/ChatCompletion", { | ||||
|         model: "gpt-3.5-turbo", | ||||
|         messages: [ | ||||
|           { | ||||
| @@ -78,11 +78,11 @@ export const refineOptions: Record< | ||||
|         function_call: { | ||||
|           name: "score_post", | ||||
|         }, | ||||
|       }; | ||||
|       }); | ||||
|  | ||||
|     After: | ||||
|  | ||||
|     prompt = { | ||||
|     definePrompt("openai/ChatCompletion", { | ||||
|         model: "gpt-3.5-turbo", | ||||
|         messages: [ | ||||
|           { | ||||
| @@ -115,7 +115,7 @@ export const refineOptions: Record< | ||||
|         function_call: { | ||||
|           name: "score_post", | ||||
|         }, | ||||
|       }; | ||||
|       }); | ||||
|  | ||||
|     Add chain of thought to the original prompt.`, | ||||
|   }, | ||||
| @@ -125,7 +125,7 @@ export const refineOptions: Record< | ||||
|      | ||||
|     This is what a prompt looks like before adding a function: | ||||
|      | ||||
|     prompt = { | ||||
|     definePrompt("openai/ChatCompletion", { | ||||
|       model: "gpt-4", | ||||
|       stream: true, | ||||
|       messages: [ | ||||
| @@ -138,11 +138,11 @@ export const refineOptions: Record< | ||||
|           content: \`This is the user's message: \${scenario.user_message}. Return "positive" or "negative" or "neutral"\`, | ||||
|         }, | ||||
|       ], | ||||
|     }; | ||||
|     }); | ||||
|    | ||||
|     This is what one looks like after adding a function: | ||||
|    | ||||
|     prompt = { | ||||
|     definePrompt("openai/ChatCompletion", { | ||||
|       model: "gpt-4", | ||||
|       stream: true, | ||||
|       messages: [ | ||||
| @@ -172,13 +172,13 @@ export const refineOptions: Record< | ||||
|       function_call: { | ||||
|         name: "extract_sentiment", | ||||
|       }, | ||||
|     }; | ||||
|     }); | ||||
|  | ||||
|     Here's another example of adding a function: | ||||
|  | ||||
|     Before: | ||||
|  | ||||
|     prompt = { | ||||
|     definePrompt("openai/ChatCompletion", { | ||||
|         model: "gpt-3.5-turbo", | ||||
|         messages: [ | ||||
|           { | ||||
| @@ -196,11 +196,11 @@ export const refineOptions: Record< | ||||
|           }, | ||||
|         ], | ||||
|         temperature: 0, | ||||
|     }; | ||||
|     }); | ||||
|  | ||||
|     After: | ||||
|  | ||||
|     prompt = { | ||||
|     definePrompt("openai/ChatCompletion", { | ||||
|         model: "gpt-3.5-turbo", | ||||
|         messages: [ | ||||
|           { | ||||
| @@ -230,7 +230,7 @@ export const refineOptions: Record< | ||||
|         function_call: { | ||||
|           name: "score_post", | ||||
|         }, | ||||
|       }; | ||||
|       }); | ||||
|      | ||||
|     Add an OpenAI function that takes one or more nested parameters that match the expected output from this prompt.`, | ||||
|   }, | ||||
|   | ||||
| @@ -20,36 +20,60 @@ import { ModelStatsCard } from "./ModelStatsCard"; | ||||
| import { SelectModelSearch } from "./SelectModelSearch"; | ||||
| import { api } from "~/utils/api"; | ||||
| import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; | ||||
| import CompareFunctions from "../RefinePromptModal/CompareFunctions"; | ||||
| import { type PromptVariant } from "@prisma/client"; | ||||
| import { isObject, isString } from "lodash-es"; | ||||
|  | ||||
| export const SelectModelModal = ({ | ||||
|   originalModel, | ||||
|   variantId, | ||||
|   variant, | ||||
|   onClose, | ||||
| }: { | ||||
|   originalModel: SupportedModel; | ||||
|   variantId: string; | ||||
|   variant: PromptVariant; | ||||
|   onClose: () => void; | ||||
| }) => { | ||||
|   const originalModel = variant.model as SupportedModel; | ||||
|   const [selectedModel, setSelectedModel] = useState<SupportedModel>(originalModel); | ||||
|   const [convertedModel, setConvertedModel] = useState<SupportedModel | undefined>(undefined); | ||||
|   const utils = api.useContext(); | ||||
|  | ||||
|   const experiment = useExperiment(); | ||||
|  | ||||
|   const createMutation = api.promptVariants.create.useMutation(); | ||||
|   const { mutateAsync: getModifiedPromptMutateAsync, data: modifiedPromptFn } = | ||||
|     api.promptVariants.getModifiedPromptFn.useMutation(); | ||||
|  | ||||
|   const [createNewVariant, creationInProgress] = useHandledAsyncCallback(async () => { | ||||
|     if (!experiment?.data?.id) return; | ||||
|     await createMutation.mutateAsync({ | ||||
|       experimentId: experiment?.data?.id, | ||||
|       variantId, | ||||
|   const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(async () => { | ||||
|     if (!experiment) return; | ||||
|  | ||||
|     await getModifiedPromptMutateAsync({ | ||||
|       id: variant.id, | ||||
|       newModel: selectedModel, | ||||
|     }); | ||||
|     setConvertedModel(selectedModel); | ||||
|   }, [getModifiedPromptMutateAsync, onClose, experiment, variant, selectedModel]); | ||||
|  | ||||
|   const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation(); | ||||
|  | ||||
|   const [replaceVariant, replacementInProgress] = useHandledAsyncCallback(async () => { | ||||
|     if ( | ||||
|       !variant.experimentId || | ||||
|       !modifiedPromptFn || | ||||
|       (isObject(modifiedPromptFn) && "status" in modifiedPromptFn) | ||||
|     ) | ||||
|       return; | ||||
|     await replaceVariantMutation.mutateAsync({ | ||||
|       id: variant.id, | ||||
|       constructFn: modifiedPromptFn, | ||||
|     }); | ||||
|     await utils.promptVariants.list.invalidate(); | ||||
|     onClose(); | ||||
|   }, [createMutation, experiment?.data?.id, variantId, onClose]); | ||||
|   }, [replaceVariantMutation, variant, onClose, modifiedPromptFn]); | ||||
|  | ||||
|   return ( | ||||
|     <Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "3xl" }}> | ||||
|     <Modal | ||||
|       isOpen | ||||
|       onClose={onClose} | ||||
|       size={{ base: "xl", sm: "2xl", md: "3xl", lg: "5xl", xl: "7xl" }} | ||||
|     > | ||||
|       <ModalOverlay /> | ||||
|       <ModalContent w={1200}> | ||||
|         <ModalHeader> | ||||
| @@ -66,18 +90,36 @@ export const SelectModelModal = ({ | ||||
|               <ModelStatsCard label="New Model" model={selectedModel} /> | ||||
|             )} | ||||
|             <SelectModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} /> | ||||
|             {isString(modifiedPromptFn) && ( | ||||
|               <CompareFunctions | ||||
|                 originalFunction={variant.constructFn} | ||||
|                 newFunction={modifiedPromptFn} | ||||
|                 leftTitle={originalModel} | ||||
|                 rightTitle={convertedModel} | ||||
|               /> | ||||
|             )} | ||||
|           </VStack> | ||||
|         </ModalBody> | ||||
|  | ||||
|         <ModalFooter> | ||||
|           <Button | ||||
|             colorScheme="blue" | ||||
|             onClick={createNewVariant} | ||||
|             minW={24} | ||||
|             disabled={originalModel === selectedModel} | ||||
|           > | ||||
|             {creationInProgress ? <Spinner boxSize={4} /> : <Text>Continue</Text>} | ||||
|           </Button> | ||||
|           <HStack> | ||||
|             <Button | ||||
|               colorScheme="gray" | ||||
|               onClick={getModifiedPromptFn} | ||||
|               minW={24} | ||||
|               isDisabled={originalModel === selectedModel || modificationInProgress} | ||||
|             > | ||||
|               {modificationInProgress ? <Spinner boxSize={4} /> : <Text>Convert</Text>} | ||||
|             </Button> | ||||
|             <Button | ||||
|               colorScheme="blue" | ||||
|               onClick={replaceVariant} | ||||
|               minW={24} | ||||
|               isDisabled={!convertedModel || modificationInProgress || replacementInProgress} | ||||
|             > | ||||
|               {replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>} | ||||
|             </Button> | ||||
|           </HStack> | ||||
|         </ModalFooter> | ||||
|       </ModalContent> | ||||
|     </Modal> | ||||
|   | ||||
| @@ -18,7 +18,6 @@ import { useState } from "react"; | ||||
| import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal"; | ||||
| import { RiExchangeFundsFill } from "react-icons/ri"; | ||||
| import { SelectModelModal } from "../SelectModelModal/SelectModelModal"; | ||||
| import { type SupportedModel } from "~/server/types"; | ||||
|  | ||||
| export default function VariantHeaderMenuButton({ | ||||
|   variant, | ||||
| @@ -99,11 +98,7 @@ export default function VariantHeaderMenuButton({ | ||||
|         </MenuList> | ||||
|       </Menu> | ||||
|       {selectModelModalOpen && ( | ||||
|         <SelectModelModal | ||||
|           originalModel={variant.model as SupportedModel} | ||||
|           variantId={variant.id} | ||||
|           onClose={() => setSelectModelModalOpen(false)} | ||||
|         /> | ||||
|         <SelectModelModal variant={variant} onClose={() => setSelectModelModalOpen(false)} /> | ||||
|       )} | ||||
|       {refinePromptModalOpen && ( | ||||
|         <RefinePromptModal variant={variant} onClose={() => setRefinePromptModalOpen(false)} /> | ||||
|   | ||||
| @@ -284,11 +284,12 @@ export const promptVariantsRouter = createTRPCRouter({ | ||||
|       return updatedPromptVariant; | ||||
|     }), | ||||
|  | ||||
|   getRefinedPromptFn: protectedProcedure | ||||
|   getModifiedPromptFn: protectedProcedure | ||||
|     .input( | ||||
|       z.object({ | ||||
|         id: z.string(), | ||||
|         instructions: z.string(), | ||||
|         instructions: z.string().optional(), | ||||
|         newModel: z.string().optional(), | ||||
|       }), | ||||
|     ) | ||||
|     .mutation(async ({ input, ctx }) => { | ||||
| @@ -307,7 +308,7 @@ export const promptVariantsRouter = createTRPCRouter({ | ||||
|  | ||||
|       const promptConstructionFn = await deriveNewConstructFn( | ||||
|         existing, | ||||
|         constructedPrompt.model as SupportedModel, | ||||
|         input.newModel as SupportedModel | undefined, | ||||
|         input.instructions, | ||||
|       ); | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user