Slightly better typings for ModelProviders

Still not great because the `any`s loosen some call sites up more than I'd like, but better than the broken types before.
This commit is contained in:
Kyle Corbitt
2023-07-21 06:50:05 -07:00
parent a5d972005e
commit 7e1fbb3767
6 changed files with 12 additions and 16 deletions

View File

@@ -88,11 +88,9 @@ export default function OutputCell({
}
const normalizedOutput = modelOutput
? // @ts-expect-error TODO FIX ASAP
provider.normalizeOutput(modelOutput.output as unknown as OutputSchema)
? provider.normalizeOutput(modelOutput.output)
: streamedMessage
? // @ts-expect-error TODO FIX ASAP
provider.normalizeOutput(streamedMessage)
? provider.normalizeOutput(streamedMessage)
: null;
if (modelOutput && normalizedOutput?.type === "json") {

View File

@@ -1,9 +1,10 @@
import openaiChatCompletion from "./openai-ChatCompletion";
import replicateLlama2 from "./replicate-llama2";
import { type SupportedProvider, type ModelProvider } from "./types";
const modelProviders = {
const modelProviders: Record<SupportedProvider, ModelProvider<any, any, any>> = {
"openai/ChatCompletion": openaiChatCompletion,
"replicate/llama2": replicateLlama2,
} as const;
};
export default modelProviders;

View File

@@ -1,14 +1,15 @@
import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend";
import replicateLlama2Frontend from "./replicate-llama2/frontend";
import { type SupportedProvider, type ModelProviderFrontend } from "./types";
// TODO: make sure we get a typescript error if you forget to add a provider here
// Keep attributes here that need to be accessible from the frontend. We can't
// just include them in the default `modelProviders` object because it has some
// transient dependencies that can only be imported on the server.
const modelProvidersFrontend = {
const modelProvidersFrontend: Record<SupportedProvider, ModelProviderFrontend<any>> = {
"openai/ChatCompletion": openaiChatCompletionFrontend,
"replicate/llama2": replicateLlama2Frontend,
} as const;
};
export default modelProvidersFrontend;

View File

@@ -1,6 +1,8 @@
import { type JSONSchema4 } from "json-schema";
import { type JsonValue } from "type-fest";
export type SupportedProvider = "openai/ChatCompletion" | "replicate/llama2";
type ModelProviderModel = {
name?: string;
learnMore?: string;

View File

@@ -99,7 +99,6 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
const provider = modelProviders[prompt.modelProvider];
// @ts-expect-error TODO FIX ASAP
const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null;
if (streamingChannel) {
@@ -116,8 +115,6 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
: null;
for (let i = 0; true; i++) {
// @ts-expect-error TODO FIX ASAP
const response = await provider.getCompletion(prompt.modelInput, onStream);
if (response.type === "success") {
const inputHash = hashPrompt(prompt);
@@ -126,7 +123,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
data: {
scenarioVariantCellId,
inputHash,
output: response.value as unknown as Prisma.InputJsonObject,
output: response.value as Prisma.InputJsonObject,
timeToComplete: response.timeToComplete,
promptTokens: response.promptTokens,
completionTokens: response.completionTokens,
@@ -154,7 +151,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
errorMessage: response.message,
statusCode: response.statusCode,
retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
retrievalStatus: shouldRetry ? "PENDING" : "ERROR",
retrievalStatus: "ERROR",
},
});

View File

@@ -70,7 +70,6 @@ export default async function parseConstructFn(
// We've validated the JSON schema so this should be safe
const input = prompt.input as Parameters<(typeof provider)["getModel"]>[0];
// @ts-expect-error TODO FIX ASAP
const model = provider.getModel(input);
if (!model) {
return {
@@ -80,8 +79,6 @@ export default async function parseConstructFn(
return {
modelProvider: prompt.modelProvider as keyof typeof modelProviders,
// @ts-expect-error TODO FIX ASAP
model,
modelInput: input,
};