diff --git a/src/components/ChangeModelModal/ChangeModelModal.tsx b/src/components/ChangeModelModal/ChangeModelModal.tsx index 47e57c4..15ff507 100644 --- a/src/components/ChangeModelModal/ChangeModelModal.tsx +++ b/src/components/ChangeModelModal/ChangeModelModal.tsx @@ -1,5 +1,7 @@ import { Button, + HStack, + Icon, Modal, ModalBody, ModalCloseButton, @@ -7,24 +9,21 @@ import { ModalFooter, ModalHeader, ModalOverlay, - VStack, - Text, Spinner, - HStack, - Icon, + Text, + VStack, } from "@chakra-ui/react"; -import { RiExchangeFundsFill } from "react-icons/ri"; -import { useState } from "react"; -import { ModelStatsCard } from "./ModelStatsCard"; -import { ModelSearch } from "./ModelSearch"; -import { api } from "~/utils/api"; -import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; -import CompareFunctions from "../RefinePromptModal/CompareFunctions"; import { type PromptVariant } from "@prisma/client"; import { isObject, isString } from "lodash-es"; -import { type Model, type SupportedProvider } from "~/modelProviders/types"; -import frontendModelProviders from "~/modelProviders/frontendModelProviders"; -import { keyForModel } from "~/utils/utils"; +import { useState } from "react"; +import { RiExchangeFundsFill } from "react-icons/ri"; +import { type ProviderModel } from "~/modelProviders/types"; +import { api } from "~/utils/api"; +import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; +import { lookupModel, modelLabel } from "~/utils/utils"; +import CompareFunctions from "../RefinePromptModal/CompareFunctions"; +import { ModelSearch } from "./ModelSearch"; +import { ModelStatsCard } from "./ModelStatsCard"; export const ChangeModelModal = ({ variant, @@ -33,11 +32,13 @@ export const ChangeModelModal = ({ variant: PromptVariant; onClose: () => void; }) => { - const originalModelProviderName = variant.modelProvider as SupportedProvider; - const originalModelProvider = frontendModelProviders[originalModelProviderName]; - const originalModel = originalModelProvider.models[variant.model] as Model; - const [selectedModel, setSelectedModel] = useState(originalModel); - const [convertedModel, setConvertedModel] = useState(undefined); + const originalModel = lookupModel(variant.modelProvider, variant.model); + const [selectedModel, setSelectedModel] = useState({ + provider: variant.modelProvider, + model: variant.model, + } as ProviderModel); + const [convertedModel, setConvertedModel] = useState(); + const utils = api.useContext(); const experiment = useExperiment(); @@ -72,9 +73,10 @@ export const ChangeModelModal = ({ onClose(); }, [replaceVariantMutation, variant, onClose, modifiedPromptFn]); - const originalModelLabel = keyForModel(originalModel); - const selectedModelLabel = keyForModel(selectedModel); - const convertedModelLabel = convertedModel ? keyForModel(convertedModel) : undefined; + const originalLabel = modelLabel(variant.modelProvider, variant.model); + const selectedLabel = modelLabel(selectedModel.provider, selectedModel.model); + const convertedLabel = + convertedModel && modelLabel(convertedModel.provider, convertedModel.model); return ( - {originalModelLabel !== selectedModelLabel && ( - + {originalLabel !== selectedLabel && ( + )} {isString(modifiedPromptFn) && ( )} @@ -115,7 +120,7 @@ export const ChangeModelModal = ({ colorScheme="gray" onClick={getModifiedPromptFn} minW={24} - isDisabled={originalModel === selectedModel || modificationInProgress} + isDisabled={originalLabel === selectedLabel || modificationInProgress} > {modificationInProgress ? : Convert} diff --git a/src/components/ChangeModelModal/ModelSearch.tsx b/src/components/ChangeModelModal/ModelSearch.tsx index addfd5f..8f03de4 100644 --- a/src/components/ChangeModelModal/ModelSearch.tsx +++ b/src/components/ChangeModelModal/ModelSearch.tsx @@ -1,49 +1,35 @@ -import { VStack, Text } from "@chakra-ui/react"; -import { type LegacyRef, useCallback } from "react"; -import Select, { type SingleValue } from "react-select"; +import { Text, VStack } from "@chakra-ui/react"; +import { type LegacyRef } from "react"; +import Select from "react-select"; import { useElementDimensions } from "~/utils/hooks"; +import { flatMap } from "lodash-es"; import frontendModelProviders from "~/modelProviders/frontendModelProviders"; -import { type Model } from "~/modelProviders/types"; -import { keyForModel } from "~/utils/utils"; +import { type ProviderModel } from "~/modelProviders/types"; +import { modelLabel } from "~/utils/utils"; -const modelOptions: { label: string; value: Model }[] = []; +const modelOptions = flatMap(Object.entries(frontendModelProviders), ([providerId, provider]) => + Object.entries(provider.models).map(([modelId]) => ({ + provider: providerId, + model: modelId, + })), +) as ProviderModel[]; -for (const [_, providerValue] of Object.entries(frontendModelProviders)) { - for (const [_, modelValue] of Object.entries(providerValue.models)) { - modelOptions.push({ - label: keyForModel(modelValue), - value: modelValue, - }); - } -} - -export const ModelSearch = ({ - selectedModel, - setSelectedModel, -}: { - selectedModel: Model; - setSelectedModel: (model: Model) => void; +export const ModelSearch = (props: { + selectedModel: ProviderModel; + setSelectedModel: (model: ProviderModel) => void; }) => { - const handleSelection = useCallback( - (option: SingleValue<{ label: string; value: Model }>) => { - if (!option) return; - setSelectedModel(option.value); - }, - [setSelectedModel], - ); - const selectedOption = modelOptions.find((option) => option.label === keyForModel(selectedModel)); - const [containerRef, containerDimensions] = useElementDimensions(); return ( } w="full"> Browse Models -