Add request cost to OutputStats (#12)
This commit is contained in:
@@ -6,13 +6,16 @@ import SyntaxHighlighter from "react-syntax-highlighter";
|
|||||||
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
|
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
|
||||||
import stringify from "json-stringify-pretty-compact";
|
import stringify from "json-stringify-pretty-compact";
|
||||||
import { useMemo, type ReactElement } from "react";
|
import { useMemo, type ReactElement } from "react";
|
||||||
import { BsCheck, BsClock, BsX } from "react-icons/bs";
|
import { BsCheck, BsClock, BsX, BsCurrencyDollar } from "react-icons/bs";
|
||||||
import { type ModelOutput } from "@prisma/client";
|
import { type ModelOutput } from "@prisma/client";
|
||||||
import { type ChatCompletion } from "openai/resources/chat";
|
import { type ChatCompletion } from "openai/resources/chat";
|
||||||
import { generateChannel } from "~/utils/generateChannel";
|
import { generateChannel } from "~/utils/generateChannel";
|
||||||
import { isObject } from "lodash";
|
import { isObject } from "lodash";
|
||||||
import useSocket from "~/utils/useSocket";
|
import useSocket from "~/utils/useSocket";
|
||||||
import { evaluateOutput } from "~/server/utils/evaluateOutput";
|
import { evaluateOutput } from "~/server/utils/evaluateOutput";
|
||||||
|
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
||||||
|
import { type JSONSerializable, type SupportedModel } from "~/server/types";
|
||||||
|
import { getModelName } from "~/server/utils/getModelName";
|
||||||
|
|
||||||
export default function OutputCell({
|
export default function OutputCell({
|
||||||
scenario,
|
scenario,
|
||||||
@@ -37,6 +40,8 @@ export default function OutputCell({
|
|||||||
if (variant.config === null || Object.keys(variant.config).length === 0)
|
if (variant.config === null || Object.keys(variant.config).length === 0)
|
||||||
disabledReason = "Save your prompt variant to see output";
|
disabledReason = "Save your prompt variant to see output";
|
||||||
|
|
||||||
|
const model = getModelName(variant.config as JSONSerializable);
|
||||||
|
|
||||||
const shouldStream =
|
const shouldStream =
|
||||||
isObject(variant) &&
|
isObject(variant) &&
|
||||||
"config" in variant &&
|
"config" in variant &&
|
||||||
@@ -110,7 +115,7 @@ export default function OutputCell({
|
|||||||
{ maxLength: 40 }
|
{ maxLength: 40 }
|
||||||
)}
|
)}
|
||||||
</SyntaxHighlighter>
|
</SyntaxHighlighter>
|
||||||
<OutputStats modelOutput={output.data} scenario={scenario} />
|
<OutputStats model={model} modelOutput={output.data} scenario={scenario} />
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -121,15 +126,17 @@ export default function OutputCell({
|
|||||||
return (
|
return (
|
||||||
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
|
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
|
||||||
{contentToDisplay}
|
{contentToDisplay}
|
||||||
{output.data && <OutputStats modelOutput={output.data} scenario={scenario} />}
|
{output.data && <OutputStats model={model} modelOutput={output.data} scenario={scenario} />}
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const OutputStats = ({
|
const OutputStats = ({
|
||||||
|
model,
|
||||||
modelOutput,
|
modelOutput,
|
||||||
scenario,
|
scenario,
|
||||||
}: {
|
}: {
|
||||||
|
model: SupportedModel | null;
|
||||||
modelOutput: ModelOutput;
|
modelOutput: ModelOutput;
|
||||||
scenario: Scenario;
|
scenario: Scenario;
|
||||||
}) => {
|
}) => {
|
||||||
@@ -138,6 +145,15 @@ const OutputStats = ({
|
|||||||
const evals =
|
const evals =
|
||||||
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
|
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
|
||||||
|
|
||||||
|
const promptTokens = modelOutput.promptTokens;
|
||||||
|
const completionTokens = modelOutput.completionTokens;
|
||||||
|
|
||||||
|
const promptCost = promptTokens && model ? calculateTokenCost(model, promptTokens) : 0;
|
||||||
|
const completionCost =
|
||||||
|
completionTokens && model ? calculateTokenCost(model, completionTokens, true) : 0;
|
||||||
|
|
||||||
|
const cost = promptCost + completionCost;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<HStack align="center" color="gray.500" fontSize="xs" mt={2}>
|
<HStack align="center" color="gray.500" fontSize="xs" mt={2}>
|
||||||
<HStack flex={1}>
|
<HStack flex={1}>
|
||||||
@@ -155,8 +171,12 @@ const OutputStats = ({
|
|||||||
);
|
);
|
||||||
})}
|
})}
|
||||||
</HStack>
|
</HStack>
|
||||||
<HStack>
|
<HStack spacing={0}>
|
||||||
<Icon as={BsClock} mr={0.5} />
|
<Icon as={BsCurrencyDollar} />
|
||||||
|
<Text mr={1}>{cost.toFixed(3)}</Text>
|
||||||
|
</HStack>
|
||||||
|
<HStack spacing={0.5}>
|
||||||
|
<Icon as={BsClock} />
|
||||||
<Text>{(timeToComplete / 1000).toFixed(2)}s</Text>
|
<Text>{(timeToComplete / 1000).toFixed(2)}s</Text>
|
||||||
</HStack>
|
</HStack>
|
||||||
</HStack>
|
</HStack>
|
||||||
|
|||||||
@@ -9,15 +9,15 @@ export type JSONSerializable =
|
|||||||
// Placeholder for now
|
// Placeholder for now
|
||||||
export type OpenAIChatConfig = NonNullable<JSONSerializable>;
|
export type OpenAIChatConfig = NonNullable<JSONSerializable>;
|
||||||
|
|
||||||
export enum OpenAIChatModels {
|
export enum OpenAIChatModel {
|
||||||
"gpt-4" = "gpt-4",
|
"gpt-4" = "gpt-4",
|
||||||
"gpt-4-0613" = "gpt-4-0613",
|
"gpt-4-0613" = "gpt-4-0613",
|
||||||
"gpt-4-32k" = "gpt-4-32k",
|
"gpt-4-32k" = "gpt-4-32k",
|
||||||
"gpt-4-32k-0613" = "gpt-4-32k-0613",
|
"gpt-4-32k-0613" = "gpt-4-32k-0613",
|
||||||
"gpt-3.5-turbo" = "gpt-3.5-turbo",
|
"gpt-3.5-turbo" = "gpt-3.5-turbo",
|
||||||
"gpt-3.5-turbo-16k" = "gpt-3.5-turbo-16k",
|
|
||||||
"gpt-3.5-turbo-0613" = "gpt-3.5-turbo-0613",
|
"gpt-3.5-turbo-0613" = "gpt-3.5-turbo-0613",
|
||||||
|
"gpt-3.5-turbo-16k" = "gpt-3.5-turbo-16k",
|
||||||
"gpt-3.5-turbo-16k-0613" = "gpt-3.5-turbo-16k-0613",
|
"gpt-3.5-turbo-16k-0613" = "gpt-3.5-turbo-16k-0613",
|
||||||
}
|
}
|
||||||
|
|
||||||
type SupportedModel = keyof typeof OpenAIChatModels;
|
export type SupportedModel = keyof typeof OpenAIChatModel;
|
||||||
|
|||||||
@@ -4,9 +4,10 @@ import { Prisma } from "@prisma/client";
|
|||||||
import { streamChatCompletion } from "./openai";
|
import { streamChatCompletion } from "./openai";
|
||||||
import { wsConnection } from "~/utils/wsConnection";
|
import { wsConnection } from "~/utils/wsConnection";
|
||||||
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
||||||
import { type JSONSerializable, OpenAIChatModels } from "../types";
|
import { type JSONSerializable, OpenAIChatModel } from "../types";
|
||||||
import { env } from "~/env.mjs";
|
import { env } from "~/env.mjs";
|
||||||
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||||
|
import { getModelName } from "./getModelName";
|
||||||
|
|
||||||
env;
|
env;
|
||||||
|
|
||||||
@@ -23,7 +24,8 @@ export async function getCompletion(
|
|||||||
payload: JSONSerializable,
|
payload: JSONSerializable,
|
||||||
channel?: string
|
channel?: string
|
||||||
): Promise<CompletionResponse> {
|
): Promise<CompletionResponse> {
|
||||||
if (!payload || !isObject(payload))
|
const modelName = getModelName(payload);
|
||||||
|
if (!modelName)
|
||||||
return {
|
return {
|
||||||
output: Prisma.JsonNull,
|
output: Prisma.JsonNull,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
@@ -31,9 +33,7 @@ export async function getCompletion(
|
|||||||
timeToComplete: 0,
|
timeToComplete: 0,
|
||||||
};
|
};
|
||||||
if (
|
if (
|
||||||
"model" in payload &&
|
modelName in OpenAIChatModel
|
||||||
typeof payload.model === "string" &&
|
|
||||||
payload.model in OpenAIChatModels
|
|
||||||
) {
|
) {
|
||||||
return getOpenAIChatCompletion(
|
return getOpenAIChatCompletion(
|
||||||
payload as unknown as CompletionCreateParams,
|
payload as unknown as CompletionCreateParams,
|
||||||
@@ -109,7 +109,7 @@ export async function getOpenAIChatCompletion(
|
|||||||
resp.promptTokens = usage.prompt_tokens;
|
resp.promptTokens = usage.prompt_tokens;
|
||||||
resp.completionTokens = usage.completion_tokens;
|
resp.completionTokens = usage.completion_tokens;
|
||||||
} else if (isObject(resp.output) && 'choices' in resp.output) {
|
} else if (isObject(resp.output) && 'choices' in resp.output) {
|
||||||
const model = payload.model as unknown as OpenAIChatModels
|
const model = payload.model as unknown as OpenAIChatModel
|
||||||
resp.promptTokens = countOpenAIChatTokens(
|
resp.promptTokens = countOpenAIChatTokens(
|
||||||
model,
|
model,
|
||||||
payload.messages
|
payload.messages
|
||||||
|
|||||||
8
src/server/utils/getModelName.ts
Normal file
8
src/server/utils/getModelName.ts
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
import { isObject } from "lodash";
|
||||||
|
import { type JSONSerializable, type SupportedModel } from "../types";
|
||||||
|
|
||||||
|
export function getModelName(config: JSONSerializable): SupportedModel | null {
|
||||||
|
if (!isObject(config)) return null;
|
||||||
|
if ("model" in config && typeof config.model === "string") return config.model as SupportedModel;
|
||||||
|
return null
|
||||||
|
}
|
||||||
39
src/utils/calculateTokenCost.ts
Normal file
39
src/utils/calculateTokenCost.ts
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import { type SupportedModel, OpenAIChatModel } from "~/server/types";
|
||||||
|
|
||||||
|
const openAIPromptTokensToDollars: { [key in OpenAIChatModel]: number } = {
|
||||||
|
"gpt-4": 0.00003,
|
||||||
|
"gpt-4-0613": 0.00003,
|
||||||
|
"gpt-4-32k": 0.00006,
|
||||||
|
"gpt-4-32k-0613": 0.00006,
|
||||||
|
"gpt-3.5-turbo": 0.0000015,
|
||||||
|
"gpt-3.5-turbo-0613": 0.0000015,
|
||||||
|
"gpt-3.5-turbo-16k": 0.000003,
|
||||||
|
"gpt-3.5-turbo-16k-0613": 0.000003,
|
||||||
|
};
|
||||||
|
|
||||||
|
const openAICompletionTokensToDollars: { [key in OpenAIChatModel]: number } = {
|
||||||
|
"gpt-4": 0.00006,
|
||||||
|
"gpt-4-0613": 0.00006,
|
||||||
|
"gpt-4-32k": 0.00012,
|
||||||
|
"gpt-4-32k-0613": 0.00012,
|
||||||
|
"gpt-3.5-turbo": 0.000002,
|
||||||
|
"gpt-3.5-turbo-0613": 0.000002,
|
||||||
|
"gpt-3.5-turbo-16k": 0.000004,
|
||||||
|
"gpt-3.5-turbo-16k-0613": 0.000004,
|
||||||
|
};
|
||||||
|
|
||||||
|
export const calculateTokenCost = (model: SupportedModel, numTokens: number, isCompletion = false) => {
|
||||||
|
if (model in OpenAIChatModel) {
|
||||||
|
return calculateOpenAIChatTokenCost(model as OpenAIChatModel, numTokens, isCompletion);
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const calculateOpenAIChatTokenCost = (
|
||||||
|
model: OpenAIChatModel,
|
||||||
|
numTokens: number,
|
||||||
|
isCompletion: boolean
|
||||||
|
) => {
|
||||||
|
const tokensToDollars = isCompletion ? openAICompletionTokensToDollars[model] : openAIPromptTokensToDollars[model];
|
||||||
|
return tokensToDollars * numTokens;
|
||||||
|
};
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
import { type ChatCompletion } from "openai/resources/chat";
|
import { type ChatCompletion } from "openai/resources/chat";
|
||||||
import { GPTTokens } from "gpt-tokens";
|
import { GPTTokens } from "gpt-tokens";
|
||||||
import { type OpenAIChatModels } from "~/server/types";
|
import { type OpenAIChatModel } from "~/server/types";
|
||||||
|
|
||||||
interface GPTTokensMessageItem {
|
interface GPTTokensMessageItem {
|
||||||
name?: string;
|
name?: string;
|
||||||
@@ -9,7 +9,7 @@ interface GPTTokensMessageItem {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export const countOpenAIChatTokens = (
|
export const countOpenAIChatTokens = (
|
||||||
model: OpenAIChatModels,
|
model: OpenAIChatModel,
|
||||||
messages: ChatCompletion.Choice.Message[]
|
messages: ChatCompletion.Choice.Message[]
|
||||||
) => {
|
) => {
|
||||||
return new GPTTokens({ model, messages: messages as unknown as GPTTokensMessageItem[] })
|
return new GPTTokens({ model, messages: messages as unknown as GPTTokensMessageItem[] })
|
||||||
|
|||||||
Reference in New Issue
Block a user