From 1ae5612d55bb735f8922a4e1c7ae1fdf67397bdd Mon Sep 17 00:00:00 2001 From: arcticfly <41524992+arcticfly@users.noreply.github.com> Date: Thu, 6 Jul 2023 13:12:59 -0700 Subject: [PATCH] Add promptTokens and completionTokens to model output (#11) * Default to streaming in config * Add tokens to database * Add NEXT_PUBLIC_SOCKET_URL to .env.example * Disable streaming for functions * Add newline to types --- .env.example | 4 +- package.json | 1 + pnpm-lock.yaml | 26 +++- .../migration.sql | 3 + prisma/schema.prisma | 3 + src/server/api/routers/experiments.router.ts | 1 + src/server/api/routers/modelOutputs.router.ts | 9 +- src/server/types.ts | 13 ++ src/server/utils/getChatCompletion.ts | 74 ---------- src/server/utils/getCompletion.ts | 132 ++++++++++++++++++ src/utils/countTokens.ts | 17 +++ 11 files changed, 201 insertions(+), 82 deletions(-) create mode 100644 prisma/migrations/20230706193243_add_tokens_to_model_output/migration.sql delete mode 100644 src/server/utils/getChatCompletion.ts create mode 100644 src/server/utils/getCompletion.ts create mode 100644 src/utils/countTokens.ts diff --git a/.env.example b/.env.example index 1ddae38..57a1921 100644 --- a/.env.example +++ b/.env.example @@ -15,4 +15,6 @@ DATABASE_URL="postgresql://postgres:postgres@localhost:5432/querykey?schema=publ # OpenAI API key. Instructions on generating a key can be found here: # https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key -OPENAI_API_KEY="" \ No newline at end of file +OPENAI_API_KEY="" + +NEXT_PUBLIC_SOCKET_URL="http://localhost:3318" diff --git a/package.json b/package.json index 77401e6..ef1c562 100644 --- a/package.json +++ b/package.json @@ -35,6 +35,7 @@ "dotenv": "^16.3.1", "express": "^4.18.2", "framer-motion": "^10.12.17", + "gpt-tokens": "^1.0.10", "json-stringify-pretty-compact": "^4.0.0", "lodash": "^4.17.21", "next": "^13.4.2", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 8c9ac71..afabca1 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1,4 +1,4 @@ -lockfileVersion: '6.1' +lockfileVersion: '6.0' settings: autoInstallPeers: true @@ -68,6 +68,9 @@ dependencies: framer-motion: specifier: ^10.12.17 version: 10.12.17(react-dom@18.2.0)(react@18.2.0) + gpt-tokens: + specifier: ^1.0.10 + version: 1.0.10 json-stringify-pretty-compact: specifier: ^4.0.0 version: 4.0.0 @@ -2548,6 +2551,10 @@ packages: resolution: {integrity: sha512-Y5gU45svrR5tI2Vt/X9GPd3L0HNIKzGu202EjxrXMpuc2V2CiKgemAbUUsqYmZJvPtCXoUKjNZwBJzsNScUbXA==} dev: false + /base64-js@1.5.1: + resolution: {integrity: sha512-AKpaYlHn8t4SVbOHCy+b5+KKgvR4vrsD8vbvrbiQJps7fKDTkjkDry6ji0rUJjC0kzbNePLwzxq8iypo41qeWA==} + dev: false + /base64id@2.0.0: resolution: {integrity: sha512-lGe34o6EHj9y3Kts9R4ZYs/Gr+6N7MCaMlIFA3F1R2O5/m7K06AxfSeO5530PEERE6/WyEg3lsuyw4GHlPZHog==} engines: {node: ^4.5.0 || >= 5.9} @@ -2897,6 +2904,10 @@ packages: dependencies: ms: 2.1.2 + /decimal.js@10.4.3: + resolution: {integrity: sha512-VBBaLc1MgL5XpzgIP7ny5Z6Nx3UrRkIViUkPUdtl9aya5amy3De1gsUUSB1g3+3sExYNjCAsAznmukyxCb1GRA==} + dev: false + /deep-is@0.1.4: resolution: {integrity: sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ==} dev: true @@ -3879,6 +3890,13 @@ packages: get-intrinsic: 1.2.1 dev: true + /gpt-tokens@1.0.10: + resolution: {integrity: sha512-DNWfqhu+ZAbjTUT76Xc5UBE+e7L0WejsrbiJy+/zgvA2C4697OFN6TLfQY7zaWlay8bNUKqLzbStz0VI0thDtQ==} + dependencies: + decimal.js: 10.4.3 + js-tiktoken: 1.0.7 + dev: false + /graceful-fs@4.2.11: resolution: {integrity: sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==} dev: true @@ -4259,6 +4277,12 @@ packages: resolution: {integrity: sha512-6Gsx8R0RucyePbWqPssR8DyfuXmLBooYN5cZFZKjHGnQuaf7pEzhtpceagJxVu4LqhYY5EYA7nko3FmeHZ1KbA==} dev: true + /js-tiktoken@1.0.7: + resolution: {integrity: sha512-biba8u/clw7iesNEWLOLwrNGoBP2lA+hTaBLs/D45pJdUPFXyxD6nhcDVtADChghv4GgyAiMKYMiRx7x6h7Biw==} + dependencies: + base64-js: 1.5.1 + dev: false + /js-tokens@4.0.0: resolution: {integrity: sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==} diff --git a/prisma/migrations/20230706193243_add_tokens_to_model_output/migration.sql b/prisma/migrations/20230706193243_add_tokens_to_model_output/migration.sql new file mode 100644 index 0000000..a486df3 --- /dev/null +++ b/prisma/migrations/20230706193243_add_tokens_to_model_output/migration.sql @@ -0,0 +1,3 @@ +-- AlterTable +ALTER TABLE "ModelOutput" ADD COLUMN "completionTokens" INTEGER, +ADD COLUMN "promptTokens" INTEGER; diff --git a/prisma/schema.prisma b/prisma/schema.prisma index 161a4cd..7e4e0d9 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -82,6 +82,9 @@ model ModelOutput { errorMessage String? timeToComplete Int @default(0) + promptTokens Int? // Added promptTokens field + completionTokens Int? // Added completionTokens field + promptVariantId String @db.Uuid promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade) diff --git a/src/server/api/routers/experiments.router.ts b/src/server/api/routers/experiments.router.ts index c7f475c..b2d5654 100644 --- a/src/server/api/routers/experiments.router.ts +++ b/src/server/api/routers/experiments.router.ts @@ -44,6 +44,7 @@ export const experimentsRouter = createTRPCRouter({ sortIndex: 0, config: { model: "gpt-3.5-turbo", + stream: true, messages: [ { role: "system", diff --git a/src/server/api/routers/modelOutputs.router.ts b/src/server/api/routers/modelOutputs.router.ts index d144747..44b7b52 100644 --- a/src/server/api/routers/modelOutputs.router.ts +++ b/src/server/api/routers/modelOutputs.router.ts @@ -3,12 +3,9 @@ import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { prisma } from "~/server/db"; import fillTemplate, { type VariableMap } from "~/server/utils/fillTemplate"; import { type JSONSerializable } from "~/server/types"; -import { getChatCompletion } from "~/server/utils/getChatCompletion"; +import { getCompletion } from "~/server/utils/getCompletion"; import crypto from "crypto"; import type { Prisma } from "@prisma/client"; -import { env } from "~/env.mjs"; - -env; export const modelOutputsRouter = createTRPCRouter({ get: publicProcedure @@ -54,7 +51,7 @@ export const modelOutputsRouter = createTRPCRouter({ where: { inputHash, errorMessage: null }, }); - let modelResponse: Awaited>; + let modelResponse: Awaited>; if (existingResponse) { modelResponse = { @@ -64,7 +61,7 @@ export const modelOutputsRouter = createTRPCRouter({ timeToComplete: existingResponse.timeToComplete, }; } else { - modelResponse = await getChatCompletion(filledTemplate, env.OPENAI_API_KEY, input.channel); + modelResponse = await getCompletion(filledTemplate, input.channel); } const modelOutput = await prisma.modelOutput.create({ diff --git a/src/server/types.ts b/src/server/types.ts index d110298..7960aad 100644 --- a/src/server/types.ts +++ b/src/server/types.ts @@ -8,3 +8,16 @@ export type JSONSerializable = // Placeholder for now export type OpenAIChatConfig = NonNullable; + +export enum OpenAIChatModels { + "gpt-4" = "gpt-4", + "gpt-4-0613" = "gpt-4-0613", + "gpt-4-32k" = "gpt-4-32k", + "gpt-4-32k-0613" = "gpt-4-32k-0613", + "gpt-3.5-turbo" = "gpt-3.5-turbo", + "gpt-3.5-turbo-16k" = "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613" = "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613" = "gpt-3.5-turbo-16k-0613", +} + +type SupportedModel = keyof typeof OpenAIChatModels; diff --git a/src/server/utils/getChatCompletion.ts b/src/server/utils/getChatCompletion.ts deleted file mode 100644 index 3669a9d..0000000 --- a/src/server/utils/getChatCompletion.ts +++ /dev/null @@ -1,74 +0,0 @@ -/* eslint-disable @typescript-eslint/no-unsafe-call */ -import { isObject } from "lodash"; -import { type JSONSerializable } from "../types"; -import { Prisma } from "@prisma/client"; -import { streamChatCompletion } from "./openai"; -import { wsConnection } from "~/utils/wsConnection"; -import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat"; - -type CompletionResponse = { - output: Prisma.InputJsonValue | typeof Prisma.JsonNull; - statusCode: number; - errorMessage: string | null; - timeToComplete: number -}; - -export async function getChatCompletion( - payload: JSONSerializable, - apiKey: string, - channel?: string, -): Promise { - const start = Date.now(); - const response = await fetch("https://api.openai.com/v1/chat/completions", { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${apiKey}`, - }, - body: JSON.stringify(payload), - }); - - const resp: CompletionResponse = { - output: Prisma.JsonNull, - errorMessage: null, - statusCode: response.status, - timeToComplete: 0 - }; - - try { - if (channel) { - const completion = streamChatCompletion(payload as unknown as CompletionCreateParams); - let finalOutput: ChatCompletion | null = null; - await (async () => { - for await (const partialCompletion of completion) { - finalOutput = partialCompletion - wsConnection.emit("message", { channel, payload: partialCompletion }); - } - })().catch((err) => console.error(err)); - resp.output = finalOutput as unknown as Prisma.InputJsonValue; - resp.timeToComplete = Date.now() - start; - } else { - resp.timeToComplete = Date.now() - start; - resp.output = await response.json(); - } - - if (!response.ok) { - // If it's an object, try to get the error message - if ( - isObject(resp.output) && - "error" in resp.output && - isObject(resp.output.error) && - "message" in resp.output.error - ) { - resp.errorMessage = resp.output.error.message?.toString() ?? "Unknown error"; - } - } - } catch (e) { - console.error(e); - if (response.ok) { - resp.errorMessage = "Failed to parse response"; - } - } - - return resp; -} diff --git a/src/server/utils/getCompletion.ts b/src/server/utils/getCompletion.ts new file mode 100644 index 0000000..945f2eb --- /dev/null +++ b/src/server/utils/getCompletion.ts @@ -0,0 +1,132 @@ +/* eslint-disable @typescript-eslint/no-unsafe-call */ +import { isObject } from "lodash"; +import { Prisma } from "@prisma/client"; +import { streamChatCompletion } from "./openai"; +import { wsConnection } from "~/utils/wsConnection"; +import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat"; +import { type JSONSerializable, OpenAIChatModels } from "../types"; +import { env } from "~/env.mjs"; +import { countOpenAIChatTokens } from "~/utils/countTokens"; + +env; + +type CompletionResponse = { + output: Prisma.InputJsonValue | typeof Prisma.JsonNull; + statusCode: number; + errorMessage: string | null; + timeToComplete: number; + promptTokens?: number; + completionTokens?: number; +}; + +export async function getCompletion( + payload: JSONSerializable, + channel?: string +): Promise { + if (!payload || !isObject(payload)) + return { + output: Prisma.JsonNull, + statusCode: 400, + errorMessage: "Invalid payload provided", + timeToComplete: 0, + }; + if ( + "model" in payload && + typeof payload.model === "string" && + payload.model in OpenAIChatModels + ) { + return getOpenAIChatCompletion( + payload as unknown as CompletionCreateParams, + env.OPENAI_API_KEY, + channel + ); + } + return { + output: Prisma.JsonNull, + statusCode: 400, + errorMessage: "Invalid model provided", + timeToComplete: 0, + }; +} + +export async function getOpenAIChatCompletion( + payload: CompletionCreateParams, + apiKey: string, + channel?: string +): Promise { + // If functions are enabled, disable streaming so that we get the full response with token counts + if (payload.functions?.length) payload.stream = false; + const start = Date.now(); + const response = await fetch("https://api.openai.com/v1/chat/completions", { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${apiKey}`, + }, + body: JSON.stringify(payload), + }); + + const resp: CompletionResponse = { + output: Prisma.JsonNull, + errorMessage: null, + statusCode: response.status, + timeToComplete: 0, + }; + + try { + if (payload.stream) { + const completion = streamChatCompletion(payload as unknown as CompletionCreateParams); + let finalOutput: ChatCompletion | null = null; + await (async () => { + for await (const partialCompletion of completion) { + finalOutput = partialCompletion; + wsConnection.emit("message", { channel, payload: partialCompletion }); + } + })().catch((err) => console.error(err)); + if (finalOutput) { + resp.output = finalOutput as unknown as Prisma.InputJsonValue; + resp.timeToComplete = Date.now() - start; + } + } else { + resp.timeToComplete = Date.now() - start; + resp.output = await response.json(); + } + + if (!response.ok) { + // If it's an object, try to get the error message + if ( + isObject(resp.output) && + "error" in resp.output && + isObject(resp.output.error) && + "message" in resp.output.error + ) { + resp.errorMessage = resp.output.error.message?.toString() ?? "Unknown error"; + } + } + + if (isObject(resp.output) && "usage" in resp.output) { + const usage = resp.output.usage as unknown as ChatCompletion.Usage; + resp.promptTokens = usage.prompt_tokens; + resp.completionTokens = usage.completion_tokens; + } else if (isObject(resp.output) && 'choices' in resp.output) { + const model = payload.model as unknown as OpenAIChatModels + resp.promptTokens = countOpenAIChatTokens( + model, + payload.messages + ); + const choices = resp.output.choices as unknown as ChatCompletion.Choice[]; + const message = choices[0]?.message + if (message) { + const messages = [message] + resp.completionTokens = countOpenAIChatTokens(model, messages); + } + } + } catch (e) { + console.error(e); + if (response.ok) { + resp.errorMessage = "Failed to parse response"; + } + } + + return resp; +} diff --git a/src/utils/countTokens.ts b/src/utils/countTokens.ts new file mode 100644 index 0000000..26e03fc --- /dev/null +++ b/src/utils/countTokens.ts @@ -0,0 +1,17 @@ +import { type ChatCompletion } from "openai/resources/chat"; +import { GPTTokens } from "gpt-tokens"; +import { type OpenAIChatModels } from "~/server/types"; + +interface GPTTokensMessageItem { + name?: string; + role: "system" | "user" | "assistant"; + content: string; +} + +export const countOpenAIChatTokens = ( + model: OpenAIChatModels, + messages: ChatCompletion.Choice.Message[] +) => { + return new GPTTokens({ model, messages: messages as unknown as GPTTokensMessageItem[] }) + .usedTokens; +};