Set contents of editor for refinement modals instead of saving on server (#199)

* Set contents of editor for refinement modals instead of saving on server

* Show New Experiment text on mobile
This commit is contained in:
arcticfly
2023-08-27 19:52:25 -06:00
committed by GitHub
parent 28713fb3ef
commit fa87887e91
7 changed files with 105 additions and 85 deletions

View File

@@ -1,3 +1,4 @@
import { useState, useMemo, useCallback } from "react";
import { import {
Button, Button,
HStack, HStack,
@@ -14,16 +15,18 @@ import {
VStack, VStack,
} from "@chakra-ui/react"; } from "@chakra-ui/react";
import { type PromptVariant } from "@prisma/client"; import { type PromptVariant } from "@prisma/client";
import { isObject, isString } from "lodash-es"; import { isString } from "lodash-es";
import { useState } from "react";
import { RiExchangeFundsFill } from "react-icons/ri"; import { RiExchangeFundsFill } from "react-icons/ri";
import { type ProviderModel } from "~/modelProviders/types"; import { type ProviderModel } from "~/modelProviders/types";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import { lookupModel, modelLabel } from "~/utils/utils"; import { lookupModel, modelLabel } from "~/utils/utils";
import CompareFunctions from "../RefinePromptModal/CompareFunctions"; import CompareFunctions from "../RefinePromptModal/CompareFunctions";
import { ModelSearch } from "./ModelSearch"; import { ModelSearch } from "./ModelSearch";
import { ModelStatsCard } from "./ModelStatsCard"; import { ModelStatsCard } from "./ModelStatsCard";
import { maybeReportError } from "~/utils/errorHandling/maybeReportError";
import { useAppStore } from "~/state/store";
export const ChangeModelModal = ({ export const ChangeModelModal = ({
variant, variant,
@@ -32,48 +35,43 @@ export const ChangeModelModal = ({
variant: PromptVariant; variant: PromptVariant;
onClose: () => void; onClose: () => void;
}) => { }) => {
const editorOptionsMap = useAppStore((s) => s.sharedVariantEditor.editorOptionsMap);
const originalPromptFn = useMemo(
() => editorOptionsMap[variant.uiId]?.getContent() || "",
[editorOptionsMap, variant.uiId],
);
const originalModel = lookupModel(variant.modelProvider, variant.model); const originalModel = lookupModel(variant.modelProvider, variant.model);
const [selectedModel, setSelectedModel] = useState({ const [selectedModel, setSelectedModel] = useState({
provider: variant.modelProvider, provider: variant.modelProvider,
model: variant.model, model: variant.model,
} as ProviderModel); } as ProviderModel);
const [convertedModel, setConvertedModel] = useState<ProviderModel | undefined>(); const [convertedModel, setConvertedModel] = useState<ProviderModel | undefined>();
const visibleScenarios = useVisibleScenarioIds(); const [modifiedPromptFn, setModifiedPromptFn] = useState<string>();
const utils = api.useContext();
const experiment = useExperiment(); const experiment = useExperiment();
const { mutateAsync: getModifiedPromptMutateAsync, data: modifiedPromptFn } = const { mutateAsync: getModifiedPromptMutateAsync } =
api.promptVariants.getModifiedPromptFn.useMutation(); api.promptVariants.getModifiedPromptFn.useMutation();
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(async () => { const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(async () => {
if (!experiment) return; if (!experiment) return;
await getModifiedPromptMutateAsync({ const resp = await getModifiedPromptMutateAsync({
id: variant.id, id: variant.id,
originalPromptFn,
newModel: selectedModel, newModel: selectedModel,
}); });
if (maybeReportError(resp)) return;
setModifiedPromptFn(resp.payload);
setConvertedModel(selectedModel); setConvertedModel(selectedModel);
}, [getModifiedPromptMutateAsync, onClose, experiment, variant, selectedModel]); }, [getModifiedPromptMutateAsync, onClose, experiment, variant, selectedModel]);
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation(); const replaceVariant = useCallback(() => {
if (!modifiedPromptFn) return;
const [replaceVariant, replacementInProgress] = useHandledAsyncCallback(async () => { editorOptionsMap[variant.uiId]?.setContent(modifiedPromptFn);
if (
!variant.experimentId ||
!modifiedPromptFn ||
(isObject(modifiedPromptFn) && "status" in modifiedPromptFn)
)
return;
await replaceVariantMutation.mutateAsync({
id: variant.id,
promptConstructor: modifiedPromptFn,
streamScenarios: visibleScenarios,
});
await utils.promptVariants.list.invalidate();
onClose(); onClose();
}, [replaceVariantMutation, variant, onClose, modifiedPromptFn]); }, [variant.uiId, editorOptionsMap, onClose, modifiedPromptFn]);
const originalLabel = modelLabel(variant.modelProvider, variant.model); const originalLabel = modelLabel(variant.modelProvider, variant.model);
const selectedLabel = modelLabel(selectedModel.provider, selectedModel.model); const selectedLabel = modelLabel(selectedModel.provider, selectedModel.model);
@@ -130,9 +128,9 @@ export const ChangeModelModal = ({
colorScheme="blue" colorScheme="blue"
onClick={replaceVariant} onClick={replaceVariant}
minW={24} minW={24}
isDisabled={!convertedModel || modificationInProgress || replacementInProgress} isDisabled={!convertedModel || modificationInProgress}
> >
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>} Accept
</Button> </Button>
</HStack> </HStack>
</ModalFooter> </ModalFooter>

View File

@@ -10,7 +10,7 @@ import {
} from "@chakra-ui/react"; } from "@chakra-ui/react";
import { useCallback, useEffect, useRef, useState } from "react"; import { useCallback, useEffect, useRef, useState } from "react";
import { FiMaximize, FiMinimize } from "react-icons/fi"; import { FiMaximize, FiMinimize } from "react-icons/fi";
import { editorBackground } from "~/state/sharedVariantEditor.slice"; import { type CreatedEditor, editorBackground } from "~/state/sharedVariantEditor.slice";
import { useAppStore } from "~/state/store"; import { useAppStore } from "~/state/store";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { import {
@@ -24,8 +24,10 @@ import { type PromptVariant } from "./types";
export default function VariantEditor(props: { variant: PromptVariant }) { export default function VariantEditor(props: { variant: PromptVariant }) {
const { canModify } = useExperimentAccess(); const { canModify } = useExperimentAccess();
const monaco = useAppStore.use.sharedVariantEditor.monaco(); const monaco = useAppStore.use.sharedVariantEditor.monaco();
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null); const updateOptionsForEditor = useAppStore.use.sharedVariantEditor.updateOptionsForEditor();
const editorRef = useRef<CreatedEditor | null>(null);
const containerRef = useRef<HTMLDivElement | null>(null); const containerRef = useRef<HTMLDivElement | null>(null);
const lastSavedFnRef = useRef(props.variant.promptConstructor);
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`); const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
const [isChanged, setIsChanged] = useState(false); const [isChanged, setIsChanged] = useState(false);
@@ -48,22 +50,18 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
}, [isFullscreen, toggleFullscreen]); }, [isFullscreen, toggleFullscreen]);
const lastSavedFn = props.variant.promptConstructor; const lastSavedFn = props.variant.promptConstructor;
useEffect(() => {
// Store in ref so that we can access it dynamically
lastSavedFnRef.current = lastSavedFn;
}, [lastSavedFn]);
const modifierKey = useModifierKeyLabel(); const modifierKey = useModifierKeyLabel();
const checkForChanges = useCallback(() => { const checkForChanges = useCallback(() => {
if (!editorRef.current) return; if (!editorRef.current) return;
const currentFn = editorRef.current.getValue(); const currentFn = editorRef.current.getValue();
setIsChanged(currentFn.length > 0 && currentFn !== lastSavedFn); setIsChanged(currentFn.length > 0 && currentFn !== lastSavedFnRef.current);
}, [lastSavedFn]); }, [editorRef]);
const matchUpdatedSavedFn = useCallback(() => {
if (!editorRef.current) return;
editorRef.current.setValue(lastSavedFn);
setIsChanged(false);
}, [lastSavedFn]);
useEffect(matchUpdatedSavedFn, [matchUpdatedSavedFn, lastSavedFn]);
const replaceVariant = api.promptVariants.replaceVariant.useMutation(); const replaceVariant = api.promptVariants.replaceVariant.useMutation();
const utils = api.useContext(); const utils = api.useContext();
@@ -136,6 +134,11 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
readOnly: !canModify, readOnly: !canModify,
}); });
updateOptionsForEditor(props.variant.uiId, {
getContent: () => editorRef.current?.getValue() || "",
setContent: (content) => editorRef.current?.setValue(content),
});
// Workaround because otherwise the commands only work on whatever // Workaround because otherwise the commands only work on whatever
// editor was loaded on the page last. // editor was loaded on the page last.
// https://github.com/microsoft/monaco-editor/issues/2947#issuecomment-1422265201 // https://github.com/microsoft/monaco-editor/issues/2947#issuecomment-1422265201
@@ -155,7 +158,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
}); });
}); });
editorRef.current.onDidChangeModelContent(checkForChanges); const checkForChangesListener = editorRef.current.onDidChangeModelContent(checkForChanges);
const resizeObserver = new ResizeObserver(() => { const resizeObserver = new ResizeObserver(() => {
editorRef.current?.layout(); editorRef.current?.layout();
@@ -164,6 +167,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
return () => { return () => {
resizeObserver.disconnect(); resizeObserver.disconnect();
checkForChangesListener.dispose();
editorRef.current?.dispose(); editorRef.current?.dispose();
}; };
} }
@@ -171,7 +175,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
// We intentionally skip the onSave and props.savedConfig dependencies here because // We intentionally skip the onSave and props.savedConfig dependencies here because
// we don't want to re-render the editor from scratch // we don't want to re-render the editor from scratch
/* eslint-disable-next-line react-hooks/exhaustive-deps */ /* eslint-disable-next-line react-hooks/exhaustive-deps */
}, [monaco, editorId]); }, [monaco, editorId, updateOptionsForEditor]);
useEffect(() => { useEffect(() => {
if (!editorRef.current) return; if (!editorRef.current) return;

View File

@@ -1,3 +1,4 @@
import { useState, useMemo, useCallback } from "react";
import { import {
Button, Button,
Modal, Modal,
@@ -9,22 +10,23 @@ import {
ModalOverlay, ModalOverlay,
VStack, VStack,
Text, Text,
Spinner,
HStack, HStack,
Icon, Icon,
SimpleGrid, SimpleGrid,
} from "@chakra-ui/react"; } from "@chakra-ui/react";
import { BsStars } from "react-icons/bs"; import { BsStars } from "react-icons/bs";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks"; import { useHandledAsyncCallback } from "~/utils/hooks";
import { type PromptVariant } from "@prisma/client"; import { type PromptVariant } from "@prisma/client";
import { useState } from "react";
import CompareFunctions from "./CompareFunctions"; import CompareFunctions from "./CompareFunctions";
import { CustomInstructionsInput } from "../CustomInstructionsInput"; import { CustomInstructionsInput } from "../CustomInstructionsInput";
import { RefineAction } from "./RefineAction"; import { RefineAction } from "./RefineAction";
import { isObject, isString } from "lodash-es"; import { isString } from "lodash-es";
import { type RefinementAction, type SupportedProvider } from "~/modelProviders/types"; import { type RefinementAction, type SupportedProvider } from "~/modelProviders/types";
import frontendModelProviders from "~/modelProviders/frontendModelProviders"; import frontendModelProviders from "~/modelProviders/frontendModelProviders";
import { useAppStore } from "~/state/store";
import { maybeReportError } from "~/utils/errorHandling/maybeReportError";
export const RefinePromptModal = ({ export const RefinePromptModal = ({
variant, variant,
@@ -33,19 +35,23 @@ export const RefinePromptModal = ({
variant: PromptVariant; variant: PromptVariant;
onClose: () => void; onClose: () => void;
}) => { }) => {
const utils = api.useContext(); const editorOptionsMap = useAppStore((s) => s.sharedVariantEditor.editorOptionsMap);
const visibleScenarios = useVisibleScenarioIds(); const originalPromptFn = useMemo(
() => editorOptionsMap[variant.uiId]?.getContent() || "",
[editorOptionsMap, variant.uiId],
);
const refinementActions = const refinementActions =
frontendModelProviders[variant.modelProvider as SupportedProvider].refinementActions || {}; frontendModelProviders[variant.modelProvider as SupportedProvider].refinementActions || {};
const { mutateAsync: getModifiedPromptMutateAsync, data: refinedPromptFn } = const { mutateAsync: getModifiedPromptMutateAsync } =
api.promptVariants.getModifiedPromptFn.useMutation(); api.promptVariants.getModifiedPromptFn.useMutation();
const [instructions, setInstructions] = useState<string>(""); const [instructions, setInstructions] = useState<string>("");
const [activeRefineActionLabel, setActiveRefineActionLabel] = useState<string | undefined>( const [activeRefineActionLabel, setActiveRefineActionLabel] = useState<string | undefined>(
undefined, undefined,
); );
const [refinedPromptFn, setRefinedPromptFn] = useState<string>();
const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback( const [getModifiedPromptFn, modificationInProgress] = useHandledAsyncCallback(
async (label?: string) => { async (label?: string) => {
@@ -54,31 +60,22 @@ export const RefinePromptModal = ({
? (refinementActions[label] as RefinementAction).instructions ? (refinementActions[label] as RefinementAction).instructions
: instructions; : instructions;
setActiveRefineActionLabel(label); setActiveRefineActionLabel(label);
await getModifiedPromptMutateAsync({ const resp = await getModifiedPromptMutateAsync({
id: variant.id, id: variant.id,
originalPromptFn,
instructions: updatedInstructions, instructions: updatedInstructions,
}); });
if (maybeReportError(resp)) return;
setRefinedPromptFn(resp.payload);
}, },
[getModifiedPromptMutateAsync, onClose, variant, instructions, setActiveRefineActionLabel], [getModifiedPromptMutateAsync, onClose, variant, instructions, setActiveRefineActionLabel],
); );
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation(); const replaceVariant = useCallback(() => {
if (!refinedPromptFn) return;
const [replaceVariant, replacementInProgress] = useHandledAsyncCallback(async () => { editorOptionsMap[variant.uiId]?.setContent(refinedPromptFn);
if (
!variant.experimentId ||
!refinedPromptFn ||
(isObject(refinedPromptFn) && "status" in refinedPromptFn)
)
return;
await replaceVariantMutation.mutateAsync({
id: variant.id,
promptConstructor: refinedPromptFn,
streamScenarios: visibleScenarios,
});
await utils.promptVariants.list.invalidate();
onClose(); onClose();
}, [replaceVariantMutation, variant, onClose, refinedPromptFn]); }, [variant.uiId, editorOptionsMap, onClose, refinedPromptFn]);
return ( return (
<Modal <Modal
@@ -126,7 +123,7 @@ export const RefinePromptModal = ({
/> />
</VStack> </VStack>
<CompareFunctions <CompareFunctions
originalFunction={variant.promptConstructor} originalFunction={originalPromptFn}
newFunction={isString(refinedPromptFn) ? refinedPromptFn : undefined} newFunction={isString(refinedPromptFn) ? refinedPromptFn : undefined}
maxH="40vh" maxH="40vh"
/> />
@@ -139,9 +136,9 @@ export const RefinePromptModal = ({
colorScheme="blue" colorScheme="blue"
onClick={replaceVariant} onClick={replaceVariant}
minW={24} minW={24}
isDisabled={replacementInProgress || !refinedPromptFn} isDisabled={!refinedPromptFn}
> >
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>} Accept
</Button> </Button>
</HStack> </HStack>
</ModalFooter> </ModalFooter>

View File

@@ -98,9 +98,7 @@ export const NewExperimentCard = () => {
> >
<VStack align="center" justify="center" w="full" h="full" p={4} onClick={createExperiment}> <VStack align="center" justify="center" w="full" h="full" p={4} onClick={createExperiment}>
<Icon as={isLoading ? Spinner : BsPlusSquare} boxSize={8} /> <Icon as={isLoading ? Spinner : BsPlusSquare} boxSize={8} />
<Text display={{ base: "none", md: "block" }} ml={2}> <Text ml={2}>New Experiment</Text>
New Experiment
</Text>
</VStack> </VStack>
</Card> </Card>
); );

View File

@@ -298,6 +298,7 @@ export const promptVariantsRouter = createTRPCRouter({
.input( .input(
z.object({ z.object({
id: z.string(), id: z.string(),
originalPromptFn: z.string(),
instructions: z.string().optional(), instructions: z.string().optional(),
newModel: z newModel: z
.object({ .object({
@@ -315,22 +316,21 @@ export const promptVariantsRouter = createTRPCRouter({
}); });
await requireCanModifyExperiment(existing.experimentId, ctx); await requireCanModifyExperiment(existing.experimentId, ctx);
const constructedPrompt = await parsePromptConstructor(existing.promptConstructor);
if ("error" in constructedPrompt) {
return error(constructedPrompt.error);
}
const model = input.newModel const model = input.newModel
? modelProviders[input.newModel.provider].models[input.newModel.model] ? modelProviders[input.newModel.provider].models[input.newModel.model]
: undefined; : undefined;
const promptConstructionFn = await deriveNewConstructFn(existing, model, input.instructions); const promptConstructionFn = await deriveNewConstructFn(
existing,
input.originalPromptFn,
model,
input.instructions,
);
// TODO: Validate promptConstructionFn // TODO: Validate promptConstructionFn
// TODO: Record in some sort of history // TODO: Record in some sort of history
return promptConstructionFn; return success(promptConstructionFn);
}), }),
replaceVariant: protectedProcedure replaceVariant: protectedProcedure

View File

@@ -12,14 +12,20 @@ const isolate = new ivm.Isolate({ memoryLimit: 128 });
export async function deriveNewConstructFn( export async function deriveNewConstructFn(
originalVariant: PromptVariant | null, originalVariant: PromptVariant | null,
originalPromptFn?: string,
newModel?: Model, newModel?: Model,
instructions?: string, instructions?: string,
) { ) {
if (originalVariant && !newModel && !instructions) { if (originalPromptFn && !newModel && !instructions) {
return originalVariant.promptConstructor; return originalPromptFn;
} }
if (originalVariant && (newModel || instructions)) { if (originalVariant && originalPromptFn && (newModel || instructions)) {
return await requestUpdatedPromptFunction(originalVariant, newModel, instructions); return await requestUpdatedPromptFunction(
originalVariant,
originalPromptFn,
newModel,
instructions,
);
} }
return dedent` return dedent`
prompt = { prompt = {
@@ -36,6 +42,7 @@ export async function deriveNewConstructFn(
const NUM_RETRIES = 5; const NUM_RETRIES = 5;
const requestUpdatedPromptFunction = async ( const requestUpdatedPromptFunction = async (
originalVariant: PromptVariant, originalVariant: PromptVariant,
originalPromptFn: string,
newModel?: Model, newModel?: Model,
instructions?: string, instructions?: string,
) => { ) => {
@@ -55,7 +62,7 @@ const requestUpdatedPromptFunction = async (
}, },
{ {
role: "user", role: "user",
content: `This is the current prompt constructor function:\n---\n${originalVariant.promptConstructor}`, content: `This is the current prompt constructor function:\n---\n${originalPromptFn}`,
}, },
]; ];
if (newModel) { if (newModel) {

View File

@@ -1,16 +1,26 @@
import loader, { type Monaco } from "@monaco-editor/loader";
import { type RouterOutputs } from "~/utils/api"; import { type RouterOutputs } from "~/utils/api";
import { type SliceCreator } from "./store"; import { type SliceCreator } from "./store";
import loader from "@monaco-editor/loader";
import formatPromptConstructor from "~/promptConstructor/format"; import formatPromptConstructor from "~/promptConstructor/format";
export const editorBackground = "#fafafa"; export const editorBackground = "#fafafa";
export type CreatedEditor = ReturnType<Monaco["editor"]["create"]>;
type EditorOptions = {
getContent: () => string;
setContent: (content: string) => void;
};
export type SharedVariantEditorSlice = { export type SharedVariantEditorSlice = {
monaco: null | ReturnType<typeof loader.__getMonacoInstance>; monaco: null | Monaco;
loadMonaco: () => Promise<void>; loadMonaco: () => Promise<void>;
scenarioVars: RouterOutputs["scenarioVars"]["list"]; scenarioVars: RouterOutputs["scenarioVars"]["list"];
updateScenariosModel: () => void; updateScenariosModel: () => void;
setScenarioVars: (scenarioVars: RouterOutputs["scenarioVars"]["list"]) => void; setScenarioVars: (scenarioVars: RouterOutputs["scenarioVars"]["list"]) => void;
editorOptionsMap: Record<string, EditorOptions>;
updateOptionsForEditor: (uiId: string, { getContent, setContent }: EditorOptions) => void;
}; };
export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> = (set, get) => ({ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> = (set, get) => ({
@@ -93,4 +103,10 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
); );
} }
}, },
editorOptionsMap: {},
updateOptionsForEditor: (uiId, options) => {
set((state) => {
state.sharedVariantEditor.editorOptionsMap[uiId] = options;
});
},
}); });