Slightly better typings for ModelProviders

Still not great because the `any`s loosen some call sites up more than I'd like, but better than the broken types before.
This commit is contained in:
Kyle Corbitt
2023-07-21 06:50:05 -07:00
parent a5d972005e
commit 7e1fbb3767
6 changed files with 12 additions and 16 deletions

View File

@@ -88,11 +88,9 @@ export default function OutputCell({
} }
const normalizedOutput = modelOutput const normalizedOutput = modelOutput
? // @ts-expect-error TODO FIX ASAP ? provider.normalizeOutput(modelOutput.output)
provider.normalizeOutput(modelOutput.output as unknown as OutputSchema)
: streamedMessage : streamedMessage
? // @ts-expect-error TODO FIX ASAP ? provider.normalizeOutput(streamedMessage)
provider.normalizeOutput(streamedMessage)
: null; : null;
if (modelOutput && normalizedOutput?.type === "json") { if (modelOutput && normalizedOutput?.type === "json") {

View File

@@ -1,9 +1,10 @@
import openaiChatCompletion from "./openai-ChatCompletion"; import openaiChatCompletion from "./openai-ChatCompletion";
import replicateLlama2 from "./replicate-llama2"; import replicateLlama2 from "./replicate-llama2";
import { type SupportedProvider, type ModelProvider } from "./types";
const modelProviders = { const modelProviders: Record<SupportedProvider, ModelProvider<any, any, any>> = {
"openai/ChatCompletion": openaiChatCompletion, "openai/ChatCompletion": openaiChatCompletion,
"replicate/llama2": replicateLlama2, "replicate/llama2": replicateLlama2,
} as const; };
export default modelProviders; export default modelProviders;

View File

@@ -1,14 +1,15 @@
import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend"; import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend";
import replicateLlama2Frontend from "./replicate-llama2/frontend"; import replicateLlama2Frontend from "./replicate-llama2/frontend";
import { type SupportedProvider, type ModelProviderFrontend } from "./types";
// TODO: make sure we get a typescript error if you forget to add a provider here // TODO: make sure we get a typescript error if you forget to add a provider here
// Keep attributes here that need to be accessible from the frontend. We can't // Keep attributes here that need to be accessible from the frontend. We can't
// just include them in the default `modelProviders` object because it has some // just include them in the default `modelProviders` object because it has some
// transient dependencies that can only be imported on the server. // transient dependencies that can only be imported on the server.
const modelProvidersFrontend = { const modelProvidersFrontend: Record<SupportedProvider, ModelProviderFrontend<any>> = {
"openai/ChatCompletion": openaiChatCompletionFrontend, "openai/ChatCompletion": openaiChatCompletionFrontend,
"replicate/llama2": replicateLlama2Frontend, "replicate/llama2": replicateLlama2Frontend,
} as const; };
export default modelProvidersFrontend; export default modelProvidersFrontend;

View File

@@ -1,6 +1,8 @@
import { type JSONSchema4 } from "json-schema"; import { type JSONSchema4 } from "json-schema";
import { type JsonValue } from "type-fest"; import { type JsonValue } from "type-fest";
export type SupportedProvider = "openai/ChatCompletion" | "replicate/llama2";
type ModelProviderModel = { type ModelProviderModel = {
name?: string; name?: string;
learnMore?: string; learnMore?: string;

View File

@@ -99,7 +99,6 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
const provider = modelProviders[prompt.modelProvider]; const provider = modelProviders[prompt.modelProvider];
// @ts-expect-error TODO FIX ASAP
const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null; const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null;
if (streamingChannel) { if (streamingChannel) {
@@ -116,8 +115,6 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
: null; : null;
for (let i = 0; true; i++) { for (let i = 0; true; i++) {
// @ts-expect-error TODO FIX ASAP
const response = await provider.getCompletion(prompt.modelInput, onStream); const response = await provider.getCompletion(prompt.modelInput, onStream);
if (response.type === "success") { if (response.type === "success") {
const inputHash = hashPrompt(prompt); const inputHash = hashPrompt(prompt);
@@ -126,7 +123,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
data: { data: {
scenarioVariantCellId, scenarioVariantCellId,
inputHash, inputHash,
output: response.value as unknown as Prisma.InputJsonObject, output: response.value as Prisma.InputJsonObject,
timeToComplete: response.timeToComplete, timeToComplete: response.timeToComplete,
promptTokens: response.promptTokens, promptTokens: response.promptTokens,
completionTokens: response.completionTokens, completionTokens: response.completionTokens,
@@ -154,7 +151,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
errorMessage: response.message, errorMessage: response.message,
statusCode: response.statusCode, statusCode: response.statusCode,
retryTime: shouldRetry ? new Date(Date.now() + delay) : null, retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
retrievalStatus: shouldRetry ? "PENDING" : "ERROR", retrievalStatus: "ERROR",
}, },
}); });

View File

@@ -70,7 +70,6 @@ export default async function parseConstructFn(
// We've validated the JSON schema so this should be safe // We've validated the JSON schema so this should be safe
const input = prompt.input as Parameters<(typeof provider)["getModel"]>[0]; const input = prompt.input as Parameters<(typeof provider)["getModel"]>[0];
// @ts-expect-error TODO FIX ASAP
const model = provider.getModel(input); const model = provider.getModel(input);
if (!model) { if (!model) {
return { return {
@@ -80,8 +79,6 @@ export default async function parseConstructFn(
return { return {
modelProvider: prompt.modelProvider as keyof typeof modelProviders, modelProvider: prompt.modelProvider as keyof typeof modelProviders,
// @ts-expect-error TODO FIX ASAP
model, model,
modelInput: input, modelInput: input,
}; };