140 lines
4.3 KiB
TypeScript
140 lines
4.3 KiB
TypeScript
/* eslint-disable @typescript-eslint/no-unsafe-call */
|
|
import {
|
|
type ChatCompletionChunk,
|
|
type ChatCompletion,
|
|
type CompletionCreateParams,
|
|
} from "openai/resources/chat";
|
|
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
|
import { type CompletionResponse } from "../types";
|
|
import { omit } from "lodash-es";
|
|
import { openai } from "~/server/utils/openai";
|
|
import { truthyFilter } from "~/utils/utils";
|
|
import { APIError } from "openai";
|
|
import frontendModelProvider from "./frontend";
|
|
import modelProvider, { type SupportedModel } from ".";
|
|
|
|
const mergeStreamedChunks = (
|
|
base: ChatCompletion | null,
|
|
chunk: ChatCompletionChunk,
|
|
): ChatCompletion => {
|
|
if (base === null) {
|
|
return mergeStreamedChunks({ ...chunk, choices: [] }, chunk);
|
|
}
|
|
|
|
const choices = [...base.choices];
|
|
for (const choice of chunk.choices) {
|
|
const baseChoice = choices.find((c) => c.index === choice.index);
|
|
if (baseChoice) {
|
|
baseChoice.finish_reason = choice.finish_reason ?? baseChoice.finish_reason;
|
|
baseChoice.message = baseChoice.message ?? { role: "assistant" };
|
|
|
|
if (choice.delta?.content)
|
|
baseChoice.message.content =
|
|
((baseChoice.message.content as string) ?? "") + (choice.delta.content ?? "");
|
|
if (choice.delta?.function_call) {
|
|
const fnCall = baseChoice.message.function_call ?? {};
|
|
fnCall.name =
|
|
((fnCall.name as string) ?? "") + ((choice.delta.function_call.name as string) ?? "");
|
|
fnCall.arguments =
|
|
((fnCall.arguments as string) ?? "") +
|
|
((choice.delta.function_call.arguments as string) ?? "");
|
|
}
|
|
} else {
|
|
choices.push({ ...omit(choice, "delta"), message: { role: "assistant", ...choice.delta } });
|
|
}
|
|
}
|
|
|
|
const merged: ChatCompletion = {
|
|
...base,
|
|
choices,
|
|
};
|
|
|
|
return merged;
|
|
};
|
|
|
|
export async function getCompletion(
|
|
input: CompletionCreateParams,
|
|
onStream: ((partialOutput: ChatCompletion) => void) | null,
|
|
): Promise<CompletionResponse<ChatCompletion>> {
|
|
const start = Date.now();
|
|
let finalCompletion: ChatCompletion | null = null;
|
|
let promptTokens: number | undefined = undefined;
|
|
let completionTokens: number | undefined = undefined;
|
|
const modelName = modelProvider.getModel(input) as SupportedModel;
|
|
|
|
try {
|
|
if (onStream) {
|
|
const resp = await openai.chat.completions.create(
|
|
{ ...input, stream: true },
|
|
{
|
|
maxRetries: 0,
|
|
},
|
|
);
|
|
for await (const part of resp) {
|
|
finalCompletion = mergeStreamedChunks(finalCompletion, part);
|
|
onStream(finalCompletion);
|
|
}
|
|
if (!finalCompletion) {
|
|
return {
|
|
type: "error",
|
|
message: "Streaming failed to return a completion",
|
|
autoRetry: false,
|
|
};
|
|
}
|
|
try {
|
|
promptTokens = countOpenAIChatTokens(modelName, input.messages);
|
|
completionTokens = countOpenAIChatTokens(
|
|
modelName,
|
|
finalCompletion.choices.map((c) => c.message).filter(truthyFilter),
|
|
);
|
|
} catch (err) {
|
|
// TODO handle this, library seems like maybe it doesn't work with function calls?
|
|
console.error(err);
|
|
}
|
|
} else {
|
|
const resp = await openai.chat.completions.create(
|
|
{ ...input, stream: false },
|
|
{
|
|
maxRetries: 0,
|
|
},
|
|
);
|
|
finalCompletion = resp;
|
|
promptTokens = resp.usage?.prompt_tokens ?? 0;
|
|
completionTokens = resp.usage?.completion_tokens ?? 0;
|
|
}
|
|
const timeToComplete = Date.now() - start;
|
|
|
|
const { promptTokenPrice, completionTokenPrice } = frontendModelProvider.models[modelName];
|
|
let cost = undefined;
|
|
if (promptTokenPrice && completionTokenPrice && promptTokens && completionTokens) {
|
|
cost = promptTokens * promptTokenPrice + completionTokens * completionTokenPrice;
|
|
}
|
|
|
|
return {
|
|
type: "success",
|
|
statusCode: 200,
|
|
value: finalCompletion,
|
|
timeToComplete,
|
|
promptTokens,
|
|
completionTokens,
|
|
cost,
|
|
};
|
|
} catch (error: unknown) {
|
|
if (error instanceof APIError) {
|
|
return {
|
|
type: "error",
|
|
message: error.message,
|
|
autoRetry: error.status === 429 || error.status === 503,
|
|
statusCode: error.status,
|
|
};
|
|
} else {
|
|
console.error(error);
|
|
return {
|
|
type: "error",
|
|
message: (error as Error).message,
|
|
autoRetry: true,
|
|
};
|
|
}
|
|
}
|
|
}
|