Better division of labor between frontend and backend model providers

A bit better thinking on which types go where.
This commit is contained in:
Kyle Corbitt
2023-07-21 11:49:35 -07:00
parent 7e1fbb3767
commit 741128e0f4
7 changed files with 58 additions and 48 deletions

View File

@@ -10,7 +10,7 @@ import useSocket from "~/utils/useSocket";
import { OutputStats } from "./OutputStats";
import { ErrorHandler } from "./ErrorHandler";
import { CellOptions } from "./CellOptions";
import modelProvidersFrontend from "~/modelProviders/modelProvidersFrontend";
import frontendModelProviders from "~/modelProviders/frontendModelProviders";
export default function OutputCell({
scenario,
@@ -40,7 +40,7 @@ export default function OutputCell({
);
const provider =
modelProvidersFrontend[variant.modelProvider as keyof typeof modelProvidersFrontend];
frontendModelProviders[variant.modelProvider as keyof typeof frontendModelProviders];
type OutputSchema = Parameters<typeof provider.normalizeOutput>[0];

View File

@@ -1,15 +1,15 @@
import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend";
import replicateLlama2Frontend from "./replicate-llama2/frontend";
import { type SupportedProvider, type ModelProviderFrontend } from "./types";
import { type SupportedProvider, type FrontendModelProvider } from "./types";
// TODO: make sure we get a typescript error if you forget to add a provider here
// Keep attributes here that need to be accessible from the frontend. We can't
// just include them in the default `modelProviders` object because it has some
// transient dependencies that can only be imported on the server.
const modelProvidersFrontend: Record<SupportedProvider, ModelProviderFrontend<any>> = {
const frontendModelProviders: Record<SupportedProvider, FrontendModelProvider<any, any>> = {
"openai/ChatCompletion": openaiChatCompletionFrontend,
"replicate/llama2": replicateLlama2Frontend,
};
export default modelProvidersFrontend;
export default frontendModelProviders;

View File

@@ -1,8 +1,30 @@
import { type JsonValue } from "type-fest";
import { type OpenaiChatModelProvider } from ".";
import { type ModelProviderFrontend } from "../types";
import { type SupportedModel } from ".";
import { type FrontendModelProvider } from "../types";
import { type ChatCompletion } from "openai/resources/chat";
const frontendModelProvider: FrontendModelProvider<SupportedModel, ChatCompletion> = {
name: "OpenAI ChatCompletion",
models: {
"gpt-4-0613": {
name: "GPT-4",
learnMore: "https://openai.com/gpt-4",
},
"gpt-4-32k-0613": {
name: "GPT-4 32k",
learnMore: "https://openai.com/gpt-4",
},
"gpt-3.5-turbo-0613": {
name: "GPT-3.5 Turbo",
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-16k-0613": {
name: "GPT-3.5 Turbo 16k",
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
},
const modelProviderFrontend: ModelProviderFrontend<OpenaiChatModelProvider> = {
normalizeOutput: (output) => {
const message = output.choices[0]?.message;
if (!message)
@@ -39,4 +61,4 @@ const modelProviderFrontend: ModelProviderFrontend<OpenaiChatModelProvider> = {
},
};
export default modelProviderFrontend;
export default frontendModelProvider;

View File

@@ -3,6 +3,7 @@ import { type ModelProvider } from "../types";
import inputSchema from "./codegen/input.schema.json";
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
import { getCompletion } from "./getCompletion";
import frontendModelProvider from "./frontend";
const supportedModels = [
"gpt-4-0613",
@@ -11,7 +12,7 @@ const supportedModels = [
"gpt-3.5-turbo-16k-0613",
] as const;
type SupportedModel = (typeof supportedModels)[number];
export type SupportedModel = (typeof supportedModels)[number];
export type OpenaiChatModelProvider = ModelProvider<
SupportedModel,
@@ -20,25 +21,6 @@ export type OpenaiChatModelProvider = ModelProvider<
>;
const modelProvider: OpenaiChatModelProvider = {
name: "OpenAI ChatCompletion",
models: {
"gpt-4-0613": {
name: "GPT-4",
learnMore: "https://openai.com/gpt-4",
},
"gpt-4-32k-0613": {
name: "GPT-4 32k",
learnMore: "https://openai.com/gpt-4",
},
"gpt-3.5-turbo-0613": {
name: "GPT-3.5 Turbo",
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-16k-0613": {
name: "GPT-3.5 Turbo 16k",
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
},
getModel: (input) => {
if (supportedModels.includes(input.model as SupportedModel))
return input.model as SupportedModel;
@@ -57,6 +39,7 @@ const modelProvider: OpenaiChatModelProvider = {
inputSchema: inputSchema as JSONSchema4,
shouldStream: (input) => input.stream ?? false,
getCompletion,
...frontendModelProvider,
};
export default modelProvider;

View File

@@ -1,7 +1,15 @@
import { type ReplicateLlama2Provider } from ".";
import { type ModelProviderFrontend } from "../types";
import { type SupportedModel, type ReplicateLlama2Output } from ".";
import { type FrontendModelProvider } from "../types";
const frontendModelProvider: FrontendModelProvider<SupportedModel, ReplicateLlama2Output> = {
name: "Replicate Llama2",
models: {
"7b-chat": {},
"13b-chat": {},
"70b-chat": {},
},
const modelProviderFrontend: ModelProviderFrontend<ReplicateLlama2Provider> = {
normalizeOutput: (output) => {
return {
type: "text",
@@ -10,4 +18,4 @@ const modelProviderFrontend: ModelProviderFrontend<ReplicateLlama2Provider> = {
},
};
export default modelProviderFrontend;
export default frontendModelProvider;

View File

@@ -1,9 +1,10 @@
import { type ModelProvider } from "../types";
import frontendModelProvider from "./frontend";
import { getCompletion } from "./getCompletion";
const supportedModels = ["7b-chat", "13b-chat", "70b-chat"] as const;
type SupportedModel = (typeof supportedModels)[number];
export type SupportedModel = (typeof supportedModels)[number];
export type ReplicateLlama2Input = {
model: SupportedModel;
@@ -25,12 +26,6 @@ export type ReplicateLlama2Provider = ModelProvider<
>;
const modelProvider: ReplicateLlama2Provider = {
name: "OpenAI ChatCompletion",
models: {
"7b-chat": {},
"13b-chat": {},
"70b-chat": {},
},
getModel: (input) => {
if (supportedModels.includes(input.model)) return input.model;
@@ -69,6 +64,7 @@ const modelProvider: ReplicateLlama2Provider = {
},
shouldStream: (input) => input.stream ?? false,
getCompletion,
...frontendModelProvider,
};
export default modelProvider;

View File

@@ -3,11 +3,18 @@ import { type JsonValue } from "type-fest";
export type SupportedProvider = "openai/ChatCompletion" | "replicate/llama2";
type ModelProviderModel = {
type ModelInfo = {
name?: string;
learnMore?: string;
};
export type FrontendModelProvider<SupportedModels extends string, OutputSchema> = {
name: string;
models: Record<SupportedModels, ModelInfo>;
normalizeOutput: (output: OutputSchema) => NormalizedOutput;
};
export type CompletionResponse<T> =
| { type: "error"; message: string; autoRetry: boolean; statusCode?: number }
| {
@@ -21,8 +28,6 @@ export type CompletionResponse<T> =
};
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
name: string;
models: Record<SupportedModels, ModelProviderModel>;
getModel: (input: InputSchema) => SupportedModels | null;
shouldStream: (input: InputSchema) => boolean;
inputSchema: JSONSchema4;
@@ -33,7 +38,7 @@ export type ModelProvider<SupportedModels extends string, InputSchema, OutputSch
// This is just a convenience for type inference, don't use it at runtime
_outputSchema?: OutputSchema | null;
};
} & FrontendModelProvider<SupportedModels, OutputSchema>;
export type NormalizedOutput =
| {
@@ -44,7 +49,3 @@ export type NormalizedOutput =
type: "json";
value: JsonValue;
};
export type ModelProviderFrontend<ModelProviderT extends ModelProvider<any, any, any>> = {
normalizeOutput: (output: NonNullable<ModelProviderT["_outputSchema"]>) => NormalizedOutput;
};