From 9952dd93d87c53c4c972cbbc1b06bd1f96cc5a2a Mon Sep 17 00:00:00 2001 From: Kyle Corbitt Date: Mon, 24 Jul 2023 17:20:38 -0700 Subject: [PATCH] Only pass in model and provider I got somewhat confused by the extra fields, sorry. Also makes some frontend changes to track that state more directly although in retrospect not sure the frontend changes make things any better. --- .../ChangeModelModal/ChangeModelModal.tsx | 59 ++++++++++--------- .../ChangeModelModal/ModelSearch.tsx | 52 ++++++---------- .../ChangeModelModal/ModelStatsCard.tsx | 21 ++++--- src/modelProviders/types.ts | 26 ++++---- .../api/routers/promptVariants.router.ts | 23 +++++--- src/utils/utils.ts | 11 +++- 6 files changed, 101 insertions(+), 91 deletions(-) diff --git a/src/components/ChangeModelModal/ChangeModelModal.tsx b/src/components/ChangeModelModal/ChangeModelModal.tsx index 47e57c4..15ff507 100644 --- a/src/components/ChangeModelModal/ChangeModelModal.tsx +++ b/src/components/ChangeModelModal/ChangeModelModal.tsx @@ -1,5 +1,7 @@ import { Button, + HStack, + Icon, Modal, ModalBody, ModalCloseButton, @@ -7,24 +9,21 @@ import { ModalFooter, ModalHeader, ModalOverlay, - VStack, - Text, Spinner, - HStack, - Icon, + Text, + VStack, } from "@chakra-ui/react"; -import { RiExchangeFundsFill } from "react-icons/ri"; -import { useState } from "react"; -import { ModelStatsCard } from "./ModelStatsCard"; -import { ModelSearch } from "./ModelSearch"; -import { api } from "~/utils/api"; -import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; -import CompareFunctions from "../RefinePromptModal/CompareFunctions"; import { type PromptVariant } from "@prisma/client"; import { isObject, isString } from "lodash-es"; -import { type Model, type SupportedProvider } from "~/modelProviders/types"; -import frontendModelProviders from "~/modelProviders/frontendModelProviders"; -import { keyForModel } from "~/utils/utils"; +import { useState } from "react"; +import { RiExchangeFundsFill } from "react-icons/ri"; +import { type ProviderModel } from "~/modelProviders/types"; +import { api } from "~/utils/api"; +import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; +import { lookupModel, modelLabel } from "~/utils/utils"; +import CompareFunctions from "../RefinePromptModal/CompareFunctions"; +import { ModelSearch } from "./ModelSearch"; +import { ModelStatsCard } from "./ModelStatsCard"; export const ChangeModelModal = ({ variant, @@ -33,11 +32,13 @@ export const ChangeModelModal = ({ variant: PromptVariant; onClose: () => void; }) => { - const originalModelProviderName = variant.modelProvider as SupportedProvider; - const originalModelProvider = frontendModelProviders[originalModelProviderName]; - const originalModel = originalModelProvider.models[variant.model] as Model; - const [selectedModel, setSelectedModel] = useState(originalModel); - const [convertedModel, setConvertedModel] = useState(undefined); + const originalModel = lookupModel(variant.modelProvider, variant.model); + const [selectedModel, setSelectedModel] = useState({ + provider: variant.modelProvider, + model: variant.model, + } as ProviderModel); + const [convertedModel, setConvertedModel] = useState(); + const utils = api.useContext(); const experiment = useExperiment(); @@ -72,9 +73,10 @@ export const ChangeModelModal = ({ onClose(); }, [replaceVariantMutation, variant, onClose, modifiedPromptFn]); - const originalModelLabel = keyForModel(originalModel); - const selectedModelLabel = keyForModel(selectedModel); - const convertedModelLabel = convertedModel ? keyForModel(convertedModel) : undefined; + const originalLabel = modelLabel(variant.modelProvider, variant.model); + const selectedLabel = modelLabel(selectedModel.provider, selectedModel.model); + const convertedLabel = + convertedModel && modelLabel(convertedModel.provider, convertedModel.model); return ( - {originalModelLabel !== selectedModelLabel && ( - + {originalLabel !== selectedLabel && ( + )} {isString(modifiedPromptFn) && ( )} @@ -115,7 +120,7 @@ export const ChangeModelModal = ({ colorScheme="gray" onClick={getModifiedPromptFn} minW={24} - isDisabled={originalModel === selectedModel || modificationInProgress} + isDisabled={originalLabel === selectedLabel || modificationInProgress} > {modificationInProgress ? : Convert} diff --git a/src/components/ChangeModelModal/ModelSearch.tsx b/src/components/ChangeModelModal/ModelSearch.tsx index addfd5f..8f03de4 100644 --- a/src/components/ChangeModelModal/ModelSearch.tsx +++ b/src/components/ChangeModelModal/ModelSearch.tsx @@ -1,49 +1,35 @@ -import { VStack, Text } from "@chakra-ui/react"; -import { type LegacyRef, useCallback } from "react"; -import Select, { type SingleValue } from "react-select"; +import { Text, VStack } from "@chakra-ui/react"; +import { type LegacyRef } from "react"; +import Select from "react-select"; import { useElementDimensions } from "~/utils/hooks"; +import { flatMap } from "lodash-es"; import frontendModelProviders from "~/modelProviders/frontendModelProviders"; -import { type Model } from "~/modelProviders/types"; -import { keyForModel } from "~/utils/utils"; +import { type ProviderModel } from "~/modelProviders/types"; +import { modelLabel } from "~/utils/utils"; -const modelOptions: { label: string; value: Model }[] = []; +const modelOptions = flatMap(Object.entries(frontendModelProviders), ([providerId, provider]) => + Object.entries(provider.models).map(([modelId]) => ({ + provider: providerId, + model: modelId, + })), +) as ProviderModel[]; -for (const [_, providerValue] of Object.entries(frontendModelProviders)) { - for (const [_, modelValue] of Object.entries(providerValue.models)) { - modelOptions.push({ - label: keyForModel(modelValue), - value: modelValue, - }); - } -} - -export const ModelSearch = ({ - selectedModel, - setSelectedModel, -}: { - selectedModel: Model; - setSelectedModel: (model: Model) => void; +export const ModelSearch = (props: { + selectedModel: ProviderModel; + setSelectedModel: (model: ProviderModel) => void; }) => { - const handleSelection = useCallback( - (option: SingleValue<{ label: string; value: Model }>) => { - if (!option) return; - setSelectedModel(option.value); - }, - [setSelectedModel], - ); - const selectedOption = modelOptions.find((option) => option.label === keyForModel(selectedModel)); - const [containerRef, containerDimensions] = useElementDimensions(); return ( } w="full"> Browse Models -