Refine prompt (#63)

* Remove unused ScenarioVariantCell fields

* Refine deriveNewConstructFn

* Fix prettier

* Remove migration script

* Add refine modal

* Fix prettier

* Fix diff checker overflow

* Decrease diff height
This commit is contained in:
arcticfly
2023-07-19 15:31:40 -07:00
committed by GitHub
parent 58892d8b63
commit 4c97b9f147
10 changed files with 550 additions and 155 deletions

View File

@@ -1,4 +1,4 @@
import { Box, Button, Icon, Spinner } from "@chakra-ui/react";
import { Box, Button, Icon, Spinner, Text } from "@chakra-ui/react";
import { BsPlus } from "react-icons/bs";
import { api } from "~/utils/api";
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
@@ -34,7 +34,7 @@ export default function NewVariantButton() {
minH={headerMinHeight}
>
<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />
Add Variant
<Text display={{ base: "none", md: "flex" }}>Add Variant</Text>
</Button>
);
}

View File

@@ -22,7 +22,13 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
setIsChanged(currentFn.length > 0 && currentFn !== lastSavedFn);
}, [lastSavedFn]);
useEffect(checkForChanges, [checkForChanges, lastSavedFn]);
const matchUpdatedSavedFn = useCallback(() => {
if (!editorRef.current) return;
editorRef.current.setValue(lastSavedFn);
setIsChanged(false);
}, [lastSavedFn]);
useEffect(matchUpdatedSavedFn, [matchUpdatedSavedFn, lastSavedFn]);
const replaceVariant = api.promptVariants.replaceVariant.useMutation();
const utils = api.useContext();

View File

@@ -0,0 +1,45 @@
import { HStack, VStack } from "@chakra-ui/react";
import React from "react";
import DiffViewer, { DiffMethod } from "react-diff-viewer";
import Prism from "prismjs";
import "prismjs/components/prism-javascript";
import "prismjs/themes/prism.css"; // choose a theme you like
const CompareFunctions = ({
originalFunction,
newFunction = "",
}: {
originalFunction: string;
newFunction?: string;
}) => {
console.log("newFunction", newFunction);
const highlightSyntax = (str: string) => {
let highlighted;
try {
highlighted = Prism.highlight(str, Prism.languages.javascript as Prism.Grammar, "javascript");
} catch (e) {
console.error("Error highlighting:", e);
highlighted = str;
}
return <pre style={{ display: "inline" }} dangerouslySetInnerHTML={{ __html: highlighted }} />;
};
return (
<HStack w="full" spacing={5}>
<VStack w="full" spacing={4} maxH="65vh" fontSize={12} lineHeight={1} overflowY="auto">
<DiffViewer
oldValue={originalFunction}
newValue={newFunction || originalFunction}
splitView={true}
hideLineNumbers={true}
leftTitle="Original"
rightTitle={newFunction ? "Modified" : "Unmodified"}
disableWordDiff={true}
compareMethod={DiffMethod.CHARS}
renderContent={highlightSyntax}
/>
</VStack>
</HStack>
);
};
export default CompareFunctions;

View File

@@ -0,0 +1,103 @@
import {
Button,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
VStack,
Text,
Spinner,
HStack,
} from "@chakra-ui/react";
import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks";
import { type PromptVariant } from "@prisma/client";
import { useState } from "react";
import AutoResizeTextArea from "../AutoResizeTextArea";
import CompareFunctions from "./CompareFunctions";
export const RefinePromptModal = ({
variant,
onClose,
}: {
variant: PromptVariant;
onClose: () => void;
}) => {
const utils = api.useContext();
const { mutateAsync: getRefinedPromptMutateAsync, data: refinedPromptFn } =
api.promptVariants.getRefinedPromptFn.useMutation();
const [instructions, setInstructions] = useState<string>("");
const [getRefinedPromptFn, refiningInProgress] = useHandledAsyncCallback(async () => {
if (!variant.experimentId) return;
await getRefinedPromptMutateAsync({
id: variant.id,
instructions,
});
}, [getRefinedPromptMutateAsync, onClose, variant, instructions]);
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
const [replaceVariant, replacementInProgress] = useHandledAsyncCallback(async () => {
if (!variant.experimentId || !refinedPromptFn) return;
await replaceVariantMutation.mutateAsync({
id: variant.id,
constructFn: refinedPromptFn,
});
await utils.promptVariants.list.invalidate();
onClose();
}, [replaceVariantMutation, variant, onClose, refinedPromptFn]);
return (
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "7xl" }}>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader>Refine Your Prompt</ModalHeader>
<ModalCloseButton />
<ModalBody maxW="unset">
<VStack spacing={8}>
<HStack w="full">
<AutoResizeTextArea
value={instructions}
onChange={(e) => setInstructions(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter" && !e.metaKey && !e.ctrlKey && !e.shiftKey) {
e.preventDefault();
e.currentTarget.blur();
getRefinedPromptFn();
}
}}
placeholder="Use chain of thought"
/>
<Button onClick={getRefinedPromptFn}>
{refiningInProgress ? <Spinner boxSize={4} /> : <Text>Submit</Text>}
</Button>
</HStack>
<CompareFunctions
originalFunction={variant.constructFn}
newFunction={refinedPromptFn}
/>
</VStack>
</ModalBody>
<ModalFooter>
<HStack spacing={4}>
<Button onClick={onClose}>Cancel</Button>
<Button
colorScheme="blue"
onClick={replaceVariant}
minW={24}
disabled={!refinedPromptFn}
>
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
</Button>
</HStack>
</ModalFooter>
</ModalContent>
</Modal>
);
};

View File

@@ -32,18 +32,18 @@ export const SelectModelModal = ({
const experiment = useExperiment();
const duplicateMutation = api.promptVariants.create.useMutation();
const createMutation = api.promptVariants.create.useMutation();
const [createNewVariant, creationInProgress] = useHandledAsyncCallback(async () => {
if (!experiment?.data?.id) return;
await duplicateMutation.mutateAsync({
await createMutation.mutateAsync({
experimentId: experiment?.data?.id,
variantId,
newModel: selectedModel,
});
await utils.promptVariants.list.invalidate();
onClose();
}, [duplicateMutation, experiment?.data?.id, variantId, onClose]);
}, [createMutation, experiment?.data?.id, variantId, onClose]);
return (
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "3xl" }}>

View File

@@ -11,13 +11,15 @@ import {
MenuDivider,
Text,
Spinner,
} from "@chakra-ui/react"; // Changed here
} from "@chakra-ui/react";
import { BsFillTrashFill, BsGear } from "react-icons/bs";
import { FaRegClone } from "react-icons/fa";
import { RiExchangeFundsFill } from "react-icons/ri";
import { AiOutlineDiff } from "react-icons/ai";
import { useState } from "react";
import { SelectModelModal } from "../SelectModelModal/SelectModelModal";
import { type SupportedModel } from "~/server/types";
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
export default function VariantHeaderMenuButton({
variant,
@@ -51,6 +53,7 @@ export default function VariantHeaderMenuButton({
}, [hideMutation, variant.id]);
const [selectModelModalOpen, setSelectModelModalOpen] = useState(false);
const [refinePromptModalOpen, setRefinePromptModalOpen] = useState(false);
return (
<>
@@ -75,6 +78,12 @@ export default function VariantHeaderMenuButton({
>
Change Model
</MenuItem>
<MenuItem
icon={<Icon as={AiOutlineDiff} boxSize={5} />}
onClick={() => setRefinePromptModalOpen(true)}
>
Refine
</MenuItem>
{canHide && (
<>
<MenuDivider />
@@ -97,6 +106,9 @@ export default function VariantHeaderMenuButton({
onClose={() => setSelectModelModalOpen(false)}
/>
)}
{refinePromptModalOpen && (
<RefinePromptModal variant={variant} onClose={() => setRefinePromptModalOpen(false)} />
)}
</>
);
}