From 332a2101c0dba540e9c1a9e62d79a2208328f259 Mon Sep 17 00:00:00 2001 From: Kyle Corbitt Date: Thu, 20 Jul 2023 18:54:26 -0700 Subject: [PATCH] More work on modelProviders I think everything that's OpenAI-specific is inside modelProviders at this point, so we can get started adding more providers. --- .../OutputsTable/OutputCell/OutputCell.tsx | 42 ++--- src/components/OutputsTable/VariantEditor.tsx | 2 - src/modelProviders/generateTypes.ts | 2 +- .../{index.ts => modelProviders.ts} | 0 src/modelProviders/modelProvidersFrontend.ts | 10 ++ .../openai-ChatCompletion/frontend.ts | 42 +++++ .../openai-ChatCompletion/getCompletion.ts | 142 +++++++++++++++++ .../openai-ChatCompletion/index.ts | 12 +- src/modelProviders/types.ts | 38 ++++- src/server/scripts/openai-test.ts | 19 +++ src/server/tasks/queryLLM.task.ts | 146 +++++++----------- src/server/types.ts | 11 -- src/server/utils/fillTemplate.ts | 23 --- src/server/utils/generateNewCell.ts | 6 + src/server/utils/getCompletion.ts | 107 ------------- src/server/utils/openai.ts | 59 ------- src/server/utils/parseConstructFn.ts | 2 +- src/state/sync.tsx | 1 - src/utils/countTokens.ts | 2 +- src/utils/useSocket.ts | 7 +- src/utils/utils.ts | 1 + 21 files changed, 344 insertions(+), 330 deletions(-) rename src/modelProviders/{index.ts => modelProviders.ts} (100%) create mode 100644 src/modelProviders/modelProvidersFrontend.ts create mode 100644 src/modelProviders/openai-ChatCompletion/frontend.ts create mode 100644 src/modelProviders/openai-ChatCompletion/getCompletion.ts create mode 100644 src/server/scripts/openai-test.ts delete mode 100644 src/server/utils/getCompletion.ts create mode 100644 src/utils/utils.ts diff --git a/src/components/OutputsTable/OutputCell/OutputCell.tsx b/src/components/OutputsTable/OutputCell/OutputCell.tsx index 595dd20..c9b2c00 100644 --- a/src/components/OutputsTable/OutputCell/OutputCell.tsx +++ b/src/components/OutputsTable/OutputCell/OutputCell.tsx @@ -6,11 +6,11 @@ import SyntaxHighlighter from "react-syntax-highlighter"; import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs"; import stringify from "json-stringify-pretty-compact"; import { type ReactElement, useState, useEffect } from "react"; -import { type ChatCompletion } from "openai/resources/chat"; import useSocket from "~/utils/useSocket"; import { OutputStats } from "./OutputStats"; import { ErrorHandler } from "./ErrorHandler"; import { CellOptions } from "./CellOptions"; +import modelProvidersFrontend from "~/modelProviders/modelProvidersFrontend"; export default function OutputCell({ scenario, @@ -33,15 +33,17 @@ export default function OutputCell({ if (!templateHasVariables) disabledReason = "Add a value to the scenario variables to see output"; - // if (variant.config === null || Object.keys(variant.config).length === 0) - // disabledReason = "Save your prompt variant to see output"; - const [refetchInterval, setRefetchInterval] = useState(0); const { data: cell, isLoading: queryLoading } = api.scenarioVariantCells.get.useQuery( { scenarioId: scenario.id, variantId: variant.id }, { refetchInterval }, ); + const provider = + modelProvidersFrontend[variant.modelProvider as keyof typeof modelProvidersFrontend]; + + type OutputSchema = Parameters[0]; + const { mutateAsync: hardRefetchMutate } = api.scenarioVariantCells.forceRefetch.useMutation(); const [hardRefetch, hardRefetching] = useHandledAsyncCallback(async () => { await hardRefetchMutate({ scenarioId: scenario.id, variantId: variant.id }); @@ -66,8 +68,7 @@ export default function OutputCell({ const modelOutput = cell?.modelOutput; // Disconnect from socket if we're not streaming anymore - const streamedMessage = useSocket(cell?.streamingChannel); - const streamedContent = streamedMessage?.choices?.[0]?.message?.content; + const streamedMessage = useSocket(cell?.streamingChannel); if (!vars) return null; @@ -86,19 +87,13 @@ export default function OutputCell({ return ; } - const response = modelOutput?.output as unknown as ChatCompletion; - const message = response?.choices?.[0]?.message; - - if (modelOutput && message?.function_call) { - const rawArgs = message.function_call.arguments ?? "null"; - let parsedArgs: string; - try { - parsedArgs = JSON.parse(rawArgs); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } catch (e: any) { - parsedArgs = `Failed to parse arguments as JSON: '${rawArgs}' ERROR: ${e.message as string}`; - } + const normalizedOutput = modelOutput + ? provider.normalizeOutput(modelOutput.output as unknown as OutputSchema) + : streamedMessage + ? provider.normalizeOutput(streamedMessage) + : null; + if (modelOutput && normalizedOutput?.type === "json") { return ( - {stringify( - { - function: message.function_call.name, - args: parsedArgs, - }, - { maxLength: 40 }, - )} + {stringify(normalizedOutput.value, { maxLength: 40 })} @@ -133,8 +122,7 @@ export default function OutputCell({ ); } - const contentToDisplay = - message?.content ?? streamedContent ?? JSON.stringify(modelOutput?.output); + const contentToDisplay = (normalizedOutput?.type === "text" && normalizedOutput.value) || ""; return ( diff --git a/src/components/OutputsTable/VariantEditor.tsx b/src/components/OutputsTable/VariantEditor.tsx index 18ab5cb..67c40c8 100644 --- a/src/components/OutputsTable/VariantEditor.tsx +++ b/src/components/OutputsTable/VariantEditor.tsx @@ -50,8 +50,6 @@ export default function VariantEditor(props: { variant: PromptVariant }) { // Make sure the user defined the prompt with the string "prompt\w*=" somewhere const promptRegex = /definePrompt\(/; if (!promptRegex.test(currentFn)) { - console.log("no prompt"); - console.log(currentFn); toast({ title: "Missing prompt", description: "Please define the prompt (eg. `definePrompt(...`", diff --git a/src/modelProviders/generateTypes.ts b/src/modelProviders/generateTypes.ts index e82ae9d..3f403a6 100644 --- a/src/modelProviders/generateTypes.ts +++ b/src/modelProviders/generateTypes.ts @@ -1,5 +1,5 @@ import { type JSONSchema4Object } from "json-schema"; -import modelProviders from "."; +import modelProviders from "./modelProviders"; import { compile } from "json-schema-to-typescript"; import dedent from "dedent"; diff --git a/src/modelProviders/index.ts b/src/modelProviders/modelProviders.ts similarity index 100% rename from src/modelProviders/index.ts rename to src/modelProviders/modelProviders.ts diff --git a/src/modelProviders/modelProvidersFrontend.ts b/src/modelProviders/modelProvidersFrontend.ts new file mode 100644 index 0000000..42d6d7d --- /dev/null +++ b/src/modelProviders/modelProvidersFrontend.ts @@ -0,0 +1,10 @@ +import modelProviderFrontend from "./openai-ChatCompletion/frontend"; + +// Keep attributes here that need to be accessible from the frontend. We can't +// just include them in the default `modelProviders` object because it has some +// transient dependencies that can only be imported on the server. +const modelProvidersFrontend = { + "openai/ChatCompletion": modelProviderFrontend, +} as const; + +export default modelProvidersFrontend; diff --git a/src/modelProviders/openai-ChatCompletion/frontend.ts b/src/modelProviders/openai-ChatCompletion/frontend.ts new file mode 100644 index 0000000..ba90c1b --- /dev/null +++ b/src/modelProviders/openai-ChatCompletion/frontend.ts @@ -0,0 +1,42 @@ +import { type JsonValue } from "type-fest"; +import { type OpenaiChatModelProvider } from "."; +import { type ModelProviderFrontend } from "../types"; + +const modelProviderFrontend: ModelProviderFrontend = { + normalizeOutput: (output) => { + const message = output.choices[0]?.message; + if (!message) + return { + type: "json", + value: output as unknown as JsonValue, + }; + + if (message.content) { + return { + type: "text", + value: message.content, + }; + } else if (message.function_call) { + let args = message.function_call.arguments ?? ""; + try { + args = JSON.parse(args); + } catch (e) { + // Ignore + } + return { + type: "json", + value: { + ...message.function_call, + arguments: args, + }, + }; + } else { + return { + type: "json", + value: message as unknown as JsonValue, + }; + } + }, +}; + +export default modelProviderFrontend; diff --git a/src/modelProviders/openai-ChatCompletion/getCompletion.ts b/src/modelProviders/openai-ChatCompletion/getCompletion.ts new file mode 100644 index 0000000..352cadf --- /dev/null +++ b/src/modelProviders/openai-ChatCompletion/getCompletion.ts @@ -0,0 +1,142 @@ +/* eslint-disable @typescript-eslint/no-unsafe-call */ +import { + type ChatCompletionChunk, + type ChatCompletion, + type CompletionCreateParams, +} from "openai/resources/chat"; +import { countOpenAIChatTokens } from "~/utils/countTokens"; +import { type CompletionResponse } from "../types"; +import { omit } from "lodash-es"; +import { openai } from "~/server/utils/openai"; +import { type OpenAIChatModel } from "~/server/types"; +import { truthyFilter } from "~/utils/utils"; +import { APIError } from "openai"; +import { modelStats } from "../modelStats"; + +const mergeStreamedChunks = ( + base: ChatCompletion | null, + chunk: ChatCompletionChunk, +): ChatCompletion => { + if (base === null) { + return mergeStreamedChunks({ ...chunk, choices: [] }, chunk); + } + + const choices = [...base.choices]; + for (const choice of chunk.choices) { + const baseChoice = choices.find((c) => c.index === choice.index); + if (baseChoice) { + baseChoice.finish_reason = choice.finish_reason ?? baseChoice.finish_reason; + baseChoice.message = baseChoice.message ?? { role: "assistant" }; + + if (choice.delta?.content) + baseChoice.message.content = + ((baseChoice.message.content as string) ?? "") + (choice.delta.content ?? ""); + if (choice.delta?.function_call) { + const fnCall = baseChoice.message.function_call ?? {}; + fnCall.name = + ((fnCall.name as string) ?? "") + ((choice.delta.function_call.name as string) ?? ""); + fnCall.arguments = + ((fnCall.arguments as string) ?? "") + + ((choice.delta.function_call.arguments as string) ?? ""); + } + } else { + choices.push({ ...omit(choice, "delta"), message: { role: "assistant", ...choice.delta } }); + } + } + + const merged: ChatCompletion = { + ...base, + choices, + }; + + return merged; +}; + +export async function getCompletion( + input: CompletionCreateParams, + onStream: ((partialOutput: ChatCompletion) => void) | null, +): Promise> { + const start = Date.now(); + let finalCompletion: ChatCompletion | null = null; + let promptTokens: number | undefined = undefined; + let completionTokens: number | undefined = undefined; + + try { + if (onStream) { + const resp = await openai.chat.completions.create( + { ...input, stream: true }, + { + maxRetries: 0, + }, + ); + for await (const part of resp) { + finalCompletion = mergeStreamedChunks(finalCompletion, part); + onStream(finalCompletion); + } + if (!finalCompletion) { + return { + type: "error", + message: "Streaming failed to return a completion", + autoRetry: false, + }; + } + try { + promptTokens = countOpenAIChatTokens( + input.model as keyof typeof OpenAIChatModel, + input.messages, + ); + completionTokens = countOpenAIChatTokens( + input.model as keyof typeof OpenAIChatModel, + finalCompletion.choices.map((c) => c.message).filter(truthyFilter), + ); + } catch (err) { + // TODO handle this, library seems like maybe it doesn't work with function calls? + console.error(err); + } + } else { + const resp = await openai.chat.completions.create( + { ...input, stream: false }, + { + maxRetries: 0, + }, + ); + finalCompletion = resp; + promptTokens = resp.usage?.prompt_tokens ?? 0; + completionTokens = resp.usage?.completion_tokens ?? 0; + } + const timeToComplete = Date.now() - start; + + const stats = modelStats[input.model as keyof typeof OpenAIChatModel]; + let cost = undefined; + if (stats && promptTokens && completionTokens) { + cost = promptTokens * stats.promptTokenPrice + completionTokens * stats.completionTokenPrice; + } + + return { + type: "success", + statusCode: 200, + value: finalCompletion, + timeToComplete, + promptTokens, + completionTokens, + cost, + }; + } catch (error: unknown) { + console.error("ERROR IS", error); + if (error instanceof APIError) { + return { + type: "error", + message: error.message, + autoRetry: error.status === 429 || error.status === 503, + statusCode: error.status, + }; + } else { + console.error(error); + return { + type: "error", + message: (error as Error).message, + autoRetry: true, + }; + } + } +} diff --git a/src/modelProviders/openai-ChatCompletion/index.ts b/src/modelProviders/openai-ChatCompletion/index.ts index 5c2c4cb..294877b 100644 --- a/src/modelProviders/openai-ChatCompletion/index.ts +++ b/src/modelProviders/openai-ChatCompletion/index.ts @@ -1,7 +1,8 @@ import { type JSONSchema4 } from "json-schema"; import { type ModelProvider } from "../types"; import inputSchema from "./codegen/input.schema.json"; -import { type CompletionCreateParams } from "openai/resources/chat"; +import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat"; +import { getCompletion } from "./getCompletion"; const supportedModels = [ "gpt-4-0613", @@ -12,7 +13,13 @@ const supportedModels = [ type SupportedModel = (typeof supportedModels)[number]; -const modelProvider: ModelProvider = { +export type OpenaiChatModelProvider = ModelProvider< + SupportedModel, + CompletionCreateParams, + ChatCompletion +>; + +const modelProvider: OpenaiChatModelProvider = { name: "OpenAI ChatCompletion", models: { "gpt-4-0613": { @@ -49,6 +56,7 @@ const modelProvider: ModelProvider = { }, inputSchema: inputSchema as JSONSchema4, shouldStream: (input) => input.stream ?? false, + getCompletion, }; export default modelProvider; diff --git a/src/modelProviders/types.ts b/src/modelProviders/types.ts index 3bccbb7..edcbe85 100644 --- a/src/modelProviders/types.ts +++ b/src/modelProviders/types.ts @@ -1,14 +1,48 @@ import { type JSONSchema4 } from "json-schema"; +import { type JsonValue } from "type-fest"; -export type ModelProviderModel = { +type ModelProviderModel = { name: string; learnMore: string; }; -export type ModelProvider = { +export type CompletionResponse = + | { type: "error"; message: string; autoRetry: boolean; statusCode?: number } + | { + type: "success"; + value: T; + timeToComplete: number; + statusCode: number; + promptTokens?: number; + completionTokens?: number; + cost?: number; + }; + +export type ModelProvider = { name: string; models: Record; getModel: (input: InputSchema) => SupportedModels | null; shouldStream: (input: InputSchema) => boolean; inputSchema: JSONSchema4; + getCompletion: ( + input: InputSchema, + onStream: ((partialOutput: OutputSchema) => void) | null, + ) => Promise>; + + // This is just a convenience for type inference, don't use it at runtime + _outputSchema?: OutputSchema | null; +}; + +export type NormalizedOutput = + | { + type: "text"; + value: string; + } + | { + type: "json"; + value: JsonValue; + }; + +export type ModelProviderFrontend> = { + normalizeOutput: (output: NonNullable) => NormalizedOutput; }; diff --git a/src/server/scripts/openai-test.ts b/src/server/scripts/openai-test.ts new file mode 100644 index 0000000..77a76fe --- /dev/null +++ b/src/server/scripts/openai-test.ts @@ -0,0 +1,19 @@ +import "dotenv/config"; +import { openai } from "../utils/openai"; + +const resp = await openai.chat.completions.create({ + model: "gpt-3.5-turbo-0613", + stream: true, + messages: [ + { + role: "user", + content: "count to 20", + }, + ], +}); + +for await (const part of resp) { + console.log("part", part); +} + +console.log("final resp", resp); diff --git a/src/server/tasks/queryLLM.task.ts b/src/server/tasks/queryLLM.task.ts index 6263cee..b2b86cb 100644 --- a/src/server/tasks/queryLLM.task.ts +++ b/src/server/tasks/queryLLM.task.ts @@ -1,15 +1,18 @@ import { prisma } from "~/server/db"; import defineTask from "./defineTask"; -import { type CompletionResponse, getOpenAIChatCompletion } from "../utils/getCompletion"; import { sleep } from "../utils/sleep"; import { generateChannel } from "~/utils/generateChannel"; import { runEvalsForOutput } from "../utils/evaluations"; -import { type CompletionCreateParams } from "openai/resources/chat"; import { type Prisma } from "@prisma/client"; import parseConstructFn from "../utils/parseConstructFn"; import hashPrompt from "../utils/hashPrompt"; import { type JsonObject } from "type-fest"; -import modelProviders from "~/modelProviders"; +import modelProviders from "~/modelProviders/modelProviders"; +import { wsConnection } from "~/utils/wsConnection"; + +export type queryLLMJob = { + scenarioVariantCellId: string; +}; const MAX_AUTO_RETRIES = 10; const MIN_DELAY = 500; // milliseconds @@ -21,51 +24,6 @@ function calculateDelay(numPreviousTries: number): number { return baseDelay + jitter; } -const getCompletionWithRetries = async ( - cellId: string, - payload: JsonObject, - channel?: string, -): Promise => { - let modelResponse: CompletionResponse | null = null; - try { - for (let i = 0; i < MAX_AUTO_RETRIES; i++) { - modelResponse = await getOpenAIChatCompletion( - payload as unknown as CompletionCreateParams, - channel, - ); - if ( - (modelResponse.statusCode !== 429 && modelResponse.statusCode !== 503) || - i === MAX_AUTO_RETRIES - 1 - ) { - return modelResponse; - } - const delay = calculateDelay(i); - await prisma.scenarioVariantCell.update({ - where: { id: cellId }, - data: { - errorMessage: "Rate limit exceeded", - statusCode: 429, - retryTime: new Date(Date.now() + delay), - }, - }); - // TODO: Maybe requeue the job so other jobs can run in the future? - await sleep(delay); - } - throw new Error("Max retries limit reached"); - } catch (error: unknown) { - return { - statusCode: modelResponse?.statusCode ?? 500, - errorMessage: modelResponse?.errorMessage ?? (error as Error).message, - output: null, - timeToComplete: 0, - }; - } -}; - -export type queryLLMJob = { - scenarioVariantCellId: string; -}; - export const queryLLM = defineTask("queryLLM", async (task) => { const { scenarioVariantCellId } = task; const cell = await prisma.scenarioVariantCell.findUnique({ @@ -141,57 +99,67 @@ export const queryLLM = defineTask("queryLLM", async (task) => { const provider = modelProviders[prompt.modelProvider]; - const streamingEnabled = provider.shouldStream(prompt.modelInput); - let streamingChannel; + const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null; - if (streamingEnabled) { - streamingChannel = generateChannel(); + if (streamingChannel) { // Save streaming channel so that UI can connect to it await prisma.scenarioVariantCell.update({ where: { id: scenarioVariantCellId }, data: { streamingChannel }, }); } + const onStream = streamingChannel + ? (partialOutput: (typeof provider)["_outputSchema"]) => { + wsConnection.emit("message", { channel: streamingChannel, payload: partialOutput }); + } + : null; - const modelResponse = await getCompletionWithRetries( - scenarioVariantCellId, - prompt.modelInput as unknown as JsonObject, - streamingChannel, - ); + for (let i = 0; true; i++) { + const response = await provider.getCompletion(prompt.modelInput, onStream); + if (response.type === "success") { + const inputHash = hashPrompt(prompt); - let modelOutput = null; - if (modelResponse.statusCode === 200) { - const inputHash = hashPrompt(prompt); - - modelOutput = await prisma.modelOutput.create({ - data: { - scenarioVariantCellId, - inputHash, - output: modelResponse.output as unknown as Prisma.InputJsonObject, - timeToComplete: modelResponse.timeToComplete, - promptTokens: modelResponse.promptTokens, - completionTokens: modelResponse.completionTokens, - cost: modelResponse.cost, - }, - }); - } - - await prisma.scenarioVariantCell.update({ - where: { id: scenarioVariantCellId }, - data: { - statusCode: modelResponse.statusCode, - errorMessage: modelResponse.errorMessage, - streamingChannel: null, - retrievalStatus: modelOutput ? "COMPLETE" : "ERROR", - modelOutput: { - connect: { - id: modelOutput?.id, + const modelOutput = await prisma.modelOutput.create({ + data: { + scenarioVariantCellId, + inputHash, + output: response.value as unknown as Prisma.InputJsonObject, + timeToComplete: response.timeToComplete, + promptTokens: response.promptTokens, + completionTokens: response.completionTokens, + cost: response.cost, }, - }, - }, - }); + }); - if (modelOutput) { - await runEvalsForOutput(variant.experimentId, scenario, modelOutput); + await prisma.scenarioVariantCell.update({ + where: { id: scenarioVariantCellId }, + data: { + statusCode: response.statusCode, + retrievalStatus: "COMPLETE", + }, + }); + + await runEvalsForOutput(variant.experimentId, scenario, modelOutput); + break; + } else { + const shouldRetry = response.autoRetry && i < MAX_AUTO_RETRIES; + const delay = calculateDelay(i); + + await prisma.scenarioVariantCell.update({ + where: { id: scenarioVariantCellId }, + data: { + errorMessage: response.message, + statusCode: response.statusCode, + retryTime: shouldRetry ? new Date(Date.now() + delay) : null, + retrievalStatus: shouldRetry ? "PENDING" : "ERROR", + }, + }); + + if (shouldRetry) { + await sleep(delay); + } else { + break; + } + } } }); diff --git a/src/server/types.ts b/src/server/types.ts index c82930b..983bacd 100644 --- a/src/server/types.ts +++ b/src/server/types.ts @@ -1,14 +1,3 @@ -export type JSONSerializable = - | string - | number - | boolean - | null - | JSONSerializable[] - | { [key: string]: JSONSerializable }; - -// Placeholder for now -export type OpenAIChatConfig = NonNullable; - export enum OpenAIChatModel { "gpt-4" = "gpt-4", "gpt-4-0613" = "gpt-4-0613", diff --git a/src/server/utils/fillTemplate.ts b/src/server/utils/fillTemplate.ts index 224c66c..b6108e2 100644 --- a/src/server/utils/fillTemplate.ts +++ b/src/server/utils/fillTemplate.ts @@ -1,5 +1,3 @@ -import { type JSONSerializable } from "../types"; - export type VariableMap = Record; // Escape quotes to match the way we encode JSON @@ -15,24 +13,3 @@ export function escapeRegExp(str: string) { export function fillTemplate(template: string, variables: VariableMap): string { return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || ""); } - -export function fillTemplateJson( - template: T, - variables: VariableMap, -): T { - if (typeof template === "string") { - return fillTemplate(template, variables) as T; - } else if (Array.isArray(template)) { - return template.map((item) => fillTemplateJson(item, variables)) as T; - } else if (typeof template === "object" && template !== null) { - return Object.keys(template).reduce( - (acc, key) => { - acc[key] = fillTemplateJson(template[key] as JSONSerializable, variables); - return acc; - }, - {} as { [key: string]: JSONSerializable } & T, - ); - } else { - return template; - } -} diff --git a/src/server/utils/generateNewCell.ts b/src/server/utils/generateNewCell.ts index b47ad57..cdbe50b 100644 --- a/src/server/utils/generateNewCell.ts +++ b/src/server/utils/generateNewCell.ts @@ -46,6 +46,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) => testScenarioId: scenarioId, statusCode: 400, errorMessage: parsedConstructFn.error, + retrievalStatus: "ERROR", }, }); } @@ -57,6 +58,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) => promptVariantId: variantId, testScenarioId: scenarioId, prompt: parsedConstructFn.modelInput as unknown as Prisma.InputJsonValue, + retrievalStatus: "PENDING", }, include: { modelOutput: true, @@ -83,6 +85,10 @@ export const generateNewCell = async (variantId: string, scenarioId: string) => updatedAt: matchingModelOutput.updatedAt, }, }); + await prisma.scenarioVariantCell.update({ + where: { id: cell.id }, + data: { retrievalStatus: "COMPLETE" }, + }); } else { cell = await queueLLMRetrievalTask(cell.id); } diff --git a/src/server/utils/getCompletion.ts b/src/server/utils/getCompletion.ts deleted file mode 100644 index b836f3a..0000000 --- a/src/server/utils/getCompletion.ts +++ /dev/null @@ -1,107 +0,0 @@ -/* eslint-disable @typescript-eslint/no-unsafe-call */ -import { isObject } from "lodash-es"; -import { streamChatCompletion } from "./openai"; -import { wsConnection } from "~/utils/wsConnection"; -import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat"; -import { type SupportedModel, type OpenAIChatModel } from "../types"; -import { env } from "~/env.mjs"; -import { countOpenAIChatTokens } from "~/utils/countTokens"; -import { rateLimitErrorMessage } from "~/sharedStrings"; -import { modelStats } from "../../modelProviders/modelStats"; - -export type CompletionResponse = { - output: ChatCompletion | null; - statusCode: number; - errorMessage: string | null; - timeToComplete: number; - promptTokens?: number; - completionTokens?: number; - cost?: number; -}; - -export async function getOpenAIChatCompletion( - payload: CompletionCreateParams, - channel?: string, -): Promise { - // If functions are enabled, disable streaming so that we get the full response with token counts - if (payload.functions?.length) payload.stream = false; - const start = Date.now(); - const response = await fetch("https://api.openai.com/v1/chat/completions", { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${env.OPENAI_API_KEY}`, - }, - body: JSON.stringify(payload), - }); - - const resp: CompletionResponse = { - output: null, - errorMessage: null, - statusCode: response.status, - timeToComplete: 0, - }; - - try { - if (payload.stream) { - const completion = streamChatCompletion(payload as unknown as CompletionCreateParams); - let finalOutput: ChatCompletion | null = null; - await (async () => { - for await (const partialCompletion of completion) { - finalOutput = partialCompletion; - wsConnection.emit("message", { channel, payload: partialCompletion }); - } - })().catch((err) => console.error(err)); - if (finalOutput) { - resp.output = finalOutput; - resp.timeToComplete = Date.now() - start; - } - } else { - resp.timeToComplete = Date.now() - start; - resp.output = await response.json(); - } - - if (!response.ok) { - if (response.status === 429) { - resp.errorMessage = rateLimitErrorMessage; - } else if ( - isObject(resp.output) && - "error" in resp.output && - isObject(resp.output.error) && - "message" in resp.output.error - ) { - // If it's an object, try to get the error message - resp.errorMessage = resp.output.error.message?.toString() ?? "Unknown error"; - } - } - - if (isObject(resp.output) && "usage" in resp.output) { - const usage = resp.output.usage as unknown as ChatCompletion.Usage; - resp.promptTokens = usage.prompt_tokens; - resp.completionTokens = usage.completion_tokens; - } else if (isObject(resp.output) && "choices" in resp.output) { - const model = payload.model as unknown as OpenAIChatModel; - resp.promptTokens = countOpenAIChatTokens(model, payload.messages); - const choices = resp.output.choices as unknown as ChatCompletion.Choice[]; - const message = choices[0]?.message; - if (message) { - const messages = [message]; - resp.completionTokens = countOpenAIChatTokens(model, messages); - } - } - - const stats = modelStats[resp.output?.model as SupportedModel]; - if (stats && resp.promptTokens && resp.completionTokens) { - resp.cost = - resp.promptTokens * stats.promptTokenPrice + - resp.completionTokens * stats.completionTokenPrice; - } - } catch (e) { - console.error(e); - if (response.ok) { - resp.errorMessage = "Failed to parse response"; - } - } - - return resp; -} diff --git a/src/server/utils/openai.ts b/src/server/utils/openai.ts index bd210ea..59fcb06 100644 --- a/src/server/utils/openai.ts +++ b/src/server/utils/openai.ts @@ -1,64 +1,5 @@ -import { omit } from "lodash-es"; import { env } from "~/env.mjs"; import OpenAI from "openai"; -import { - type ChatCompletion, - type ChatCompletionChunk, - type CompletionCreateParams, -} from "openai/resources/chat"; export const openai = new OpenAI({ apiKey: env.OPENAI_API_KEY }); - -export const mergeStreamedChunks = ( - base: ChatCompletion | null, - chunk: ChatCompletionChunk, -): ChatCompletion => { - if (base === null) { - return mergeStreamedChunks({ ...chunk, choices: [] }, chunk); - } - - const choices = [...base.choices]; - for (const choice of chunk.choices) { - const baseChoice = choices.find((c) => c.index === choice.index); - if (baseChoice) { - baseChoice.finish_reason = choice.finish_reason ?? baseChoice.finish_reason; - baseChoice.message = baseChoice.message ?? { role: "assistant" }; - - if (choice.delta?.content) - baseChoice.message.content = - ((baseChoice.message.content as string) ?? "") + (choice.delta.content ?? ""); - if (choice.delta?.function_call) { - const fnCall = baseChoice.message.function_call ?? {}; - fnCall.name = - ((fnCall.name as string) ?? "") + ((choice.delta.function_call.name as string) ?? ""); - fnCall.arguments = - ((fnCall.arguments as string) ?? "") + - ((choice.delta.function_call.arguments as string) ?? ""); - } - } else { - choices.push({ ...omit(choice, "delta"), message: { role: "assistant", ...choice.delta } }); - } - } - - const merged: ChatCompletion = { - ...base, - choices, - }; - - return merged; -}; - -export const streamChatCompletion = async function* (body: CompletionCreateParams) { - // eslint-disable-next-line @typescript-eslint/no-unsafe-call - const resp = await openai.chat.completions.create({ - ...body, - stream: true, - }); - - let mergedChunks: ChatCompletion | null = null; - for await (const part of resp) { - mergedChunks = mergeStreamedChunks(mergedChunks, part); - yield mergedChunks; - } -}; diff --git a/src/server/utils/parseConstructFn.ts b/src/server/utils/parseConstructFn.ts index 0176b8a..8bfd667 100644 --- a/src/server/utils/parseConstructFn.ts +++ b/src/server/utils/parseConstructFn.ts @@ -1,4 +1,4 @@ -import modelProviders from "~/modelProviders"; +import modelProviders from "~/modelProviders/modelProviders"; import ivm from "isolated-vm"; import { isObject, isString } from "lodash-es"; import { type JsonObject } from "type-fest"; diff --git a/src/state/sync.tsx b/src/state/sync.tsx index b3136e5..7d05411 100644 --- a/src/state/sync.tsx +++ b/src/state/sync.tsx @@ -22,7 +22,6 @@ export function SyncAppStore() { const setApi = useAppStore((state) => state.setApi); useEffect(() => { - console.log("setting api", utils); setApi(utils); }, [utils, setApi]); diff --git a/src/utils/countTokens.ts b/src/utils/countTokens.ts index db2ef0a..c21d432 100644 --- a/src/utils/countTokens.ts +++ b/src/utils/countTokens.ts @@ -9,7 +9,7 @@ interface GPTTokensMessageItem { } export const countOpenAIChatTokens = ( - model: OpenAIChatModel, + model: keyof typeof OpenAIChatModel, messages: ChatCompletion.Choice.Message[], ) => { return new GPTTokens({ model, messages: messages as unknown as GPTTokensMessageItem[] }) diff --git a/src/utils/useSocket.ts b/src/utils/useSocket.ts index 175036c..ba69387 100644 --- a/src/utils/useSocket.ts +++ b/src/utils/useSocket.ts @@ -1,13 +1,12 @@ -import { type ChatCompletion } from "openai/resources/chat"; import { useRef, useState, useEffect } from "react"; import { io, type Socket } from "socket.io-client"; import { env } from "~/env.mjs"; const url = env.NEXT_PUBLIC_SOCKET_URL; -export default function useSocket(channel?: string | null) { +export default function useSocket(channel?: string | null) { const socketRef = useRef(); - const [message, setMessage] = useState(null); + const [message, setMessage] = useState(null); useEffect(() => { if (!channel) return; @@ -21,7 +20,7 @@ export default function useSocket(channel?: string | null) { socketRef.current?.emit("join", channel); // Listen for 'message' events - socketRef.current?.on("message", (message: ChatCompletion) => { + socketRef.current?.on("message", (message: T) => { setMessage(message); }); }); diff --git a/src/utils/utils.ts b/src/utils/utils.ts new file mode 100644 index 0000000..217c0fb --- /dev/null +++ b/src/utils/utils.ts @@ -0,0 +1 @@ +export const truthyFilter = (x: T | null | undefined): x is T => Boolean(x);