More work on modelProviders
I think everything that's OpenAI-specific is inside modelProviders at this point, so we can get started adding more providers.
This commit is contained in:
@@ -1,5 +1,3 @@
|
||||
import { type JSONSerializable } from "../types";
|
||||
|
||||
export type VariableMap = Record<string, string>;
|
||||
|
||||
// Escape quotes to match the way we encode JSON
|
||||
@@ -15,24 +13,3 @@ export function escapeRegExp(str: string) {
|
||||
export function fillTemplate(template: string, variables: VariableMap): string {
|
||||
return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || "");
|
||||
}
|
||||
|
||||
export function fillTemplateJson<T extends JSONSerializable>(
|
||||
template: T,
|
||||
variables: VariableMap,
|
||||
): T {
|
||||
if (typeof template === "string") {
|
||||
return fillTemplate(template, variables) as T;
|
||||
} else if (Array.isArray(template)) {
|
||||
return template.map((item) => fillTemplateJson(item, variables)) as T;
|
||||
} else if (typeof template === "object" && template !== null) {
|
||||
return Object.keys(template).reduce(
|
||||
(acc, key) => {
|
||||
acc[key] = fillTemplateJson(template[key] as JSONSerializable, variables);
|
||||
return acc;
|
||||
},
|
||||
{} as { [key: string]: JSONSerializable } & T,
|
||||
);
|
||||
} else {
|
||||
return template;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,6 +46,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
|
||||
testScenarioId: scenarioId,
|
||||
statusCode: 400,
|
||||
errorMessage: parsedConstructFn.error,
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -57,6 +58,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
|
||||
promptVariantId: variantId,
|
||||
testScenarioId: scenarioId,
|
||||
prompt: parsedConstructFn.modelInput as unknown as Prisma.InputJsonValue,
|
||||
retrievalStatus: "PENDING",
|
||||
},
|
||||
include: {
|
||||
modelOutput: true,
|
||||
@@ -83,6 +85,10 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
|
||||
updatedAt: matchingModelOutput.updatedAt,
|
||||
},
|
||||
});
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cell.id },
|
||||
data: { retrievalStatus: "COMPLETE" },
|
||||
});
|
||||
} else {
|
||||
cell = await queueLLMRetrievalTask(cell.id);
|
||||
}
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
/* eslint-disable @typescript-eslint/no-unsafe-call */
|
||||
import { isObject } from "lodash-es";
|
||||
import { streamChatCompletion } from "./openai";
|
||||
import { wsConnection } from "~/utils/wsConnection";
|
||||
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
||||
import { type SupportedModel, type OpenAIChatModel } from "../types";
|
||||
import { env } from "~/env.mjs";
|
||||
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||
import { rateLimitErrorMessage } from "~/sharedStrings";
|
||||
import { modelStats } from "../../modelProviders/modelStats";
|
||||
|
||||
export type CompletionResponse = {
|
||||
output: ChatCompletion | null;
|
||||
statusCode: number;
|
||||
errorMessage: string | null;
|
||||
timeToComplete: number;
|
||||
promptTokens?: number;
|
||||
completionTokens?: number;
|
||||
cost?: number;
|
||||
};
|
||||
|
||||
export async function getOpenAIChatCompletion(
|
||||
payload: CompletionCreateParams,
|
||||
channel?: string,
|
||||
): Promise<CompletionResponse> {
|
||||
// If functions are enabled, disable streaming so that we get the full response with token counts
|
||||
if (payload.functions?.length) payload.stream = false;
|
||||
const start = Date.now();
|
||||
const response = await fetch("https://api.openai.com/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${env.OPENAI_API_KEY}`,
|
||||
},
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
|
||||
const resp: CompletionResponse = {
|
||||
output: null,
|
||||
errorMessage: null,
|
||||
statusCode: response.status,
|
||||
timeToComplete: 0,
|
||||
};
|
||||
|
||||
try {
|
||||
if (payload.stream) {
|
||||
const completion = streamChatCompletion(payload as unknown as CompletionCreateParams);
|
||||
let finalOutput: ChatCompletion | null = null;
|
||||
await (async () => {
|
||||
for await (const partialCompletion of completion) {
|
||||
finalOutput = partialCompletion;
|
||||
wsConnection.emit("message", { channel, payload: partialCompletion });
|
||||
}
|
||||
})().catch((err) => console.error(err));
|
||||
if (finalOutput) {
|
||||
resp.output = finalOutput;
|
||||
resp.timeToComplete = Date.now() - start;
|
||||
}
|
||||
} else {
|
||||
resp.timeToComplete = Date.now() - start;
|
||||
resp.output = await response.json();
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
if (response.status === 429) {
|
||||
resp.errorMessage = rateLimitErrorMessage;
|
||||
} else if (
|
||||
isObject(resp.output) &&
|
||||
"error" in resp.output &&
|
||||
isObject(resp.output.error) &&
|
||||
"message" in resp.output.error
|
||||
) {
|
||||
// If it's an object, try to get the error message
|
||||
resp.errorMessage = resp.output.error.message?.toString() ?? "Unknown error";
|
||||
}
|
||||
}
|
||||
|
||||
if (isObject(resp.output) && "usage" in resp.output) {
|
||||
const usage = resp.output.usage as unknown as ChatCompletion.Usage;
|
||||
resp.promptTokens = usage.prompt_tokens;
|
||||
resp.completionTokens = usage.completion_tokens;
|
||||
} else if (isObject(resp.output) && "choices" in resp.output) {
|
||||
const model = payload.model as unknown as OpenAIChatModel;
|
||||
resp.promptTokens = countOpenAIChatTokens(model, payload.messages);
|
||||
const choices = resp.output.choices as unknown as ChatCompletion.Choice[];
|
||||
const message = choices[0]?.message;
|
||||
if (message) {
|
||||
const messages = [message];
|
||||
resp.completionTokens = countOpenAIChatTokens(model, messages);
|
||||
}
|
||||
}
|
||||
|
||||
const stats = modelStats[resp.output?.model as SupportedModel];
|
||||
if (stats && resp.promptTokens && resp.completionTokens) {
|
||||
resp.cost =
|
||||
resp.promptTokens * stats.promptTokenPrice +
|
||||
resp.completionTokens * stats.completionTokenPrice;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
if (response.ok) {
|
||||
resp.errorMessage = "Failed to parse response";
|
||||
}
|
||||
}
|
||||
|
||||
return resp;
|
||||
}
|
||||
@@ -1,64 +1,5 @@
|
||||
import { omit } from "lodash-es";
|
||||
import { env } from "~/env.mjs";
|
||||
|
||||
import OpenAI from "openai";
|
||||
import {
|
||||
type ChatCompletion,
|
||||
type ChatCompletionChunk,
|
||||
type CompletionCreateParams,
|
||||
} from "openai/resources/chat";
|
||||
|
||||
export const openai = new OpenAI({ apiKey: env.OPENAI_API_KEY });
|
||||
|
||||
export const mergeStreamedChunks = (
|
||||
base: ChatCompletion | null,
|
||||
chunk: ChatCompletionChunk,
|
||||
): ChatCompletion => {
|
||||
if (base === null) {
|
||||
return mergeStreamedChunks({ ...chunk, choices: [] }, chunk);
|
||||
}
|
||||
|
||||
const choices = [...base.choices];
|
||||
for (const choice of chunk.choices) {
|
||||
const baseChoice = choices.find((c) => c.index === choice.index);
|
||||
if (baseChoice) {
|
||||
baseChoice.finish_reason = choice.finish_reason ?? baseChoice.finish_reason;
|
||||
baseChoice.message = baseChoice.message ?? { role: "assistant" };
|
||||
|
||||
if (choice.delta?.content)
|
||||
baseChoice.message.content =
|
||||
((baseChoice.message.content as string) ?? "") + (choice.delta.content ?? "");
|
||||
if (choice.delta?.function_call) {
|
||||
const fnCall = baseChoice.message.function_call ?? {};
|
||||
fnCall.name =
|
||||
((fnCall.name as string) ?? "") + ((choice.delta.function_call.name as string) ?? "");
|
||||
fnCall.arguments =
|
||||
((fnCall.arguments as string) ?? "") +
|
||||
((choice.delta.function_call.arguments as string) ?? "");
|
||||
}
|
||||
} else {
|
||||
choices.push({ ...omit(choice, "delta"), message: { role: "assistant", ...choice.delta } });
|
||||
}
|
||||
}
|
||||
|
||||
const merged: ChatCompletion = {
|
||||
...base,
|
||||
choices,
|
||||
};
|
||||
|
||||
return merged;
|
||||
};
|
||||
|
||||
export const streamChatCompletion = async function* (body: CompletionCreateParams) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-call
|
||||
const resp = await openai.chat.completions.create({
|
||||
...body,
|
||||
stream: true,
|
||||
});
|
||||
|
||||
let mergedChunks: ChatCompletion | null = null;
|
||||
for await (const part of resp) {
|
||||
mergedChunks = mergeStreamedChunks(mergedChunks, part);
|
||||
yield mergedChunks;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import modelProviders from "~/modelProviders";
|
||||
import modelProviders from "~/modelProviders/modelProviders";
|
||||
import ivm from "isolated-vm";
|
||||
import { isObject, isString } from "lodash-es";
|
||||
import { type JsonObject } from "type-fest";
|
||||
|
||||
Reference in New Issue
Block a user