Add promptTokens and completionTokens to model output (#11)

* Default to streaming in config

* Add tokens to database

* Add NEXT_PUBLIC_SOCKET_URL to .env.example

* Disable streaming for functions

* Add newline to types
This commit is contained in:
arcticfly
2023-07-06 13:12:59 -07:00
committed by GitHub
parent 6ecb952a68
commit 1ae5612d55
11 changed files with 201 additions and 82 deletions

View File

@@ -1,74 +0,0 @@
/* eslint-disable @typescript-eslint/no-unsafe-call */
import { isObject } from "lodash";
import { type JSONSerializable } from "../types";
import { Prisma } from "@prisma/client";
import { streamChatCompletion } from "./openai";
import { wsConnection } from "~/utils/wsConnection";
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
type CompletionResponse = {
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
statusCode: number;
errorMessage: string | null;
timeToComplete: number
};
export async function getChatCompletion(
payload: JSONSerializable,
apiKey: string,
channel?: string,
): Promise<CompletionResponse> {
const start = Date.now();
const response = await fetch("https://api.openai.com/v1/chat/completions", {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${apiKey}`,
},
body: JSON.stringify(payload),
});
const resp: CompletionResponse = {
output: Prisma.JsonNull,
errorMessage: null,
statusCode: response.status,
timeToComplete: 0
};
try {
if (channel) {
const completion = streamChatCompletion(payload as unknown as CompletionCreateParams);
let finalOutput: ChatCompletion | null = null;
await (async () => {
for await (const partialCompletion of completion) {
finalOutput = partialCompletion
wsConnection.emit("message", { channel, payload: partialCompletion });
}
})().catch((err) => console.error(err));
resp.output = finalOutput as unknown as Prisma.InputJsonValue;
resp.timeToComplete = Date.now() - start;
} else {
resp.timeToComplete = Date.now() - start;
resp.output = await response.json();
}
if (!response.ok) {
// If it's an object, try to get the error message
if (
isObject(resp.output) &&
"error" in resp.output &&
isObject(resp.output.error) &&
"message" in resp.output.error
) {
resp.errorMessage = resp.output.error.message?.toString() ?? "Unknown error";
}
}
} catch (e) {
console.error(e);
if (response.ok) {
resp.errorMessage = "Failed to parse response";
}
}
return resp;
}

View File

@@ -0,0 +1,132 @@
/* eslint-disable @typescript-eslint/no-unsafe-call */
import { isObject } from "lodash";
import { Prisma } from "@prisma/client";
import { streamChatCompletion } from "./openai";
import { wsConnection } from "~/utils/wsConnection";
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
import { type JSONSerializable, OpenAIChatModels } from "../types";
import { env } from "~/env.mjs";
import { countOpenAIChatTokens } from "~/utils/countTokens";
env;
type CompletionResponse = {
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
statusCode: number;
errorMessage: string | null;
timeToComplete: number;
promptTokens?: number;
completionTokens?: number;
};
export async function getCompletion(
payload: JSONSerializable,
channel?: string
): Promise<CompletionResponse> {
if (!payload || !isObject(payload))
return {
output: Prisma.JsonNull,
statusCode: 400,
errorMessage: "Invalid payload provided",
timeToComplete: 0,
};
if (
"model" in payload &&
typeof payload.model === "string" &&
payload.model in OpenAIChatModels
) {
return getOpenAIChatCompletion(
payload as unknown as CompletionCreateParams,
env.OPENAI_API_KEY,
channel
);
}
return {
output: Prisma.JsonNull,
statusCode: 400,
errorMessage: "Invalid model provided",
timeToComplete: 0,
};
}
export async function getOpenAIChatCompletion(
payload: CompletionCreateParams,
apiKey: string,
channel?: string
): Promise<CompletionResponse> {
// If functions are enabled, disable streaming so that we get the full response with token counts
if (payload.functions?.length) payload.stream = false;
const start = Date.now();
const response = await fetch("https://api.openai.com/v1/chat/completions", {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${apiKey}`,
},
body: JSON.stringify(payload),
});
const resp: CompletionResponse = {
output: Prisma.JsonNull,
errorMessage: null,
statusCode: response.status,
timeToComplete: 0,
};
try {
if (payload.stream) {
const completion = streamChatCompletion(payload as unknown as CompletionCreateParams);
let finalOutput: ChatCompletion | null = null;
await (async () => {
for await (const partialCompletion of completion) {
finalOutput = partialCompletion;
wsConnection.emit("message", { channel, payload: partialCompletion });
}
})().catch((err) => console.error(err));
if (finalOutput) {
resp.output = finalOutput as unknown as Prisma.InputJsonValue;
resp.timeToComplete = Date.now() - start;
}
} else {
resp.timeToComplete = Date.now() - start;
resp.output = await response.json();
}
if (!response.ok) {
// If it's an object, try to get the error message
if (
isObject(resp.output) &&
"error" in resp.output &&
isObject(resp.output.error) &&
"message" in resp.output.error
) {
resp.errorMessage = resp.output.error.message?.toString() ?? "Unknown error";
}
}
if (isObject(resp.output) && "usage" in resp.output) {
const usage = resp.output.usage as unknown as ChatCompletion.Usage;
resp.promptTokens = usage.prompt_tokens;
resp.completionTokens = usage.completion_tokens;
} else if (isObject(resp.output) && 'choices' in resp.output) {
const model = payload.model as unknown as OpenAIChatModels
resp.promptTokens = countOpenAIChatTokens(
model,
payload.messages
);
const choices = resp.output.choices as unknown as ChatCompletion.Choice[];
const message = choices[0]?.message
if (message) {
const messages = [message]
resp.completionTokens = countOpenAIChatTokens(model, messages);
}
}
} catch (e) {
console.error(e);
if (response.ok) {
resp.errorMessage = "Failed to parse response";
}
}
return resp;
}