From 741128e0f48b7a6870ddb53bd4f44e2fcf738414 Mon Sep 17 00:00:00 2001 From: Kyle Corbitt Date: Fri, 21 Jul 2023 11:49:35 -0700 Subject: [PATCH] Better division of labor between frontend and backend model providers A bit better thinking on which types go where. --- .../OutputsTable/OutputCell/OutputCell.tsx | 4 +-- ...sFrontend.ts => frontendModelProviders.ts} | 6 ++-- .../openai-ChatCompletion/frontend.ts | 30 ++++++++++++++++--- .../openai-ChatCompletion/index.ts | 23 ++------------ .../replicate-llama2/frontend.ts | 16 +++++++--- src/modelProviders/replicate-llama2/index.ts | 10 ++----- src/modelProviders/types.ts | 17 ++++++----- 7 files changed, 58 insertions(+), 48 deletions(-) rename src/modelProviders/{modelProvidersFrontend.ts => frontendModelProviders.ts} (74%) diff --git a/src/components/OutputsTable/OutputCell/OutputCell.tsx b/src/components/OutputsTable/OutputCell/OutputCell.tsx index c8bddce..a127c3e 100644 --- a/src/components/OutputsTable/OutputCell/OutputCell.tsx +++ b/src/components/OutputsTable/OutputCell/OutputCell.tsx @@ -10,7 +10,7 @@ import useSocket from "~/utils/useSocket"; import { OutputStats } from "./OutputStats"; import { ErrorHandler } from "./ErrorHandler"; import { CellOptions } from "./CellOptions"; -import modelProvidersFrontend from "~/modelProviders/modelProvidersFrontend"; +import frontendModelProviders from "~/modelProviders/frontendModelProviders"; export default function OutputCell({ scenario, @@ -40,7 +40,7 @@ export default function OutputCell({ ); const provider = - modelProvidersFrontend[variant.modelProvider as keyof typeof modelProvidersFrontend]; + frontendModelProviders[variant.modelProvider as keyof typeof frontendModelProviders]; type OutputSchema = Parameters[0]; diff --git a/src/modelProviders/modelProvidersFrontend.ts b/src/modelProviders/frontendModelProviders.ts similarity index 74% rename from src/modelProviders/modelProvidersFrontend.ts rename to src/modelProviders/frontendModelProviders.ts index 7068836..cdda811 100644 --- a/src/modelProviders/modelProvidersFrontend.ts +++ b/src/modelProviders/frontendModelProviders.ts @@ -1,15 +1,15 @@ import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend"; import replicateLlama2Frontend from "./replicate-llama2/frontend"; -import { type SupportedProvider, type ModelProviderFrontend } from "./types"; +import { type SupportedProvider, type FrontendModelProvider } from "./types"; // TODO: make sure we get a typescript error if you forget to add a provider here // Keep attributes here that need to be accessible from the frontend. We can't // just include them in the default `modelProviders` object because it has some // transient dependencies that can only be imported on the server. -const modelProvidersFrontend: Record> = { +const frontendModelProviders: Record> = { "openai/ChatCompletion": openaiChatCompletionFrontend, "replicate/llama2": replicateLlama2Frontend, }; -export default modelProvidersFrontend; +export default frontendModelProviders; diff --git a/src/modelProviders/openai-ChatCompletion/frontend.ts b/src/modelProviders/openai-ChatCompletion/frontend.ts index ba90c1b..b4a9fc2 100644 --- a/src/modelProviders/openai-ChatCompletion/frontend.ts +++ b/src/modelProviders/openai-ChatCompletion/frontend.ts @@ -1,8 +1,30 @@ import { type JsonValue } from "type-fest"; -import { type OpenaiChatModelProvider } from "."; -import { type ModelProviderFrontend } from "../types"; +import { type SupportedModel } from "."; +import { type FrontendModelProvider } from "../types"; +import { type ChatCompletion } from "openai/resources/chat"; + +const frontendModelProvider: FrontendModelProvider = { + name: "OpenAI ChatCompletion", + + models: { + "gpt-4-0613": { + name: "GPT-4", + learnMore: "https://openai.com/gpt-4", + }, + "gpt-4-32k-0613": { + name: "GPT-4 32k", + learnMore: "https://openai.com/gpt-4", + }, + "gpt-3.5-turbo-0613": { + name: "GPT-3.5 Turbo", + learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api", + }, + "gpt-3.5-turbo-16k-0613": { + name: "GPT-3.5 Turbo 16k", + learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api", + }, + }, -const modelProviderFrontend: ModelProviderFrontend = { normalizeOutput: (output) => { const message = output.choices[0]?.message; if (!message) @@ -39,4 +61,4 @@ const modelProviderFrontend: ModelProviderFrontend = { }, }; -export default modelProviderFrontend; +export default frontendModelProvider; diff --git a/src/modelProviders/openai-ChatCompletion/index.ts b/src/modelProviders/openai-ChatCompletion/index.ts index 294877b..9aa882c 100644 --- a/src/modelProviders/openai-ChatCompletion/index.ts +++ b/src/modelProviders/openai-ChatCompletion/index.ts @@ -3,6 +3,7 @@ import { type ModelProvider } from "../types"; import inputSchema from "./codegen/input.schema.json"; import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat"; import { getCompletion } from "./getCompletion"; +import frontendModelProvider from "./frontend"; const supportedModels = [ "gpt-4-0613", @@ -11,7 +12,7 @@ const supportedModels = [ "gpt-3.5-turbo-16k-0613", ] as const; -type SupportedModel = (typeof supportedModels)[number]; +export type SupportedModel = (typeof supportedModels)[number]; export type OpenaiChatModelProvider = ModelProvider< SupportedModel, @@ -20,25 +21,6 @@ export type OpenaiChatModelProvider = ModelProvider< >; const modelProvider: OpenaiChatModelProvider = { - name: "OpenAI ChatCompletion", - models: { - "gpt-4-0613": { - name: "GPT-4", - learnMore: "https://openai.com/gpt-4", - }, - "gpt-4-32k-0613": { - name: "GPT-4 32k", - learnMore: "https://openai.com/gpt-4", - }, - "gpt-3.5-turbo-0613": { - name: "GPT-3.5 Turbo", - learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api", - }, - "gpt-3.5-turbo-16k-0613": { - name: "GPT-3.5 Turbo 16k", - learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api", - }, - }, getModel: (input) => { if (supportedModels.includes(input.model as SupportedModel)) return input.model as SupportedModel; @@ -57,6 +39,7 @@ const modelProvider: OpenaiChatModelProvider = { inputSchema: inputSchema as JSONSchema4, shouldStream: (input) => input.stream ?? false, getCompletion, + ...frontendModelProvider, }; export default modelProvider; diff --git a/src/modelProviders/replicate-llama2/frontend.ts b/src/modelProviders/replicate-llama2/frontend.ts index e7f44eb..9c8df44 100644 --- a/src/modelProviders/replicate-llama2/frontend.ts +++ b/src/modelProviders/replicate-llama2/frontend.ts @@ -1,7 +1,15 @@ -import { type ReplicateLlama2Provider } from "."; -import { type ModelProviderFrontend } from "../types"; +import { type SupportedModel, type ReplicateLlama2Output } from "."; +import { type FrontendModelProvider } from "../types"; + +const frontendModelProvider: FrontendModelProvider = { + name: "Replicate Llama2", + + models: { + "7b-chat": {}, + "13b-chat": {}, + "70b-chat": {}, + }, -const modelProviderFrontend: ModelProviderFrontend = { normalizeOutput: (output) => { return { type: "text", @@ -10,4 +18,4 @@ const modelProviderFrontend: ModelProviderFrontend = { }, }; -export default modelProviderFrontend; +export default frontendModelProvider; diff --git a/src/modelProviders/replicate-llama2/index.ts b/src/modelProviders/replicate-llama2/index.ts index 49e1d1e..4e7c891 100644 --- a/src/modelProviders/replicate-llama2/index.ts +++ b/src/modelProviders/replicate-llama2/index.ts @@ -1,9 +1,10 @@ import { type ModelProvider } from "../types"; +import frontendModelProvider from "./frontend"; import { getCompletion } from "./getCompletion"; const supportedModels = ["7b-chat", "13b-chat", "70b-chat"] as const; -type SupportedModel = (typeof supportedModels)[number]; +export type SupportedModel = (typeof supportedModels)[number]; export type ReplicateLlama2Input = { model: SupportedModel; @@ -25,12 +26,6 @@ export type ReplicateLlama2Provider = ModelProvider< >; const modelProvider: ReplicateLlama2Provider = { - name: "OpenAI ChatCompletion", - models: { - "7b-chat": {}, - "13b-chat": {}, - "70b-chat": {}, - }, getModel: (input) => { if (supportedModels.includes(input.model)) return input.model; @@ -69,6 +64,7 @@ const modelProvider: ReplicateLlama2Provider = { }, shouldStream: (input) => input.stream ?? false, getCompletion, + ...frontendModelProvider, }; export default modelProvider; diff --git a/src/modelProviders/types.ts b/src/modelProviders/types.ts index d4eafed..c9023d8 100644 --- a/src/modelProviders/types.ts +++ b/src/modelProviders/types.ts @@ -3,11 +3,18 @@ import { type JsonValue } from "type-fest"; export type SupportedProvider = "openai/ChatCompletion" | "replicate/llama2"; -type ModelProviderModel = { +type ModelInfo = { name?: string; learnMore?: string; }; +export type FrontendModelProvider = { + name: string; + models: Record; + + normalizeOutput: (output: OutputSchema) => NormalizedOutput; +}; + export type CompletionResponse = | { type: "error"; message: string; autoRetry: boolean; statusCode?: number } | { @@ -21,8 +28,6 @@ export type CompletionResponse = }; export type ModelProvider = { - name: string; - models: Record; getModel: (input: InputSchema) => SupportedModels | null; shouldStream: (input: InputSchema) => boolean; inputSchema: JSONSchema4; @@ -33,7 +38,7 @@ export type ModelProvider; export type NormalizedOutput = | { @@ -44,7 +49,3 @@ export type NormalizedOutput = type: "json"; value: JsonValue; }; - -export type ModelProviderFrontend> = { - normalizeOutput: (output: NonNullable) => NormalizedOutput; -};