Slightly better typings for ModelProviders
Still not great because the `any`s loosen some call sites up more than I'd like, but better than the broken types before.
This commit is contained in:
@@ -88,11 +88,9 @@ export default function OutputCell({
|
|||||||
}
|
}
|
||||||
|
|
||||||
const normalizedOutput = modelOutput
|
const normalizedOutput = modelOutput
|
||||||
? // @ts-expect-error TODO FIX ASAP
|
? provider.normalizeOutput(modelOutput.output)
|
||||||
provider.normalizeOutput(modelOutput.output as unknown as OutputSchema)
|
|
||||||
: streamedMessage
|
: streamedMessage
|
||||||
? // @ts-expect-error TODO FIX ASAP
|
? provider.normalizeOutput(streamedMessage)
|
||||||
provider.normalizeOutput(streamedMessage)
|
|
||||||
: null;
|
: null;
|
||||||
|
|
||||||
if (modelOutput && normalizedOutput?.type === "json") {
|
if (modelOutput && normalizedOutput?.type === "json") {
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import openaiChatCompletion from "./openai-ChatCompletion";
|
import openaiChatCompletion from "./openai-ChatCompletion";
|
||||||
import replicateLlama2 from "./replicate-llama2";
|
import replicateLlama2 from "./replicate-llama2";
|
||||||
|
import { type SupportedProvider, type ModelProvider } from "./types";
|
||||||
|
|
||||||
const modelProviders = {
|
const modelProviders: Record<SupportedProvider, ModelProvider<any, any, any>> = {
|
||||||
"openai/ChatCompletion": openaiChatCompletion,
|
"openai/ChatCompletion": openaiChatCompletion,
|
||||||
"replicate/llama2": replicateLlama2,
|
"replicate/llama2": replicateLlama2,
|
||||||
} as const;
|
};
|
||||||
|
|
||||||
export default modelProviders;
|
export default modelProviders;
|
||||||
|
|||||||
@@ -1,14 +1,15 @@
|
|||||||
import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend";
|
import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend";
|
||||||
import replicateLlama2Frontend from "./replicate-llama2/frontend";
|
import replicateLlama2Frontend from "./replicate-llama2/frontend";
|
||||||
|
import { type SupportedProvider, type ModelProviderFrontend } from "./types";
|
||||||
|
|
||||||
// TODO: make sure we get a typescript error if you forget to add a provider here
|
// TODO: make sure we get a typescript error if you forget to add a provider here
|
||||||
|
|
||||||
// Keep attributes here that need to be accessible from the frontend. We can't
|
// Keep attributes here that need to be accessible from the frontend. We can't
|
||||||
// just include them in the default `modelProviders` object because it has some
|
// just include them in the default `modelProviders` object because it has some
|
||||||
// transient dependencies that can only be imported on the server.
|
// transient dependencies that can only be imported on the server.
|
||||||
const modelProvidersFrontend = {
|
const modelProvidersFrontend: Record<SupportedProvider, ModelProviderFrontend<any>> = {
|
||||||
"openai/ChatCompletion": openaiChatCompletionFrontend,
|
"openai/ChatCompletion": openaiChatCompletionFrontend,
|
||||||
"replicate/llama2": replicateLlama2Frontend,
|
"replicate/llama2": replicateLlama2Frontend,
|
||||||
} as const;
|
};
|
||||||
|
|
||||||
export default modelProvidersFrontend;
|
export default modelProvidersFrontend;
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import { type JSONSchema4 } from "json-schema";
|
import { type JSONSchema4 } from "json-schema";
|
||||||
import { type JsonValue } from "type-fest";
|
import { type JsonValue } from "type-fest";
|
||||||
|
|
||||||
|
export type SupportedProvider = "openai/ChatCompletion" | "replicate/llama2";
|
||||||
|
|
||||||
type ModelProviderModel = {
|
type ModelProviderModel = {
|
||||||
name?: string;
|
name?: string;
|
||||||
learnMore?: string;
|
learnMore?: string;
|
||||||
|
|||||||
@@ -99,7 +99,6 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
|
|
||||||
const provider = modelProviders[prompt.modelProvider];
|
const provider = modelProviders[prompt.modelProvider];
|
||||||
|
|
||||||
// @ts-expect-error TODO FIX ASAP
|
|
||||||
const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null;
|
const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null;
|
||||||
|
|
||||||
if (streamingChannel) {
|
if (streamingChannel) {
|
||||||
@@ -116,8 +115,6 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
: null;
|
: null;
|
||||||
|
|
||||||
for (let i = 0; true; i++) {
|
for (let i = 0; true; i++) {
|
||||||
// @ts-expect-error TODO FIX ASAP
|
|
||||||
|
|
||||||
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
||||||
if (response.type === "success") {
|
if (response.type === "success") {
|
||||||
const inputHash = hashPrompt(prompt);
|
const inputHash = hashPrompt(prompt);
|
||||||
@@ -126,7 +123,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
data: {
|
data: {
|
||||||
scenarioVariantCellId,
|
scenarioVariantCellId,
|
||||||
inputHash,
|
inputHash,
|
||||||
output: response.value as unknown as Prisma.InputJsonObject,
|
output: response.value as Prisma.InputJsonObject,
|
||||||
timeToComplete: response.timeToComplete,
|
timeToComplete: response.timeToComplete,
|
||||||
promptTokens: response.promptTokens,
|
promptTokens: response.promptTokens,
|
||||||
completionTokens: response.completionTokens,
|
completionTokens: response.completionTokens,
|
||||||
@@ -154,7 +151,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
errorMessage: response.message,
|
errorMessage: response.message,
|
||||||
statusCode: response.statusCode,
|
statusCode: response.statusCode,
|
||||||
retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
|
retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
|
||||||
retrievalStatus: shouldRetry ? "PENDING" : "ERROR",
|
retrievalStatus: "ERROR",
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -70,7 +70,6 @@ export default async function parseConstructFn(
|
|||||||
// We've validated the JSON schema so this should be safe
|
// We've validated the JSON schema so this should be safe
|
||||||
const input = prompt.input as Parameters<(typeof provider)["getModel"]>[0];
|
const input = prompt.input as Parameters<(typeof provider)["getModel"]>[0];
|
||||||
|
|
||||||
// @ts-expect-error TODO FIX ASAP
|
|
||||||
const model = provider.getModel(input);
|
const model = provider.getModel(input);
|
||||||
if (!model) {
|
if (!model) {
|
||||||
return {
|
return {
|
||||||
@@ -80,8 +79,6 @@ export default async function parseConstructFn(
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
modelProvider: prompt.modelProvider as keyof typeof modelProviders,
|
modelProvider: prompt.modelProvider as keyof typeof modelProviders,
|
||||||
// @ts-expect-error TODO FIX ASAP
|
|
||||||
|
|
||||||
model,
|
model,
|
||||||
modelInput: input,
|
modelInput: input,
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user