Only pass in model and provider

I got somewhat confused by the extra fields, sorry.

Also makes some frontend changes to track that state more directly although in retrospect not sure the frontend changes make things any better.
This commit is contained in:
Kyle Corbitt
2023-07-24 17:20:38 -07:00
parent e0b457c6c5
commit 9952dd93d8
6 changed files with 101 additions and 91 deletions

View File

@@ -1,5 +1,7 @@
import {
Button,
HStack,
Icon,
Modal,
ModalBody,
ModalCloseButton,
@@ -7,24 +9,21 @@ import {
ModalFooter,
ModalHeader,
ModalOverlay,
VStack,
Text,
Spinner,
HStack,
Icon,
Text,
VStack,
} from "@chakra-ui/react";
import { RiExchangeFundsFill } from "react-icons/ri";
import { useState } from "react";
import { ModelStatsCard } from "./ModelStatsCard";
import { ModelSearch } from "./ModelSearch";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
import { type PromptVariant } from "@prisma/client";
import { isObject, isString } from "lodash-es";
import { type Model, type SupportedProvider } from "~/modelProviders/types";
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
import { keyForModel } from "~/utils/utils";
import { useState } from "react";
import { RiExchangeFundsFill } from "react-icons/ri";
import { type ProviderModel } from "~/modelProviders/types";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import { lookupModel, modelLabel } from "~/utils/utils";
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
import { ModelSearch } from "./ModelSearch";
import { ModelStatsCard } from "./ModelStatsCard";
export const ChangeModelModal = ({
variant,
@@ -33,11 +32,13 @@ export const ChangeModelModal = ({
variant: PromptVariant;
onClose: () => void;
}) => {
const originalModelProviderName = variant.modelProvider as SupportedProvider;
const originalModelProvider = frontendModelProviders[originalModelProviderName];
const originalModel = originalModelProvider.models[variant.model] as Model;
const [selectedModel, setSelectedModel] = useState<Model>(originalModel);
const [convertedModel, setConvertedModel] = useState<Model | undefined>(undefined);
const originalModel = lookupModel(variant.modelProvider, variant.model);
const [selectedModel, setSelectedModel] = useState({
provider: variant.modelProvider,
model: variant.model,
} as ProviderModel);
const [convertedModel, setConvertedModel] = useState<ProviderModel | undefined>();
const utils = api.useContext();
const experiment = useExperiment();
@@ -72,9 +73,10 @@ export const ChangeModelModal = ({
onClose();
}, [replaceVariantMutation, variant, onClose, modifiedPromptFn]);
const originalModelLabel = keyForModel(originalModel);
const selectedModelLabel = keyForModel(selectedModel);
const convertedModelLabel = convertedModel ? keyForModel(convertedModel) : undefined;
const originalLabel = modelLabel(variant.modelProvider, variant.model);
const selectedLabel = modelLabel(selectedModel.provider, selectedModel.model);
const convertedLabel =
convertedModel && modelLabel(convertedModel.provider, convertedModel.model);
return (
<Modal
@@ -94,16 +96,19 @@ export const ChangeModelModal = ({
<ModalBody maxW="unset">
<VStack spacing={8}>
<ModelStatsCard label="Original Model" model={originalModel} />
{originalModelLabel !== selectedModelLabel && (
<ModelStatsCard label="New Model" model={selectedModel} />
{originalLabel !== selectedLabel && (
<ModelStatsCard
label="New Model"
model={lookupModel(selectedModel.provider, selectedModel.model)}
/>
)}
<ModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
{isString(modifiedPromptFn) && (
<CompareFunctions
originalFunction={variant.constructFn}
newFunction={modifiedPromptFn}
leftTitle={originalModelLabel}
rightTitle={convertedModelLabel}
leftTitle={originalLabel}
rightTitle={convertedLabel}
/>
)}
</VStack>
@@ -115,7 +120,7 @@ export const ChangeModelModal = ({
colorScheme="gray"
onClick={getModifiedPromptFn}
minW={24}
isDisabled={originalModel === selectedModel || modificationInProgress}
isDisabled={originalLabel === selectedLabel || modificationInProgress}
>
{modificationInProgress ? <Spinner boxSize={4} /> : <Text>Convert</Text>}
</Button>