Slightly better typings for ModelProviders
Still not great because the `any`s loosen some call sites up more than I'd like, but better than the broken types before.
This commit is contained in:
@@ -88,11 +88,9 @@ export default function OutputCell({
|
||||
}
|
||||
|
||||
const normalizedOutput = modelOutput
|
||||
? // @ts-expect-error TODO FIX ASAP
|
||||
provider.normalizeOutput(modelOutput.output as unknown as OutputSchema)
|
||||
? provider.normalizeOutput(modelOutput.output)
|
||||
: streamedMessage
|
||||
? // @ts-expect-error TODO FIX ASAP
|
||||
provider.normalizeOutput(streamedMessage)
|
||||
? provider.normalizeOutput(streamedMessage)
|
||||
: null;
|
||||
|
||||
if (modelOutput && normalizedOutput?.type === "json") {
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import openaiChatCompletion from "./openai-ChatCompletion";
|
||||
import replicateLlama2 from "./replicate-llama2";
|
||||
import { type SupportedProvider, type ModelProvider } from "./types";
|
||||
|
||||
const modelProviders = {
|
||||
const modelProviders: Record<SupportedProvider, ModelProvider<any, any, any>> = {
|
||||
"openai/ChatCompletion": openaiChatCompletion,
|
||||
"replicate/llama2": replicateLlama2,
|
||||
} as const;
|
||||
};
|
||||
|
||||
export default modelProviders;
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend";
|
||||
import replicateLlama2Frontend from "./replicate-llama2/frontend";
|
||||
import { type SupportedProvider, type ModelProviderFrontend } from "./types";
|
||||
|
||||
// TODO: make sure we get a typescript error if you forget to add a provider here
|
||||
|
||||
// Keep attributes here that need to be accessible from the frontend. We can't
|
||||
// just include them in the default `modelProviders` object because it has some
|
||||
// transient dependencies that can only be imported on the server.
|
||||
const modelProvidersFrontend = {
|
||||
const modelProvidersFrontend: Record<SupportedProvider, ModelProviderFrontend<any>> = {
|
||||
"openai/ChatCompletion": openaiChatCompletionFrontend,
|
||||
"replicate/llama2": replicateLlama2Frontend,
|
||||
} as const;
|
||||
};
|
||||
|
||||
export default modelProvidersFrontend;
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import { type JSONSchema4 } from "json-schema";
|
||||
import { type JsonValue } from "type-fest";
|
||||
|
||||
export type SupportedProvider = "openai/ChatCompletion" | "replicate/llama2";
|
||||
|
||||
type ModelProviderModel = {
|
||||
name?: string;
|
||||
learnMore?: string;
|
||||
|
||||
@@ -99,7 +99,6 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
|
||||
const provider = modelProviders[prompt.modelProvider];
|
||||
|
||||
// @ts-expect-error TODO FIX ASAP
|
||||
const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null;
|
||||
|
||||
if (streamingChannel) {
|
||||
@@ -116,8 +115,6 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
: null;
|
||||
|
||||
for (let i = 0; true; i++) {
|
||||
// @ts-expect-error TODO FIX ASAP
|
||||
|
||||
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
||||
if (response.type === "success") {
|
||||
const inputHash = hashPrompt(prompt);
|
||||
@@ -126,7 +123,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
data: {
|
||||
scenarioVariantCellId,
|
||||
inputHash,
|
||||
output: response.value as unknown as Prisma.InputJsonObject,
|
||||
output: response.value as Prisma.InputJsonObject,
|
||||
timeToComplete: response.timeToComplete,
|
||||
promptTokens: response.promptTokens,
|
||||
completionTokens: response.completionTokens,
|
||||
@@ -154,7 +151,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
errorMessage: response.message,
|
||||
statusCode: response.statusCode,
|
||||
retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
|
||||
retrievalStatus: shouldRetry ? "PENDING" : "ERROR",
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
@@ -70,7 +70,6 @@ export default async function parseConstructFn(
|
||||
// We've validated the JSON schema so this should be safe
|
||||
const input = prompt.input as Parameters<(typeof provider)["getModel"]>[0];
|
||||
|
||||
// @ts-expect-error TODO FIX ASAP
|
||||
const model = provider.getModel(input);
|
||||
if (!model) {
|
||||
return {
|
||||
@@ -80,8 +79,6 @@ export default async function parseConstructFn(
|
||||
|
||||
return {
|
||||
modelProvider: prompt.modelProvider as keyof typeof modelProviders,
|
||||
// @ts-expect-error TODO FIX ASAP
|
||||
|
||||
model,
|
||||
modelInput: input,
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user