Add promptTokens and completionTokens to model output (#11)
* Default to streaming in config * Add tokens to database * Add NEXT_PUBLIC_SOCKET_URL to .env.example * Disable streaming for functions * Add newline to types
This commit is contained in:
@@ -1,74 +0,0 @@
|
||||
/* eslint-disable @typescript-eslint/no-unsafe-call */
|
||||
import { isObject } from "lodash";
|
||||
import { type JSONSerializable } from "../types";
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { streamChatCompletion } from "./openai";
|
||||
import { wsConnection } from "~/utils/wsConnection";
|
||||
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
||||
|
||||
type CompletionResponse = {
|
||||
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
|
||||
statusCode: number;
|
||||
errorMessage: string | null;
|
||||
timeToComplete: number
|
||||
};
|
||||
|
||||
export async function getChatCompletion(
|
||||
payload: JSONSerializable,
|
||||
apiKey: string,
|
||||
channel?: string,
|
||||
): Promise<CompletionResponse> {
|
||||
const start = Date.now();
|
||||
const response = await fetch("https://api.openai.com/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
|
||||
const resp: CompletionResponse = {
|
||||
output: Prisma.JsonNull,
|
||||
errorMessage: null,
|
||||
statusCode: response.status,
|
||||
timeToComplete: 0
|
||||
};
|
||||
|
||||
try {
|
||||
if (channel) {
|
||||
const completion = streamChatCompletion(payload as unknown as CompletionCreateParams);
|
||||
let finalOutput: ChatCompletion | null = null;
|
||||
await (async () => {
|
||||
for await (const partialCompletion of completion) {
|
||||
finalOutput = partialCompletion
|
||||
wsConnection.emit("message", { channel, payload: partialCompletion });
|
||||
}
|
||||
})().catch((err) => console.error(err));
|
||||
resp.output = finalOutput as unknown as Prisma.InputJsonValue;
|
||||
resp.timeToComplete = Date.now() - start;
|
||||
} else {
|
||||
resp.timeToComplete = Date.now() - start;
|
||||
resp.output = await response.json();
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
// If it's an object, try to get the error message
|
||||
if (
|
||||
isObject(resp.output) &&
|
||||
"error" in resp.output &&
|
||||
isObject(resp.output.error) &&
|
||||
"message" in resp.output.error
|
||||
) {
|
||||
resp.errorMessage = resp.output.error.message?.toString() ?? "Unknown error";
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
if (response.ok) {
|
||||
resp.errorMessage = "Failed to parse response";
|
||||
}
|
||||
}
|
||||
|
||||
return resp;
|
||||
}
|
||||
132
src/server/utils/getCompletion.ts
Normal file
132
src/server/utils/getCompletion.ts
Normal file
@@ -0,0 +1,132 @@
|
||||
/* eslint-disable @typescript-eslint/no-unsafe-call */
|
||||
import { isObject } from "lodash";
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { streamChatCompletion } from "./openai";
|
||||
import { wsConnection } from "~/utils/wsConnection";
|
||||
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
||||
import { type JSONSerializable, OpenAIChatModels } from "../types";
|
||||
import { env } from "~/env.mjs";
|
||||
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||
|
||||
env;
|
||||
|
||||
type CompletionResponse = {
|
||||
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
|
||||
statusCode: number;
|
||||
errorMessage: string | null;
|
||||
timeToComplete: number;
|
||||
promptTokens?: number;
|
||||
completionTokens?: number;
|
||||
};
|
||||
|
||||
export async function getCompletion(
|
||||
payload: JSONSerializable,
|
||||
channel?: string
|
||||
): Promise<CompletionResponse> {
|
||||
if (!payload || !isObject(payload))
|
||||
return {
|
||||
output: Prisma.JsonNull,
|
||||
statusCode: 400,
|
||||
errorMessage: "Invalid payload provided",
|
||||
timeToComplete: 0,
|
||||
};
|
||||
if (
|
||||
"model" in payload &&
|
||||
typeof payload.model === "string" &&
|
||||
payload.model in OpenAIChatModels
|
||||
) {
|
||||
return getOpenAIChatCompletion(
|
||||
payload as unknown as CompletionCreateParams,
|
||||
env.OPENAI_API_KEY,
|
||||
channel
|
||||
);
|
||||
}
|
||||
return {
|
||||
output: Prisma.JsonNull,
|
||||
statusCode: 400,
|
||||
errorMessage: "Invalid model provided",
|
||||
timeToComplete: 0,
|
||||
};
|
||||
}
|
||||
|
||||
export async function getOpenAIChatCompletion(
|
||||
payload: CompletionCreateParams,
|
||||
apiKey: string,
|
||||
channel?: string
|
||||
): Promise<CompletionResponse> {
|
||||
// If functions are enabled, disable streaming so that we get the full response with token counts
|
||||
if (payload.functions?.length) payload.stream = false;
|
||||
const start = Date.now();
|
||||
const response = await fetch("https://api.openai.com/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
|
||||
const resp: CompletionResponse = {
|
||||
output: Prisma.JsonNull,
|
||||
errorMessage: null,
|
||||
statusCode: response.status,
|
||||
timeToComplete: 0,
|
||||
};
|
||||
|
||||
try {
|
||||
if (payload.stream) {
|
||||
const completion = streamChatCompletion(payload as unknown as CompletionCreateParams);
|
||||
let finalOutput: ChatCompletion | null = null;
|
||||
await (async () => {
|
||||
for await (const partialCompletion of completion) {
|
||||
finalOutput = partialCompletion;
|
||||
wsConnection.emit("message", { channel, payload: partialCompletion });
|
||||
}
|
||||
})().catch((err) => console.error(err));
|
||||
if (finalOutput) {
|
||||
resp.output = finalOutput as unknown as Prisma.InputJsonValue;
|
||||
resp.timeToComplete = Date.now() - start;
|
||||
}
|
||||
} else {
|
||||
resp.timeToComplete = Date.now() - start;
|
||||
resp.output = await response.json();
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
// If it's an object, try to get the error message
|
||||
if (
|
||||
isObject(resp.output) &&
|
||||
"error" in resp.output &&
|
||||
isObject(resp.output.error) &&
|
||||
"message" in resp.output.error
|
||||
) {
|
||||
resp.errorMessage = resp.output.error.message?.toString() ?? "Unknown error";
|
||||
}
|
||||
}
|
||||
|
||||
if (isObject(resp.output) && "usage" in resp.output) {
|
||||
const usage = resp.output.usage as unknown as ChatCompletion.Usage;
|
||||
resp.promptTokens = usage.prompt_tokens;
|
||||
resp.completionTokens = usage.completion_tokens;
|
||||
} else if (isObject(resp.output) && 'choices' in resp.output) {
|
||||
const model = payload.model as unknown as OpenAIChatModels
|
||||
resp.promptTokens = countOpenAIChatTokens(
|
||||
model,
|
||||
payload.messages
|
||||
);
|
||||
const choices = resp.output.choices as unknown as ChatCompletion.Choice[];
|
||||
const message = choices[0]?.message
|
||||
if (message) {
|
||||
const messages = [message]
|
||||
resp.completionTokens = countOpenAIChatTokens(model, messages);
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
if (response.ok) {
|
||||
resp.errorMessage = "Failed to parse response";
|
||||
}
|
||||
}
|
||||
|
||||
return resp;
|
||||
}
|
||||
Reference in New Issue
Block a user