From 92c240e7b86f70de3907b5bae0079b5dea3a99b6 Mon Sep 17 00:00:00 2001 From: arcticfly <41524992+arcticfly@users.noreply.github.com> Date: Thu, 6 Jul 2023 14:36:31 -0700 Subject: [PATCH] Add request cost to OutputStats (#12) --- src/components/OutputsTable/OutputCell.tsx | 30 ++++++++++++++--- src/server/types.ts | 6 ++-- src/server/utils/getCompletion.ts | 12 +++---- src/server/utils/getModelName.ts | 8 +++++ src/utils/calculateTokenCost.ts | 39 ++++++++++++++++++++++ src/utils/countTokens.ts | 4 +-- 6 files changed, 83 insertions(+), 16 deletions(-) create mode 100644 src/server/utils/getModelName.ts create mode 100644 src/utils/calculateTokenCost.ts diff --git a/src/components/OutputsTable/OutputCell.tsx b/src/components/OutputsTable/OutputCell.tsx index c38b663..15cc745 100644 --- a/src/components/OutputsTable/OutputCell.tsx +++ b/src/components/OutputsTable/OutputCell.tsx @@ -6,13 +6,16 @@ import SyntaxHighlighter from "react-syntax-highlighter"; import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs"; import stringify from "json-stringify-pretty-compact"; import { useMemo, type ReactElement } from "react"; -import { BsCheck, BsClock, BsX } from "react-icons/bs"; +import { BsCheck, BsClock, BsX, BsCurrencyDollar } from "react-icons/bs"; import { type ModelOutput } from "@prisma/client"; import { type ChatCompletion } from "openai/resources/chat"; import { generateChannel } from "~/utils/generateChannel"; import { isObject } from "lodash"; import useSocket from "~/utils/useSocket"; import { evaluateOutput } from "~/server/utils/evaluateOutput"; +import { calculateTokenCost } from "~/utils/calculateTokenCost"; +import { type JSONSerializable, type SupportedModel } from "~/server/types"; +import { getModelName } from "~/server/utils/getModelName"; export default function OutputCell({ scenario, @@ -37,6 +40,8 @@ export default function OutputCell({ if (variant.config === null || Object.keys(variant.config).length === 0) disabledReason = "Save your prompt variant to see output"; + const model = getModelName(variant.config as JSONSerializable); + const shouldStream = isObject(variant) && "config" in variant && @@ -110,7 +115,7 @@ export default function OutputCell({ { maxLength: 40 } )} - + ); } @@ -121,15 +126,17 @@ export default function OutputCell({ return ( {contentToDisplay} - {output.data && } + {output.data && } ); } const OutputStats = ({ + model, modelOutput, scenario, }: { + model: SupportedModel | null; modelOutput: ModelOutput; scenario: Scenario; }) => { @@ -138,6 +145,15 @@ const OutputStats = ({ const evals = api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? []; + const promptTokens = modelOutput.promptTokens; + const completionTokens = modelOutput.completionTokens; + + const promptCost = promptTokens && model ? calculateTokenCost(model, promptTokens) : 0; + const completionCost = + completionTokens && model ? calculateTokenCost(model, completionTokens, true) : 0; + + const cost = promptCost + completionCost; + return ( @@ -155,8 +171,12 @@ const OutputStats = ({ ); })} - - + + + {cost.toFixed(3)} + + + {(timeToComplete / 1000).toFixed(2)}s diff --git a/src/server/types.ts b/src/server/types.ts index 7960aad..c82930b 100644 --- a/src/server/types.ts +++ b/src/server/types.ts @@ -9,15 +9,15 @@ export type JSONSerializable = // Placeholder for now export type OpenAIChatConfig = NonNullable; -export enum OpenAIChatModels { +export enum OpenAIChatModel { "gpt-4" = "gpt-4", "gpt-4-0613" = "gpt-4-0613", "gpt-4-32k" = "gpt-4-32k", "gpt-4-32k-0613" = "gpt-4-32k-0613", "gpt-3.5-turbo" = "gpt-3.5-turbo", - "gpt-3.5-turbo-16k" = "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0613" = "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k" = "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613" = "gpt-3.5-turbo-16k-0613", } -type SupportedModel = keyof typeof OpenAIChatModels; +export type SupportedModel = keyof typeof OpenAIChatModel; diff --git a/src/server/utils/getCompletion.ts b/src/server/utils/getCompletion.ts index 945f2eb..80cfa8c 100644 --- a/src/server/utils/getCompletion.ts +++ b/src/server/utils/getCompletion.ts @@ -4,9 +4,10 @@ import { Prisma } from "@prisma/client"; import { streamChatCompletion } from "./openai"; import { wsConnection } from "~/utils/wsConnection"; import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat"; -import { type JSONSerializable, OpenAIChatModels } from "../types"; +import { type JSONSerializable, OpenAIChatModel } from "../types"; import { env } from "~/env.mjs"; import { countOpenAIChatTokens } from "~/utils/countTokens"; +import { getModelName } from "./getModelName"; env; @@ -23,7 +24,8 @@ export async function getCompletion( payload: JSONSerializable, channel?: string ): Promise { - if (!payload || !isObject(payload)) + const modelName = getModelName(payload); + if (!modelName) return { output: Prisma.JsonNull, statusCode: 400, @@ -31,9 +33,7 @@ export async function getCompletion( timeToComplete: 0, }; if ( - "model" in payload && - typeof payload.model === "string" && - payload.model in OpenAIChatModels + modelName in OpenAIChatModel ) { return getOpenAIChatCompletion( payload as unknown as CompletionCreateParams, @@ -109,7 +109,7 @@ export async function getOpenAIChatCompletion( resp.promptTokens = usage.prompt_tokens; resp.completionTokens = usage.completion_tokens; } else if (isObject(resp.output) && 'choices' in resp.output) { - const model = payload.model as unknown as OpenAIChatModels + const model = payload.model as unknown as OpenAIChatModel resp.promptTokens = countOpenAIChatTokens( model, payload.messages diff --git a/src/server/utils/getModelName.ts b/src/server/utils/getModelName.ts new file mode 100644 index 0000000..5118ea1 --- /dev/null +++ b/src/server/utils/getModelName.ts @@ -0,0 +1,8 @@ +import { isObject } from "lodash"; +import { type JSONSerializable, type SupportedModel } from "../types"; + +export function getModelName(config: JSONSerializable): SupportedModel | null { + if (!isObject(config)) return null; + if ("model" in config && typeof config.model === "string") return config.model as SupportedModel; + return null +} diff --git a/src/utils/calculateTokenCost.ts b/src/utils/calculateTokenCost.ts new file mode 100644 index 0000000..978f7c5 --- /dev/null +++ b/src/utils/calculateTokenCost.ts @@ -0,0 +1,39 @@ +import { type SupportedModel, OpenAIChatModel } from "~/server/types"; + +const openAIPromptTokensToDollars: { [key in OpenAIChatModel]: number } = { + "gpt-4": 0.00003, + "gpt-4-0613": 0.00003, + "gpt-4-32k": 0.00006, + "gpt-4-32k-0613": 0.00006, + "gpt-3.5-turbo": 0.0000015, + "gpt-3.5-turbo-0613": 0.0000015, + "gpt-3.5-turbo-16k": 0.000003, + "gpt-3.5-turbo-16k-0613": 0.000003, +}; + +const openAICompletionTokensToDollars: { [key in OpenAIChatModel]: number } = { + "gpt-4": 0.00006, + "gpt-4-0613": 0.00006, + "gpt-4-32k": 0.00012, + "gpt-4-32k-0613": 0.00012, + "gpt-3.5-turbo": 0.000002, + "gpt-3.5-turbo-0613": 0.000002, + "gpt-3.5-turbo-16k": 0.000004, + "gpt-3.5-turbo-16k-0613": 0.000004, +}; + +export const calculateTokenCost = (model: SupportedModel, numTokens: number, isCompletion = false) => { + if (model in OpenAIChatModel) { + return calculateOpenAIChatTokenCost(model as OpenAIChatModel, numTokens, isCompletion); + } + return 0; +} + +const calculateOpenAIChatTokenCost = ( + model: OpenAIChatModel, + numTokens: number, + isCompletion: boolean +) => { + const tokensToDollars = isCompletion ? openAICompletionTokensToDollars[model] : openAIPromptTokensToDollars[model]; + return tokensToDollars * numTokens; +}; diff --git a/src/utils/countTokens.ts b/src/utils/countTokens.ts index 26e03fc..ca8dd25 100644 --- a/src/utils/countTokens.ts +++ b/src/utils/countTokens.ts @@ -1,6 +1,6 @@ import { type ChatCompletion } from "openai/resources/chat"; import { GPTTokens } from "gpt-tokens"; -import { type OpenAIChatModels } from "~/server/types"; +import { type OpenAIChatModel } from "~/server/types"; interface GPTTokensMessageItem { name?: string; @@ -9,7 +9,7 @@ interface GPTTokensMessageItem { } export const countOpenAIChatTokens = ( - model: OpenAIChatModels, + model: OpenAIChatModel, messages: ChatCompletion.Choice.Message[] ) => { return new GPTTokens({ model, messages: messages as unknown as GPTTokensMessageItem[] })