diff --git a/src/components/SelectModelModal/SelectModelModal.tsx b/src/components/ChangeModelModal/ChangeModelModal.tsx similarity index 75% rename from src/components/SelectModelModal/SelectModelModal.tsx rename to src/components/ChangeModelModal/ChangeModelModal.tsx index 1476448..47e57c4 100644 --- a/src/components/SelectModelModal/SelectModelModal.tsx +++ b/src/components/ChangeModelModal/ChangeModelModal.tsx @@ -15,25 +15,29 @@ import { } from "@chakra-ui/react"; import { RiExchangeFundsFill } from "react-icons/ri"; import { useState } from "react"; -import { type SupportedModel } from "~/server/types"; import { ModelStatsCard } from "./ModelStatsCard"; -import { SelectModelSearch } from "./SelectModelSearch"; +import { ModelSearch } from "./ModelSearch"; import { api } from "~/utils/api"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import CompareFunctions from "../RefinePromptModal/CompareFunctions"; import { type PromptVariant } from "@prisma/client"; import { isObject, isString } from "lodash-es"; +import { type Model, type SupportedProvider } from "~/modelProviders/types"; +import frontendModelProviders from "~/modelProviders/frontendModelProviders"; +import { keyForModel } from "~/utils/utils"; -export const SelectModelModal = ({ +export const ChangeModelModal = ({ variant, onClose, }: { variant: PromptVariant; onClose: () => void; }) => { - const originalModel = variant.model as SupportedModel; - const [selectedModel, setSelectedModel] = useState(originalModel); - const [convertedModel, setConvertedModel] = useState(undefined); + const originalModelProviderName = variant.modelProvider as SupportedProvider; + const originalModelProvider = frontendModelProviders[originalModelProviderName]; + const originalModel = originalModelProvider.models[variant.model] as Model; + const [selectedModel, setSelectedModel] = useState(originalModel); + const [convertedModel, setConvertedModel] = useState(undefined); const utils = api.useContext(); const experiment = useExperiment(); @@ -68,6 +72,10 @@ export const SelectModelModal = ({ onClose(); }, [replaceVariantMutation, variant, onClose, modifiedPromptFn]); + const originalModelLabel = keyForModel(originalModel); + const selectedModelLabel = keyForModel(selectedModel); + const convertedModelLabel = convertedModel ? keyForModel(convertedModel) : undefined; + return ( - {originalModel !== selectedModel && ( + {originalModelLabel !== selectedModelLabel && ( )} - + {isString(modifiedPromptFn) && ( )} diff --git a/src/components/ChangeModelModal/ModelSearch.tsx b/src/components/ChangeModelModal/ModelSearch.tsx new file mode 100644 index 0000000..addfd5f --- /dev/null +++ b/src/components/ChangeModelModal/ModelSearch.tsx @@ -0,0 +1,50 @@ +import { VStack, Text } from "@chakra-ui/react"; +import { type LegacyRef, useCallback } from "react"; +import Select, { type SingleValue } from "react-select"; +import { useElementDimensions } from "~/utils/hooks"; + +import frontendModelProviders from "~/modelProviders/frontendModelProviders"; +import { type Model } from "~/modelProviders/types"; +import { keyForModel } from "~/utils/utils"; + +const modelOptions: { label: string; value: Model }[] = []; + +for (const [_, providerValue] of Object.entries(frontendModelProviders)) { + for (const [_, modelValue] of Object.entries(providerValue.models)) { + modelOptions.push({ + label: keyForModel(modelValue), + value: modelValue, + }); + } +} + +export const ModelSearch = ({ + selectedModel, + setSelectedModel, +}: { + selectedModel: Model; + setSelectedModel: (model: Model) => void; +}) => { + const handleSelection = useCallback( + (option: SingleValue<{ label: string; value: Model }>) => { + if (!option) return; + setSelectedModel(option.value); + }, + [setSelectedModel], + ); + const selectedOption = modelOptions.find((option) => option.label === keyForModel(selectedModel)); + + const [containerRef, containerDimensions] = useElementDimensions(); + + return ( + } w="full"> + Browse Models + ({ ...provided, width: containerDimensions?.width }) }} - value={selectedOption} - options={modelOptions} - onChange={handleSelection} - /> - - ); -}; diff --git a/src/components/VariantHeader/VariantHeaderMenuButton.tsx b/src/components/VariantHeader/VariantHeaderMenuButton.tsx index 31958de..3ddfee4 100644 --- a/src/components/VariantHeader/VariantHeaderMenuButton.tsx +++ b/src/components/VariantHeader/VariantHeaderMenuButton.tsx @@ -17,7 +17,7 @@ import { FaRegClone } from "react-icons/fa"; import { useState } from "react"; import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal"; import { RiExchangeFundsFill } from "react-icons/ri"; -import { SelectModelModal } from "../SelectModelModal/SelectModelModal"; +import { ChangeModelModal } from "../ChangeModelModal/ChangeModelModal"; export default function VariantHeaderMenuButton({ variant, @@ -50,7 +50,7 @@ export default function VariantHeaderMenuButton({ await utils.promptVariants.list.invalidate(); }, [hideMutation, variant.id]); - const [selectModelModalOpen, setSelectModelModalOpen] = useState(false); + const [changeModelModalOpen, setChangeModelModalOpen] = useState(false); const [refinePromptModalOpen, setRefinePromptModalOpen] = useState(false); return ( @@ -72,7 +72,7 @@ export default function VariantHeaderMenuButton({ } - onClick={() => setSelectModelModalOpen(true)} + onClick={() => setChangeModelModalOpen(true)} > Change Model @@ -97,8 +97,8 @@ export default function VariantHeaderMenuButton({ )} - {selectModelModalOpen && ( - setSelectModelModalOpen(false)} /> + {changeModelModalOpen && ( + setChangeModelModalOpen(false)} /> )} {refinePromptModalOpen && ( setRefinePromptModalOpen(false)} /> diff --git a/src/modelProviders/modelStats.ts b/src/modelProviders/modelStats.ts deleted file mode 100644 index 0b291c5..0000000 --- a/src/modelProviders/modelStats.ts +++ /dev/null @@ -1,77 +0,0 @@ -import { type SupportedModel } from "../server/types"; - -interface ModelStats { - contextLength: number; - promptTokenPrice: number; - completionTokenPrice: number; - speed: "fast" | "medium" | "slow"; - provider: "OpenAI"; - learnMoreUrl: string; -} - -export const modelStats: Record = { - "gpt-4": { - contextLength: 8192, - promptTokenPrice: 0.00003, - completionTokenPrice: 0.00006, - speed: "medium", - provider: "OpenAI", - learnMoreUrl: "https://openai.com/gpt-4", - }, - "gpt-4-0613": { - contextLength: 8192, - promptTokenPrice: 0.00003, - completionTokenPrice: 0.00006, - speed: "medium", - provider: "OpenAI", - learnMoreUrl: "https://openai.com/gpt-4", - }, - "gpt-4-32k": { - contextLength: 32768, - promptTokenPrice: 0.00006, - completionTokenPrice: 0.00012, - speed: "medium", - provider: "OpenAI", - learnMoreUrl: "https://openai.com/gpt-4", - }, - "gpt-4-32k-0613": { - contextLength: 32768, - promptTokenPrice: 0.00006, - completionTokenPrice: 0.00012, - speed: "medium", - provider: "OpenAI", - learnMoreUrl: "https://openai.com/gpt-4", - }, - "gpt-3.5-turbo": { - contextLength: 4096, - promptTokenPrice: 0.0000015, - completionTokenPrice: 0.000002, - speed: "fast", - provider: "OpenAI", - learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api", - }, - "gpt-3.5-turbo-0613": { - contextLength: 4096, - promptTokenPrice: 0.0000015, - completionTokenPrice: 0.000002, - speed: "fast", - provider: "OpenAI", - learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api", - }, - "gpt-3.5-turbo-16k": { - contextLength: 16384, - promptTokenPrice: 0.000003, - completionTokenPrice: 0.000004, - speed: "fast", - provider: "OpenAI", - learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api", - }, - "gpt-3.5-turbo-16k-0613": { - contextLength: 16384, - promptTokenPrice: 0.000003, - completionTokenPrice: 0.000004, - speed: "fast", - provider: "OpenAI", - learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api", - }, -}; diff --git a/src/modelProviders/openai-ChatCompletion/frontend.ts b/src/modelProviders/openai-ChatCompletion/frontend.ts index b4a9fc2..825c2c1 100644 --- a/src/modelProviders/openai-ChatCompletion/frontend.ts +++ b/src/modelProviders/openai-ChatCompletion/frontend.ts @@ -9,19 +9,39 @@ const frontendModelProvider: FrontendModelProvider c.message).filter(truthyFilter), ); } catch (err) { @@ -106,10 +104,10 @@ export async function getCompletion( } const timeToComplete = Date.now() - start; - const stats = modelStats[input.model as keyof typeof OpenAIChatModel]; + const { promptTokenPrice, completionTokenPrice } = frontendModelProvider.models[modelName]; let cost = undefined; - if (stats && promptTokens && completionTokens) { - cost = promptTokens * stats.promptTokenPrice + completionTokens * stats.completionTokenPrice; + if (promptTokenPrice && completionTokenPrice && promptTokens && completionTokens) { + cost = promptTokens * promptTokenPrice + completionTokens * completionTokenPrice; } return { diff --git a/src/modelProviders/replicate-llama2/frontend.ts b/src/modelProviders/replicate-llama2/frontend.ts index 9c8df44..dc3a6e0 100644 --- a/src/modelProviders/replicate-llama2/frontend.ts +++ b/src/modelProviders/replicate-llama2/frontend.ts @@ -5,9 +5,30 @@ const frontendModelProvider: FrontendModelProvider { diff --git a/src/modelProviders/types.ts b/src/modelProviders/types.ts index c9023d8..2bdb8ec 100644 --- a/src/modelProviders/types.ts +++ b/src/modelProviders/types.ts @@ -1,16 +1,31 @@ import { type JSONSchema4 } from "json-schema"; import { type JsonValue } from "type-fest"; +import { z } from "zod"; -export type SupportedProvider = "openai/ChatCompletion" | "replicate/llama2"; +const ZodSupportedProvider = z.union([ + z.literal("openai/ChatCompletion"), + z.literal("replicate/llama2"), +]); -type ModelInfo = { - name?: string; - learnMore?: string; -}; +export type SupportedProvider = z.infer; + +export const ZodModel = z.object({ + name: z.string(), + contextWindow: z.number(), + promptTokenPrice: z.number().optional(), + completionTokenPrice: z.number().optional(), + pricePerSecond: z.number().optional(), + speed: z.union([z.literal("fast"), z.literal("medium"), z.literal("slow")]), + provider: ZodSupportedProvider, + description: z.string().optional(), + learnMoreUrl: z.string().optional(), +}); + +export type Model = z.infer; export type FrontendModelProvider = { name: string; - models: Record; + models: Record; normalizeOutput: (output: OutputSchema) => NormalizedOutput; }; diff --git a/src/server/api/routers/promptVariants.router.ts b/src/server/api/routers/promptVariants.router.ts index cef951c..d88a53d 100644 --- a/src/server/api/routers/promptVariants.router.ts +++ b/src/server/api/routers/promptVariants.router.ts @@ -2,7 +2,6 @@ import { z } from "zod"; import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc"; import { prisma } from "~/server/db"; import { generateNewCell } from "~/server/utils/generateNewCell"; -import { type SupportedModel } from "~/server/types"; import userError from "~/server/utils/error"; import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated"; import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants"; @@ -10,6 +9,7 @@ import { type PromptVariant } from "@prisma/client"; import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn"; import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl"; import parseConstructFn from "~/server/utils/parseConstructFn"; +import { ZodModel } from "~/modelProviders/types"; export const promptVariantsRouter = createTRPCRouter({ list: publicProcedure @@ -144,7 +144,7 @@ export const promptVariantsRouter = createTRPCRouter({ z.object({ experimentId: z.string(), variantId: z.string().optional(), - newModel: z.string().optional(), + newModel: ZodModel.optional(), }), ) .mutation(async ({ input, ctx }) => { @@ -186,10 +186,7 @@ export const promptVariantsRouter = createTRPCRouter({ ? `${originalVariant?.label} Copy` : `Prompt Variant ${largestSortIndex + 2}`; - const newConstructFn = await deriveNewConstructFn( - originalVariant, - input.newModel as SupportedModel, - ); + const newConstructFn = await deriveNewConstructFn(originalVariant, input.newModel); const createNewVariantAction = prisma.promptVariant.create({ data: { @@ -289,7 +286,7 @@ export const promptVariantsRouter = createTRPCRouter({ z.object({ id: z.string(), instructions: z.string().optional(), - newModel: z.string().optional(), + newModel: ZodModel.optional(), }), ) .mutation(async ({ input, ctx }) => { @@ -308,7 +305,7 @@ export const promptVariantsRouter = createTRPCRouter({ const promptConstructionFn = await deriveNewConstructFn( existing, - input.newModel as SupportedModel | undefined, + input.newModel, input.instructions, ); diff --git a/src/server/types.ts b/src/server/types.ts deleted file mode 100644 index 983bacd..0000000 --- a/src/server/types.ts +++ /dev/null @@ -1,12 +0,0 @@ -export enum OpenAIChatModel { - "gpt-4" = "gpt-4", - "gpt-4-0613" = "gpt-4-0613", - "gpt-4-32k" = "gpt-4-32k", - "gpt-4-32k-0613" = "gpt-4-32k-0613", - "gpt-3.5-turbo" = "gpt-3.5-turbo", - "gpt-3.5-turbo-0613" = "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-16k" = "gpt-3.5-turbo-16k", - "gpt-3.5-turbo-16k-0613" = "gpt-3.5-turbo-16k-0613", -} - -export type SupportedModel = keyof typeof OpenAIChatModel; diff --git a/src/server/utils/deriveNewContructFn.ts b/src/server/utils/deriveNewContructFn.ts index eed1398..2149f00 100644 --- a/src/server/utils/deriveNewContructFn.ts +++ b/src/server/utils/deriveNewContructFn.ts @@ -1,18 +1,18 @@ import { type PromptVariant } from "@prisma/client"; -import { type SupportedModel } from "../types"; import ivm from "isolated-vm"; import dedent from "dedent"; import { openai } from "./openai"; -import { getApiShapeForModel } from "./getTypesForModel"; import { isObject } from "lodash-es"; import { type CompletionCreateParams } from "openai/resources/chat/completions"; import formatPromptConstructor from "~/utils/formatPromptConstructor"; +import { type SupportedProvider, type Model } from "~/modelProviders/types"; +import modelProviders from "~/modelProviders/modelProviders"; const isolate = new ivm.Isolate({ memoryLimit: 128 }); export async function deriveNewConstructFn( originalVariant: PromptVariant | null, - newModel?: SupportedModel, + newModel?: Model, instructions?: string, ) { if (originalVariant && !newModel && !instructions) { @@ -36,10 +36,11 @@ export async function deriveNewConstructFn( const NUM_RETRIES = 5; const requestUpdatedPromptFunction = async ( originalVariant: PromptVariant, - newModel?: SupportedModel, + newModel?: Model, instructions?: string, ) => { - const originalModel = originalVariant.model as SupportedModel; + const originalModelProvider = modelProviders[originalVariant.modelProvider as SupportedProvider]; + const originalModel = originalModelProvider.models[originalVariant.model] as Model; let newContructionFn = ""; for (let i = 0; i < NUM_RETRIES; i++) { try { @@ -47,7 +48,7 @@ const requestUpdatedPromptFunction = async ( { role: "system", content: `Your job is to update prompt constructor functions. Here is the api shape for the current model:\n---\n${JSON.stringify( - getApiShapeForModel(originalModel), + originalModelProvider.inputSchema, null, 2, )}\n\nDo not add any assistant messages.`, @@ -60,8 +61,20 @@ const requestUpdatedPromptFunction = async ( if (newModel) { messages.push({ role: "user", - content: `Return the prompt constructor function for ${newModel} given the existing prompt constructor function for ${originalModel}`, + content: `Return the prompt constructor function for ${newModel.name} given the existing prompt constructor function for ${originalModel.name}`, }); + if (newModel.provider !== originalModel.provider) { + messages.push({ + role: "user", + content: `The old provider was ${originalModel.provider}. The new provider is ${ + newModel.provider + }. Here is the schema for the new model:\n---\n${JSON.stringify( + modelProviders[newModel.provider].inputSchema, + null, + 2, + )}`, + }); + } } if (instructions) { messages.push({ diff --git a/src/server/utils/getTypesForModel.ts b/src/server/utils/getTypesForModel.ts deleted file mode 100644 index 2aa679a..0000000 --- a/src/server/utils/getTypesForModel.ts +++ /dev/null @@ -1,6 +0,0 @@ -import { type SupportedModel } from "../types"; - -export const getApiShapeForModel = (model: SupportedModel) => { - // if (model in OpenAIChatModel) return openAIChatApiShape; - return ""; -}; diff --git a/src/utils/countTokens.ts b/src/utils/countTokens.ts index c21d432..653adaa 100644 --- a/src/utils/countTokens.ts +++ b/src/utils/countTokens.ts @@ -1,6 +1,6 @@ import { type ChatCompletion } from "openai/resources/chat"; import { GPTTokens } from "gpt-tokens"; -import { type OpenAIChatModel } from "~/server/types"; +import { type SupportedModel } from "~/modelProviders/openai-ChatCompletion"; interface GPTTokensMessageItem { name?: string; @@ -9,7 +9,7 @@ interface GPTTokensMessageItem { } export const countOpenAIChatTokens = ( - model: keyof typeof OpenAIChatModel, + model: SupportedModel, messages: ChatCompletion.Choice.Message[], ) => { return new GPTTokens({ model, messages: messages as unknown as GPTTokensMessageItem[] }) diff --git a/src/utils/utils.ts b/src/utils/utils.ts index 217c0fb..2644dfb 100644 --- a/src/utils/utils.ts +++ b/src/utils/utils.ts @@ -1 +1,5 @@ +import { type Model } from "~/modelProviders/types"; + export const truthyFilter = (x: T | null | undefined): x is T => Boolean(x); + +export const keyForModel = (model: Model) => `${model.provider}/${model.name}`;