Add promptTokens and completionTokens to model output (#11)
* Default to streaming in config * Add tokens to database * Add NEXT_PUBLIC_SOCKET_URL to .env.example * Disable streaming for functions * Add newline to types
This commit is contained in:
@@ -16,3 +16,5 @@ DATABASE_URL="postgresql://postgres:postgres@localhost:5432/querykey?schema=publ
|
|||||||
# OpenAI API key. Instructions on generating a key can be found here:
|
# OpenAI API key. Instructions on generating a key can be found here:
|
||||||
# https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key
|
# https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key
|
||||||
OPENAI_API_KEY=""
|
OPENAI_API_KEY=""
|
||||||
|
|
||||||
|
NEXT_PUBLIC_SOCKET_URL="http://localhost:3318"
|
||||||
|
|||||||
@@ -35,6 +35,7 @@
|
|||||||
"dotenv": "^16.3.1",
|
"dotenv": "^16.3.1",
|
||||||
"express": "^4.18.2",
|
"express": "^4.18.2",
|
||||||
"framer-motion": "^10.12.17",
|
"framer-motion": "^10.12.17",
|
||||||
|
"gpt-tokens": "^1.0.10",
|
||||||
"json-stringify-pretty-compact": "^4.0.0",
|
"json-stringify-pretty-compact": "^4.0.0",
|
||||||
"lodash": "^4.17.21",
|
"lodash": "^4.17.21",
|
||||||
"next": "^13.4.2",
|
"next": "^13.4.2",
|
||||||
|
|||||||
26
pnpm-lock.yaml
generated
26
pnpm-lock.yaml
generated
@@ -1,4 +1,4 @@
|
|||||||
lockfileVersion: '6.1'
|
lockfileVersion: '6.0'
|
||||||
|
|
||||||
settings:
|
settings:
|
||||||
autoInstallPeers: true
|
autoInstallPeers: true
|
||||||
@@ -68,6 +68,9 @@ dependencies:
|
|||||||
framer-motion:
|
framer-motion:
|
||||||
specifier: ^10.12.17
|
specifier: ^10.12.17
|
||||||
version: 10.12.17(react-dom@18.2.0)(react@18.2.0)
|
version: 10.12.17(react-dom@18.2.0)(react@18.2.0)
|
||||||
|
gpt-tokens:
|
||||||
|
specifier: ^1.0.10
|
||||||
|
version: 1.0.10
|
||||||
json-stringify-pretty-compact:
|
json-stringify-pretty-compact:
|
||||||
specifier: ^4.0.0
|
specifier: ^4.0.0
|
||||||
version: 4.0.0
|
version: 4.0.0
|
||||||
@@ -2548,6 +2551,10 @@ packages:
|
|||||||
resolution: {integrity: sha512-Y5gU45svrR5tI2Vt/X9GPd3L0HNIKzGu202EjxrXMpuc2V2CiKgemAbUUsqYmZJvPtCXoUKjNZwBJzsNScUbXA==}
|
resolution: {integrity: sha512-Y5gU45svrR5tI2Vt/X9GPd3L0HNIKzGu202EjxrXMpuc2V2CiKgemAbUUsqYmZJvPtCXoUKjNZwBJzsNScUbXA==}
|
||||||
dev: false
|
dev: false
|
||||||
|
|
||||||
|
/base64-js@1.5.1:
|
||||||
|
resolution: {integrity: sha512-AKpaYlHn8t4SVbOHCy+b5+KKgvR4vrsD8vbvrbiQJps7fKDTkjkDry6ji0rUJjC0kzbNePLwzxq8iypo41qeWA==}
|
||||||
|
dev: false
|
||||||
|
|
||||||
/base64id@2.0.0:
|
/base64id@2.0.0:
|
||||||
resolution: {integrity: sha512-lGe34o6EHj9y3Kts9R4ZYs/Gr+6N7MCaMlIFA3F1R2O5/m7K06AxfSeO5530PEERE6/WyEg3lsuyw4GHlPZHog==}
|
resolution: {integrity: sha512-lGe34o6EHj9y3Kts9R4ZYs/Gr+6N7MCaMlIFA3F1R2O5/m7K06AxfSeO5530PEERE6/WyEg3lsuyw4GHlPZHog==}
|
||||||
engines: {node: ^4.5.0 || >= 5.9}
|
engines: {node: ^4.5.0 || >= 5.9}
|
||||||
@@ -2897,6 +2904,10 @@ packages:
|
|||||||
dependencies:
|
dependencies:
|
||||||
ms: 2.1.2
|
ms: 2.1.2
|
||||||
|
|
||||||
|
/decimal.js@10.4.3:
|
||||||
|
resolution: {integrity: sha512-VBBaLc1MgL5XpzgIP7ny5Z6Nx3UrRkIViUkPUdtl9aya5amy3De1gsUUSB1g3+3sExYNjCAsAznmukyxCb1GRA==}
|
||||||
|
dev: false
|
||||||
|
|
||||||
/deep-is@0.1.4:
|
/deep-is@0.1.4:
|
||||||
resolution: {integrity: sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ==}
|
resolution: {integrity: sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ==}
|
||||||
dev: true
|
dev: true
|
||||||
@@ -3879,6 +3890,13 @@ packages:
|
|||||||
get-intrinsic: 1.2.1
|
get-intrinsic: 1.2.1
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/gpt-tokens@1.0.10:
|
||||||
|
resolution: {integrity: sha512-DNWfqhu+ZAbjTUT76Xc5UBE+e7L0WejsrbiJy+/zgvA2C4697OFN6TLfQY7zaWlay8bNUKqLzbStz0VI0thDtQ==}
|
||||||
|
dependencies:
|
||||||
|
decimal.js: 10.4.3
|
||||||
|
js-tiktoken: 1.0.7
|
||||||
|
dev: false
|
||||||
|
|
||||||
/graceful-fs@4.2.11:
|
/graceful-fs@4.2.11:
|
||||||
resolution: {integrity: sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==}
|
resolution: {integrity: sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==}
|
||||||
dev: true
|
dev: true
|
||||||
@@ -4259,6 +4277,12 @@ packages:
|
|||||||
resolution: {integrity: sha512-6Gsx8R0RucyePbWqPssR8DyfuXmLBooYN5cZFZKjHGnQuaf7pEzhtpceagJxVu4LqhYY5EYA7nko3FmeHZ1KbA==}
|
resolution: {integrity: sha512-6Gsx8R0RucyePbWqPssR8DyfuXmLBooYN5cZFZKjHGnQuaf7pEzhtpceagJxVu4LqhYY5EYA7nko3FmeHZ1KbA==}
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/js-tiktoken@1.0.7:
|
||||||
|
resolution: {integrity: sha512-biba8u/clw7iesNEWLOLwrNGoBP2lA+hTaBLs/D45pJdUPFXyxD6nhcDVtADChghv4GgyAiMKYMiRx7x6h7Biw==}
|
||||||
|
dependencies:
|
||||||
|
base64-js: 1.5.1
|
||||||
|
dev: false
|
||||||
|
|
||||||
/js-tokens@4.0.0:
|
/js-tokens@4.0.0:
|
||||||
resolution: {integrity: sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==}
|
resolution: {integrity: sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "ModelOutput" ADD COLUMN "completionTokens" INTEGER,
|
||||||
|
ADD COLUMN "promptTokens" INTEGER;
|
||||||
@@ -82,6 +82,9 @@ model ModelOutput {
|
|||||||
errorMessage String?
|
errorMessage String?
|
||||||
timeToComplete Int @default(0)
|
timeToComplete Int @default(0)
|
||||||
|
|
||||||
|
promptTokens Int? // Added promptTokens field
|
||||||
|
completionTokens Int? // Added completionTokens field
|
||||||
|
|
||||||
promptVariantId String @db.Uuid
|
promptVariantId String @db.Uuid
|
||||||
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
|
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
sortIndex: 0,
|
sortIndex: 0,
|
||||||
config: {
|
config: {
|
||||||
model: "gpt-3.5-turbo",
|
model: "gpt-3.5-turbo",
|
||||||
|
stream: true,
|
||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
role: "system",
|
role: "system",
|
||||||
|
|||||||
@@ -3,12 +3,9 @@ import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
|||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import fillTemplate, { type VariableMap } from "~/server/utils/fillTemplate";
|
import fillTemplate, { type VariableMap } from "~/server/utils/fillTemplate";
|
||||||
import { type JSONSerializable } from "~/server/types";
|
import { type JSONSerializable } from "~/server/types";
|
||||||
import { getChatCompletion } from "~/server/utils/getChatCompletion";
|
import { getCompletion } from "~/server/utils/getCompletion";
|
||||||
import crypto from "crypto";
|
import crypto from "crypto";
|
||||||
import type { Prisma } from "@prisma/client";
|
import type { Prisma } from "@prisma/client";
|
||||||
import { env } from "~/env.mjs";
|
|
||||||
|
|
||||||
env;
|
|
||||||
|
|
||||||
export const modelOutputsRouter = createTRPCRouter({
|
export const modelOutputsRouter = createTRPCRouter({
|
||||||
get: publicProcedure
|
get: publicProcedure
|
||||||
@@ -54,7 +51,7 @@ export const modelOutputsRouter = createTRPCRouter({
|
|||||||
where: { inputHash, errorMessage: null },
|
where: { inputHash, errorMessage: null },
|
||||||
});
|
});
|
||||||
|
|
||||||
let modelResponse: Awaited<ReturnType<typeof getChatCompletion>>;
|
let modelResponse: Awaited<ReturnType<typeof getCompletion>>;
|
||||||
|
|
||||||
if (existingResponse) {
|
if (existingResponse) {
|
||||||
modelResponse = {
|
modelResponse = {
|
||||||
@@ -64,7 +61,7 @@ export const modelOutputsRouter = createTRPCRouter({
|
|||||||
timeToComplete: existingResponse.timeToComplete,
|
timeToComplete: existingResponse.timeToComplete,
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
modelResponse = await getChatCompletion(filledTemplate, env.OPENAI_API_KEY, input.channel);
|
modelResponse = await getCompletion(filledTemplate, input.channel);
|
||||||
}
|
}
|
||||||
|
|
||||||
const modelOutput = await prisma.modelOutput.create({
|
const modelOutput = await prisma.modelOutput.create({
|
||||||
|
|||||||
@@ -8,3 +8,16 @@ export type JSONSerializable =
|
|||||||
|
|
||||||
// Placeholder for now
|
// Placeholder for now
|
||||||
export type OpenAIChatConfig = NonNullable<JSONSerializable>;
|
export type OpenAIChatConfig = NonNullable<JSONSerializable>;
|
||||||
|
|
||||||
|
export enum OpenAIChatModels {
|
||||||
|
"gpt-4" = "gpt-4",
|
||||||
|
"gpt-4-0613" = "gpt-4-0613",
|
||||||
|
"gpt-4-32k" = "gpt-4-32k",
|
||||||
|
"gpt-4-32k-0613" = "gpt-4-32k-0613",
|
||||||
|
"gpt-3.5-turbo" = "gpt-3.5-turbo",
|
||||||
|
"gpt-3.5-turbo-16k" = "gpt-3.5-turbo-16k",
|
||||||
|
"gpt-3.5-turbo-0613" = "gpt-3.5-turbo-0613",
|
||||||
|
"gpt-3.5-turbo-16k-0613" = "gpt-3.5-turbo-16k-0613",
|
||||||
|
}
|
||||||
|
|
||||||
|
type SupportedModel = keyof typeof OpenAIChatModels;
|
||||||
|
|||||||
@@ -1,74 +0,0 @@
|
|||||||
/* eslint-disable @typescript-eslint/no-unsafe-call */
|
|
||||||
import { isObject } from "lodash";
|
|
||||||
import { type JSONSerializable } from "../types";
|
|
||||||
import { Prisma } from "@prisma/client";
|
|
||||||
import { streamChatCompletion } from "./openai";
|
|
||||||
import { wsConnection } from "~/utils/wsConnection";
|
|
||||||
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
|
||||||
|
|
||||||
type CompletionResponse = {
|
|
||||||
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
|
|
||||||
statusCode: number;
|
|
||||||
errorMessage: string | null;
|
|
||||||
timeToComplete: number
|
|
||||||
};
|
|
||||||
|
|
||||||
export async function getChatCompletion(
|
|
||||||
payload: JSONSerializable,
|
|
||||||
apiKey: string,
|
|
||||||
channel?: string,
|
|
||||||
): Promise<CompletionResponse> {
|
|
||||||
const start = Date.now();
|
|
||||||
const response = await fetch("https://api.openai.com/v1/chat/completions", {
|
|
||||||
method: "POST",
|
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
Authorization: `Bearer ${apiKey}`,
|
|
||||||
},
|
|
||||||
body: JSON.stringify(payload),
|
|
||||||
});
|
|
||||||
|
|
||||||
const resp: CompletionResponse = {
|
|
||||||
output: Prisma.JsonNull,
|
|
||||||
errorMessage: null,
|
|
||||||
statusCode: response.status,
|
|
||||||
timeToComplete: 0
|
|
||||||
};
|
|
||||||
|
|
||||||
try {
|
|
||||||
if (channel) {
|
|
||||||
const completion = streamChatCompletion(payload as unknown as CompletionCreateParams);
|
|
||||||
let finalOutput: ChatCompletion | null = null;
|
|
||||||
await (async () => {
|
|
||||||
for await (const partialCompletion of completion) {
|
|
||||||
finalOutput = partialCompletion
|
|
||||||
wsConnection.emit("message", { channel, payload: partialCompletion });
|
|
||||||
}
|
|
||||||
})().catch((err) => console.error(err));
|
|
||||||
resp.output = finalOutput as unknown as Prisma.InputJsonValue;
|
|
||||||
resp.timeToComplete = Date.now() - start;
|
|
||||||
} else {
|
|
||||||
resp.timeToComplete = Date.now() - start;
|
|
||||||
resp.output = await response.json();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!response.ok) {
|
|
||||||
// If it's an object, try to get the error message
|
|
||||||
if (
|
|
||||||
isObject(resp.output) &&
|
|
||||||
"error" in resp.output &&
|
|
||||||
isObject(resp.output.error) &&
|
|
||||||
"message" in resp.output.error
|
|
||||||
) {
|
|
||||||
resp.errorMessage = resp.output.error.message?.toString() ?? "Unknown error";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
console.error(e);
|
|
||||||
if (response.ok) {
|
|
||||||
resp.errorMessage = "Failed to parse response";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp;
|
|
||||||
}
|
|
||||||
132
src/server/utils/getCompletion.ts
Normal file
132
src/server/utils/getCompletion.ts
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
/* eslint-disable @typescript-eslint/no-unsafe-call */
|
||||||
|
import { isObject } from "lodash";
|
||||||
|
import { Prisma } from "@prisma/client";
|
||||||
|
import { streamChatCompletion } from "./openai";
|
||||||
|
import { wsConnection } from "~/utils/wsConnection";
|
||||||
|
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
||||||
|
import { type JSONSerializable, OpenAIChatModels } from "../types";
|
||||||
|
import { env } from "~/env.mjs";
|
||||||
|
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||||
|
|
||||||
|
env;
|
||||||
|
|
||||||
|
type CompletionResponse = {
|
||||||
|
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
|
||||||
|
statusCode: number;
|
||||||
|
errorMessage: string | null;
|
||||||
|
timeToComplete: number;
|
||||||
|
promptTokens?: number;
|
||||||
|
completionTokens?: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export async function getCompletion(
|
||||||
|
payload: JSONSerializable,
|
||||||
|
channel?: string
|
||||||
|
): Promise<CompletionResponse> {
|
||||||
|
if (!payload || !isObject(payload))
|
||||||
|
return {
|
||||||
|
output: Prisma.JsonNull,
|
||||||
|
statusCode: 400,
|
||||||
|
errorMessage: "Invalid payload provided",
|
||||||
|
timeToComplete: 0,
|
||||||
|
};
|
||||||
|
if (
|
||||||
|
"model" in payload &&
|
||||||
|
typeof payload.model === "string" &&
|
||||||
|
payload.model in OpenAIChatModels
|
||||||
|
) {
|
||||||
|
return getOpenAIChatCompletion(
|
||||||
|
payload as unknown as CompletionCreateParams,
|
||||||
|
env.OPENAI_API_KEY,
|
||||||
|
channel
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
output: Prisma.JsonNull,
|
||||||
|
statusCode: 400,
|
||||||
|
errorMessage: "Invalid model provided",
|
||||||
|
timeToComplete: 0,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function getOpenAIChatCompletion(
|
||||||
|
payload: CompletionCreateParams,
|
||||||
|
apiKey: string,
|
||||||
|
channel?: string
|
||||||
|
): Promise<CompletionResponse> {
|
||||||
|
// If functions are enabled, disable streaming so that we get the full response with token counts
|
||||||
|
if (payload.functions?.length) payload.stream = false;
|
||||||
|
const start = Date.now();
|
||||||
|
const response = await fetch("https://api.openai.com/v1/chat/completions", {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
Authorization: `Bearer ${apiKey}`,
|
||||||
|
},
|
||||||
|
body: JSON.stringify(payload),
|
||||||
|
});
|
||||||
|
|
||||||
|
const resp: CompletionResponse = {
|
||||||
|
output: Prisma.JsonNull,
|
||||||
|
errorMessage: null,
|
||||||
|
statusCode: response.status,
|
||||||
|
timeToComplete: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (payload.stream) {
|
||||||
|
const completion = streamChatCompletion(payload as unknown as CompletionCreateParams);
|
||||||
|
let finalOutput: ChatCompletion | null = null;
|
||||||
|
await (async () => {
|
||||||
|
for await (const partialCompletion of completion) {
|
||||||
|
finalOutput = partialCompletion;
|
||||||
|
wsConnection.emit("message", { channel, payload: partialCompletion });
|
||||||
|
}
|
||||||
|
})().catch((err) => console.error(err));
|
||||||
|
if (finalOutput) {
|
||||||
|
resp.output = finalOutput as unknown as Prisma.InputJsonValue;
|
||||||
|
resp.timeToComplete = Date.now() - start;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
resp.timeToComplete = Date.now() - start;
|
||||||
|
resp.output = await response.json();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
// If it's an object, try to get the error message
|
||||||
|
if (
|
||||||
|
isObject(resp.output) &&
|
||||||
|
"error" in resp.output &&
|
||||||
|
isObject(resp.output.error) &&
|
||||||
|
"message" in resp.output.error
|
||||||
|
) {
|
||||||
|
resp.errorMessage = resp.output.error.message?.toString() ?? "Unknown error";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isObject(resp.output) && "usage" in resp.output) {
|
||||||
|
const usage = resp.output.usage as unknown as ChatCompletion.Usage;
|
||||||
|
resp.promptTokens = usage.prompt_tokens;
|
||||||
|
resp.completionTokens = usage.completion_tokens;
|
||||||
|
} else if (isObject(resp.output) && 'choices' in resp.output) {
|
||||||
|
const model = payload.model as unknown as OpenAIChatModels
|
||||||
|
resp.promptTokens = countOpenAIChatTokens(
|
||||||
|
model,
|
||||||
|
payload.messages
|
||||||
|
);
|
||||||
|
const choices = resp.output.choices as unknown as ChatCompletion.Choice[];
|
||||||
|
const message = choices[0]?.message
|
||||||
|
if (message) {
|
||||||
|
const messages = [message]
|
||||||
|
resp.completionTokens = countOpenAIChatTokens(model, messages);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
if (response.ok) {
|
||||||
|
resp.errorMessage = "Failed to parse response";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp;
|
||||||
|
}
|
||||||
17
src/utils/countTokens.ts
Normal file
17
src/utils/countTokens.ts
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
import { type ChatCompletion } from "openai/resources/chat";
|
||||||
|
import { GPTTokens } from "gpt-tokens";
|
||||||
|
import { type OpenAIChatModels } from "~/server/types";
|
||||||
|
|
||||||
|
interface GPTTokensMessageItem {
|
||||||
|
name?: string;
|
||||||
|
role: "system" | "user" | "assistant";
|
||||||
|
content: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const countOpenAIChatTokens = (
|
||||||
|
model: OpenAIChatModels,
|
||||||
|
messages: ChatCompletion.Choice.Message[]
|
||||||
|
) => {
|
||||||
|
return new GPTTokens({ model, messages: messages as unknown as GPTTokensMessageItem[] })
|
||||||
|
.usedTokens;
|
||||||
|
};
|
||||||
Reference in New Issue
Block a user