From 92c240e7b86f70de3907b5bae0079b5dea3a99b6 Mon Sep 17 00:00:00 2001
From: arcticfly <41524992+arcticfly@users.noreply.github.com>
Date: Thu, 6 Jul 2023 14:36:31 -0700
Subject: [PATCH] Add request cost to OutputStats (#12)
---
src/components/OutputsTable/OutputCell.tsx | 30 ++++++++++++++---
src/server/types.ts | 6 ++--
src/server/utils/getCompletion.ts | 12 +++----
src/server/utils/getModelName.ts | 8 +++++
src/utils/calculateTokenCost.ts | 39 ++++++++++++++++++++++
src/utils/countTokens.ts | 4 +--
6 files changed, 83 insertions(+), 16 deletions(-)
create mode 100644 src/server/utils/getModelName.ts
create mode 100644 src/utils/calculateTokenCost.ts
diff --git a/src/components/OutputsTable/OutputCell.tsx b/src/components/OutputsTable/OutputCell.tsx
index c38b663..15cc745 100644
--- a/src/components/OutputsTable/OutputCell.tsx
+++ b/src/components/OutputsTable/OutputCell.tsx
@@ -6,13 +6,16 @@ import SyntaxHighlighter from "react-syntax-highlighter";
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
import stringify from "json-stringify-pretty-compact";
import { useMemo, type ReactElement } from "react";
-import { BsCheck, BsClock, BsX } from "react-icons/bs";
+import { BsCheck, BsClock, BsX, BsCurrencyDollar } from "react-icons/bs";
import { type ModelOutput } from "@prisma/client";
import { type ChatCompletion } from "openai/resources/chat";
import { generateChannel } from "~/utils/generateChannel";
import { isObject } from "lodash";
import useSocket from "~/utils/useSocket";
import { evaluateOutput } from "~/server/utils/evaluateOutput";
+import { calculateTokenCost } from "~/utils/calculateTokenCost";
+import { type JSONSerializable, type SupportedModel } from "~/server/types";
+import { getModelName } from "~/server/utils/getModelName";
export default function OutputCell({
scenario,
@@ -37,6 +40,8 @@ export default function OutputCell({
if (variant.config === null || Object.keys(variant.config).length === 0)
disabledReason = "Save your prompt variant to see output";
+ const model = getModelName(variant.config as JSONSerializable);
+
const shouldStream =
isObject(variant) &&
"config" in variant &&
@@ -110,7 +115,7 @@ export default function OutputCell({
{ maxLength: 40 }
)}
-
+
);
}
@@ -121,15 +126,17 @@ export default function OutputCell({
return (
{contentToDisplay}
- {output.data && }
+ {output.data && }
);
}
const OutputStats = ({
+ model,
modelOutput,
scenario,
}: {
+ model: SupportedModel | null;
modelOutput: ModelOutput;
scenario: Scenario;
}) => {
@@ -138,6 +145,15 @@ const OutputStats = ({
const evals =
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
+ const promptTokens = modelOutput.promptTokens;
+ const completionTokens = modelOutput.completionTokens;
+
+ const promptCost = promptTokens && model ? calculateTokenCost(model, promptTokens) : 0;
+ const completionCost =
+ completionTokens && model ? calculateTokenCost(model, completionTokens, true) : 0;
+
+ const cost = promptCost + completionCost;
+
return (
@@ -155,8 +171,12 @@ const OutputStats = ({
);
})}
-
-
+
+
+ {cost.toFixed(3)}
+
+
+
{(timeToComplete / 1000).toFixed(2)}s
diff --git a/src/server/types.ts b/src/server/types.ts
index 7960aad..c82930b 100644
--- a/src/server/types.ts
+++ b/src/server/types.ts
@@ -9,15 +9,15 @@ export type JSONSerializable =
// Placeholder for now
export type OpenAIChatConfig = NonNullable;
-export enum OpenAIChatModels {
+export enum OpenAIChatModel {
"gpt-4" = "gpt-4",
"gpt-4-0613" = "gpt-4-0613",
"gpt-4-32k" = "gpt-4-32k",
"gpt-4-32k-0613" = "gpt-4-32k-0613",
"gpt-3.5-turbo" = "gpt-3.5-turbo",
- "gpt-3.5-turbo-16k" = "gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0613" = "gpt-3.5-turbo-0613",
+ "gpt-3.5-turbo-16k" = "gpt-3.5-turbo-16k",
"gpt-3.5-turbo-16k-0613" = "gpt-3.5-turbo-16k-0613",
}
-type SupportedModel = keyof typeof OpenAIChatModels;
+export type SupportedModel = keyof typeof OpenAIChatModel;
diff --git a/src/server/utils/getCompletion.ts b/src/server/utils/getCompletion.ts
index 945f2eb..80cfa8c 100644
--- a/src/server/utils/getCompletion.ts
+++ b/src/server/utils/getCompletion.ts
@@ -4,9 +4,10 @@ import { Prisma } from "@prisma/client";
import { streamChatCompletion } from "./openai";
import { wsConnection } from "~/utils/wsConnection";
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
-import { type JSONSerializable, OpenAIChatModels } from "../types";
+import { type JSONSerializable, OpenAIChatModel } from "../types";
import { env } from "~/env.mjs";
import { countOpenAIChatTokens } from "~/utils/countTokens";
+import { getModelName } from "./getModelName";
env;
@@ -23,7 +24,8 @@ export async function getCompletion(
payload: JSONSerializable,
channel?: string
): Promise {
- if (!payload || !isObject(payload))
+ const modelName = getModelName(payload);
+ if (!modelName)
return {
output: Prisma.JsonNull,
statusCode: 400,
@@ -31,9 +33,7 @@ export async function getCompletion(
timeToComplete: 0,
};
if (
- "model" in payload &&
- typeof payload.model === "string" &&
- payload.model in OpenAIChatModels
+ modelName in OpenAIChatModel
) {
return getOpenAIChatCompletion(
payload as unknown as CompletionCreateParams,
@@ -109,7 +109,7 @@ export async function getOpenAIChatCompletion(
resp.promptTokens = usage.prompt_tokens;
resp.completionTokens = usage.completion_tokens;
} else if (isObject(resp.output) && 'choices' in resp.output) {
- const model = payload.model as unknown as OpenAIChatModels
+ const model = payload.model as unknown as OpenAIChatModel
resp.promptTokens = countOpenAIChatTokens(
model,
payload.messages
diff --git a/src/server/utils/getModelName.ts b/src/server/utils/getModelName.ts
new file mode 100644
index 0000000..5118ea1
--- /dev/null
+++ b/src/server/utils/getModelName.ts
@@ -0,0 +1,8 @@
+import { isObject } from "lodash";
+import { type JSONSerializable, type SupportedModel } from "../types";
+
+export function getModelName(config: JSONSerializable): SupportedModel | null {
+ if (!isObject(config)) return null;
+ if ("model" in config && typeof config.model === "string") return config.model as SupportedModel;
+ return null
+}
diff --git a/src/utils/calculateTokenCost.ts b/src/utils/calculateTokenCost.ts
new file mode 100644
index 0000000..978f7c5
--- /dev/null
+++ b/src/utils/calculateTokenCost.ts
@@ -0,0 +1,39 @@
+import { type SupportedModel, OpenAIChatModel } from "~/server/types";
+
+const openAIPromptTokensToDollars: { [key in OpenAIChatModel]: number } = {
+ "gpt-4": 0.00003,
+ "gpt-4-0613": 0.00003,
+ "gpt-4-32k": 0.00006,
+ "gpt-4-32k-0613": 0.00006,
+ "gpt-3.5-turbo": 0.0000015,
+ "gpt-3.5-turbo-0613": 0.0000015,
+ "gpt-3.5-turbo-16k": 0.000003,
+ "gpt-3.5-turbo-16k-0613": 0.000003,
+};
+
+const openAICompletionTokensToDollars: { [key in OpenAIChatModel]: number } = {
+ "gpt-4": 0.00006,
+ "gpt-4-0613": 0.00006,
+ "gpt-4-32k": 0.00012,
+ "gpt-4-32k-0613": 0.00012,
+ "gpt-3.5-turbo": 0.000002,
+ "gpt-3.5-turbo-0613": 0.000002,
+ "gpt-3.5-turbo-16k": 0.000004,
+ "gpt-3.5-turbo-16k-0613": 0.000004,
+};
+
+export const calculateTokenCost = (model: SupportedModel, numTokens: number, isCompletion = false) => {
+ if (model in OpenAIChatModel) {
+ return calculateOpenAIChatTokenCost(model as OpenAIChatModel, numTokens, isCompletion);
+ }
+ return 0;
+}
+
+const calculateOpenAIChatTokenCost = (
+ model: OpenAIChatModel,
+ numTokens: number,
+ isCompletion: boolean
+) => {
+ const tokensToDollars = isCompletion ? openAICompletionTokensToDollars[model] : openAIPromptTokensToDollars[model];
+ return tokensToDollars * numTokens;
+};
diff --git a/src/utils/countTokens.ts b/src/utils/countTokens.ts
index 26e03fc..ca8dd25 100644
--- a/src/utils/countTokens.ts
+++ b/src/utils/countTokens.ts
@@ -1,6 +1,6 @@
import { type ChatCompletion } from "openai/resources/chat";
import { GPTTokens } from "gpt-tokens";
-import { type OpenAIChatModels } from "~/server/types";
+import { type OpenAIChatModel } from "~/server/types";
interface GPTTokensMessageItem {
name?: string;
@@ -9,7 +9,7 @@ interface GPTTokensMessageItem {
}
export const countOpenAIChatTokens = (
- model: OpenAIChatModels,
+ model: OpenAIChatModel,
messages: ChatCompletion.Choice.Message[]
) => {
return new GPTTokens({ model, messages: messages as unknown as GPTTokensMessageItem[] })