Better division of labor between frontend and backend model providers

A bit better thinking on which types go where.
This commit is contained in:
Kyle Corbitt
2023-07-21 11:49:35 -07:00
parent 7e1fbb3767
commit 741128e0f4
7 changed files with 58 additions and 48 deletions

View File

@@ -10,7 +10,7 @@ import useSocket from "~/utils/useSocket";
import { OutputStats } from "./OutputStats"; import { OutputStats } from "./OutputStats";
import { ErrorHandler } from "./ErrorHandler"; import { ErrorHandler } from "./ErrorHandler";
import { CellOptions } from "./CellOptions"; import { CellOptions } from "./CellOptions";
import modelProvidersFrontend from "~/modelProviders/modelProvidersFrontend"; import frontendModelProviders from "~/modelProviders/frontendModelProviders";
export default function OutputCell({ export default function OutputCell({
scenario, scenario,
@@ -40,7 +40,7 @@ export default function OutputCell({
); );
const provider = const provider =
modelProvidersFrontend[variant.modelProvider as keyof typeof modelProvidersFrontend]; frontendModelProviders[variant.modelProvider as keyof typeof frontendModelProviders];
type OutputSchema = Parameters<typeof provider.normalizeOutput>[0]; type OutputSchema = Parameters<typeof provider.normalizeOutput>[0];

View File

@@ -1,15 +1,15 @@
import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend"; import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend";
import replicateLlama2Frontend from "./replicate-llama2/frontend"; import replicateLlama2Frontend from "./replicate-llama2/frontend";
import { type SupportedProvider, type ModelProviderFrontend } from "./types"; import { type SupportedProvider, type FrontendModelProvider } from "./types";
// TODO: make sure we get a typescript error if you forget to add a provider here // TODO: make sure we get a typescript error if you forget to add a provider here
// Keep attributes here that need to be accessible from the frontend. We can't // Keep attributes here that need to be accessible from the frontend. We can't
// just include them in the default `modelProviders` object because it has some // just include them in the default `modelProviders` object because it has some
// transient dependencies that can only be imported on the server. // transient dependencies that can only be imported on the server.
const modelProvidersFrontend: Record<SupportedProvider, ModelProviderFrontend<any>> = { const frontendModelProviders: Record<SupportedProvider, FrontendModelProvider<any, any>> = {
"openai/ChatCompletion": openaiChatCompletionFrontend, "openai/ChatCompletion": openaiChatCompletionFrontend,
"replicate/llama2": replicateLlama2Frontend, "replicate/llama2": replicateLlama2Frontend,
}; };
export default modelProvidersFrontend; export default frontendModelProviders;

View File

@@ -1,8 +1,30 @@
import { type JsonValue } from "type-fest"; import { type JsonValue } from "type-fest";
import { type OpenaiChatModelProvider } from "."; import { type SupportedModel } from ".";
import { type ModelProviderFrontend } from "../types"; import { type FrontendModelProvider } from "../types";
import { type ChatCompletion } from "openai/resources/chat";
const frontendModelProvider: FrontendModelProvider<SupportedModel, ChatCompletion> = {
name: "OpenAI ChatCompletion",
models: {
"gpt-4-0613": {
name: "GPT-4",
learnMore: "https://openai.com/gpt-4",
},
"gpt-4-32k-0613": {
name: "GPT-4 32k",
learnMore: "https://openai.com/gpt-4",
},
"gpt-3.5-turbo-0613": {
name: "GPT-3.5 Turbo",
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-16k-0613": {
name: "GPT-3.5 Turbo 16k",
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
},
const modelProviderFrontend: ModelProviderFrontend<OpenaiChatModelProvider> = {
normalizeOutput: (output) => { normalizeOutput: (output) => {
const message = output.choices[0]?.message; const message = output.choices[0]?.message;
if (!message) if (!message)
@@ -39,4 +61,4 @@ const modelProviderFrontend: ModelProviderFrontend<OpenaiChatModelProvider> = {
}, },
}; };
export default modelProviderFrontend; export default frontendModelProvider;

View File

@@ -3,6 +3,7 @@ import { type ModelProvider } from "../types";
import inputSchema from "./codegen/input.schema.json"; import inputSchema from "./codegen/input.schema.json";
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat"; import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
import { getCompletion } from "./getCompletion"; import { getCompletion } from "./getCompletion";
import frontendModelProvider from "./frontend";
const supportedModels = [ const supportedModels = [
"gpt-4-0613", "gpt-4-0613",
@@ -11,7 +12,7 @@ const supportedModels = [
"gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k-0613",
] as const; ] as const;
type SupportedModel = (typeof supportedModels)[number]; export type SupportedModel = (typeof supportedModels)[number];
export type OpenaiChatModelProvider = ModelProvider< export type OpenaiChatModelProvider = ModelProvider<
SupportedModel, SupportedModel,
@@ -20,25 +21,6 @@ export type OpenaiChatModelProvider = ModelProvider<
>; >;
const modelProvider: OpenaiChatModelProvider = { const modelProvider: OpenaiChatModelProvider = {
name: "OpenAI ChatCompletion",
models: {
"gpt-4-0613": {
name: "GPT-4",
learnMore: "https://openai.com/gpt-4",
},
"gpt-4-32k-0613": {
name: "GPT-4 32k",
learnMore: "https://openai.com/gpt-4",
},
"gpt-3.5-turbo-0613": {
name: "GPT-3.5 Turbo",
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-16k-0613": {
name: "GPT-3.5 Turbo 16k",
learnMore: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
},
getModel: (input) => { getModel: (input) => {
if (supportedModels.includes(input.model as SupportedModel)) if (supportedModels.includes(input.model as SupportedModel))
return input.model as SupportedModel; return input.model as SupportedModel;
@@ -57,6 +39,7 @@ const modelProvider: OpenaiChatModelProvider = {
inputSchema: inputSchema as JSONSchema4, inputSchema: inputSchema as JSONSchema4,
shouldStream: (input) => input.stream ?? false, shouldStream: (input) => input.stream ?? false,
getCompletion, getCompletion,
...frontendModelProvider,
}; };
export default modelProvider; export default modelProvider;

View File

@@ -1,7 +1,15 @@
import { type ReplicateLlama2Provider } from "."; import { type SupportedModel, type ReplicateLlama2Output } from ".";
import { type ModelProviderFrontend } from "../types"; import { type FrontendModelProvider } from "../types";
const frontendModelProvider: FrontendModelProvider<SupportedModel, ReplicateLlama2Output> = {
name: "Replicate Llama2",
models: {
"7b-chat": {},
"13b-chat": {},
"70b-chat": {},
},
const modelProviderFrontend: ModelProviderFrontend<ReplicateLlama2Provider> = {
normalizeOutput: (output) => { normalizeOutput: (output) => {
return { return {
type: "text", type: "text",
@@ -10,4 +18,4 @@ const modelProviderFrontend: ModelProviderFrontend<ReplicateLlama2Provider> = {
}, },
}; };
export default modelProviderFrontend; export default frontendModelProvider;

View File

@@ -1,9 +1,10 @@
import { type ModelProvider } from "../types"; import { type ModelProvider } from "../types";
import frontendModelProvider from "./frontend";
import { getCompletion } from "./getCompletion"; import { getCompletion } from "./getCompletion";
const supportedModels = ["7b-chat", "13b-chat", "70b-chat"] as const; const supportedModels = ["7b-chat", "13b-chat", "70b-chat"] as const;
type SupportedModel = (typeof supportedModels)[number]; export type SupportedModel = (typeof supportedModels)[number];
export type ReplicateLlama2Input = { export type ReplicateLlama2Input = {
model: SupportedModel; model: SupportedModel;
@@ -25,12 +26,6 @@ export type ReplicateLlama2Provider = ModelProvider<
>; >;
const modelProvider: ReplicateLlama2Provider = { const modelProvider: ReplicateLlama2Provider = {
name: "OpenAI ChatCompletion",
models: {
"7b-chat": {},
"13b-chat": {},
"70b-chat": {},
},
getModel: (input) => { getModel: (input) => {
if (supportedModels.includes(input.model)) return input.model; if (supportedModels.includes(input.model)) return input.model;
@@ -69,6 +64,7 @@ const modelProvider: ReplicateLlama2Provider = {
}, },
shouldStream: (input) => input.stream ?? false, shouldStream: (input) => input.stream ?? false,
getCompletion, getCompletion,
...frontendModelProvider,
}; };
export default modelProvider; export default modelProvider;

View File

@@ -3,11 +3,18 @@ import { type JsonValue } from "type-fest";
export type SupportedProvider = "openai/ChatCompletion" | "replicate/llama2"; export type SupportedProvider = "openai/ChatCompletion" | "replicate/llama2";
type ModelProviderModel = { type ModelInfo = {
name?: string; name?: string;
learnMore?: string; learnMore?: string;
}; };
export type FrontendModelProvider<SupportedModels extends string, OutputSchema> = {
name: string;
models: Record<SupportedModels, ModelInfo>;
normalizeOutput: (output: OutputSchema) => NormalizedOutput;
};
export type CompletionResponse<T> = export type CompletionResponse<T> =
| { type: "error"; message: string; autoRetry: boolean; statusCode?: number } | { type: "error"; message: string; autoRetry: boolean; statusCode?: number }
| { | {
@@ -21,8 +28,6 @@ export type CompletionResponse<T> =
}; };
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = { export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
name: string;
models: Record<SupportedModels, ModelProviderModel>;
getModel: (input: InputSchema) => SupportedModels | null; getModel: (input: InputSchema) => SupportedModels | null;
shouldStream: (input: InputSchema) => boolean; shouldStream: (input: InputSchema) => boolean;
inputSchema: JSONSchema4; inputSchema: JSONSchema4;
@@ -33,7 +38,7 @@ export type ModelProvider<SupportedModels extends string, InputSchema, OutputSch
// This is just a convenience for type inference, don't use it at runtime // This is just a convenience for type inference, don't use it at runtime
_outputSchema?: OutputSchema | null; _outputSchema?: OutputSchema | null;
}; } & FrontendModelProvider<SupportedModels, OutputSchema>;
export type NormalizedOutput = export type NormalizedOutput =
| { | {
@@ -44,7 +49,3 @@ export type NormalizedOutput =
type: "json"; type: "json";
value: JsonValue; value: JsonValue;
}; };
export type ModelProviderFrontend<ModelProviderT extends ModelProvider<any, any, any>> = {
normalizeOutput: (output: NonNullable<ModelProviderT["_outputSchema"]>) => NormalizedOutput;
};