diff --git a/src/components/OutputsTable/OutputCell/OutputCell.tsx b/src/components/OutputsTable/OutputCell/OutputCell.tsx index d9b9d81..c8bddce 100644 --- a/src/components/OutputsTable/OutputCell/OutputCell.tsx +++ b/src/components/OutputsTable/OutputCell/OutputCell.tsx @@ -88,11 +88,9 @@ export default function OutputCell({ } const normalizedOutput = modelOutput - ? // @ts-expect-error TODO FIX ASAP - provider.normalizeOutput(modelOutput.output as unknown as OutputSchema) + ? provider.normalizeOutput(modelOutput.output) : streamedMessage - ? // @ts-expect-error TODO FIX ASAP - provider.normalizeOutput(streamedMessage) + ? provider.normalizeOutput(streamedMessage) : null; if (modelOutput && normalizedOutput?.type === "json") { diff --git a/src/modelProviders/modelProviders.ts b/src/modelProviders/modelProviders.ts index 0013943..3ec8442 100644 --- a/src/modelProviders/modelProviders.ts +++ b/src/modelProviders/modelProviders.ts @@ -1,9 +1,10 @@ import openaiChatCompletion from "./openai-ChatCompletion"; import replicateLlama2 from "./replicate-llama2"; +import { type SupportedProvider, type ModelProvider } from "./types"; -const modelProviders = { +const modelProviders: Record> = { "openai/ChatCompletion": openaiChatCompletion, "replicate/llama2": replicateLlama2, -} as const; +}; export default modelProviders; diff --git a/src/modelProviders/modelProvidersFrontend.ts b/src/modelProviders/modelProvidersFrontend.ts index e1ef03c..7068836 100644 --- a/src/modelProviders/modelProvidersFrontend.ts +++ b/src/modelProviders/modelProvidersFrontend.ts @@ -1,14 +1,15 @@ import openaiChatCompletionFrontend from "./openai-ChatCompletion/frontend"; import replicateLlama2Frontend from "./replicate-llama2/frontend"; +import { type SupportedProvider, type ModelProviderFrontend } from "./types"; // TODO: make sure we get a typescript error if you forget to add a provider here // Keep attributes here that need to be accessible from the frontend. We can't // just include them in the default `modelProviders` object because it has some // transient dependencies that can only be imported on the server. -const modelProvidersFrontend = { +const modelProvidersFrontend: Record> = { "openai/ChatCompletion": openaiChatCompletionFrontend, "replicate/llama2": replicateLlama2Frontend, -} as const; +}; export default modelProvidersFrontend; diff --git a/src/modelProviders/types.ts b/src/modelProviders/types.ts index 03cd846..d4eafed 100644 --- a/src/modelProviders/types.ts +++ b/src/modelProviders/types.ts @@ -1,6 +1,8 @@ import { type JSONSchema4 } from "json-schema"; import { type JsonValue } from "type-fest"; +export type SupportedProvider = "openai/ChatCompletion" | "replicate/llama2"; + type ModelProviderModel = { name?: string; learnMore?: string; diff --git a/src/server/tasks/queryLLM.task.ts b/src/server/tasks/queryLLM.task.ts index 29affe7..9afb050 100644 --- a/src/server/tasks/queryLLM.task.ts +++ b/src/server/tasks/queryLLM.task.ts @@ -99,7 +99,6 @@ export const queryLLM = defineTask("queryLLM", async (task) => { const provider = modelProviders[prompt.modelProvider]; - // @ts-expect-error TODO FIX ASAP const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null; if (streamingChannel) { @@ -116,8 +115,6 @@ export const queryLLM = defineTask("queryLLM", async (task) => { : null; for (let i = 0; true; i++) { - // @ts-expect-error TODO FIX ASAP - const response = await provider.getCompletion(prompt.modelInput, onStream); if (response.type === "success") { const inputHash = hashPrompt(prompt); @@ -126,7 +123,7 @@ export const queryLLM = defineTask("queryLLM", async (task) => { data: { scenarioVariantCellId, inputHash, - output: response.value as unknown as Prisma.InputJsonObject, + output: response.value as Prisma.InputJsonObject, timeToComplete: response.timeToComplete, promptTokens: response.promptTokens, completionTokens: response.completionTokens, @@ -154,7 +151,7 @@ export const queryLLM = defineTask("queryLLM", async (task) => { errorMessage: response.message, statusCode: response.statusCode, retryTime: shouldRetry ? new Date(Date.now() + delay) : null, - retrievalStatus: shouldRetry ? "PENDING" : "ERROR", + retrievalStatus: "ERROR", }, }); diff --git a/src/server/utils/parseConstructFn.ts b/src/server/utils/parseConstructFn.ts index 1b0d8eb..8bfd667 100644 --- a/src/server/utils/parseConstructFn.ts +++ b/src/server/utils/parseConstructFn.ts @@ -70,7 +70,6 @@ export default async function parseConstructFn( // We've validated the JSON schema so this should be safe const input = prompt.input as Parameters<(typeof provider)["getModel"]>[0]; - // @ts-expect-error TODO FIX ASAP const model = provider.getModel(input); if (!model) { return { @@ -80,8 +79,6 @@ export default async function parseConstructFn( return { modelProvider: prompt.modelProvider as keyof typeof modelProviders, - // @ts-expect-error TODO FIX ASAP - model, modelInput: input, };