Add request cost to OutputStats (#12)

This commit is contained in:
arcticfly
2023-07-06 14:36:31 -07:00
committed by GitHub
parent f728027ef6
commit 92c240e7b8
6 changed files with 83 additions and 16 deletions

View File

@@ -6,13 +6,16 @@ import SyntaxHighlighter from "react-syntax-highlighter";
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs"; import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
import stringify from "json-stringify-pretty-compact"; import stringify from "json-stringify-pretty-compact";
import { useMemo, type ReactElement } from "react"; import { useMemo, type ReactElement } from "react";
import { BsCheck, BsClock, BsX } from "react-icons/bs"; import { BsCheck, BsClock, BsX, BsCurrencyDollar } from "react-icons/bs";
import { type ModelOutput } from "@prisma/client"; import { type ModelOutput } from "@prisma/client";
import { type ChatCompletion } from "openai/resources/chat"; import { type ChatCompletion } from "openai/resources/chat";
import { generateChannel } from "~/utils/generateChannel"; import { generateChannel } from "~/utils/generateChannel";
import { isObject } from "lodash"; import { isObject } from "lodash";
import useSocket from "~/utils/useSocket"; import useSocket from "~/utils/useSocket";
import { evaluateOutput } from "~/server/utils/evaluateOutput"; import { evaluateOutput } from "~/server/utils/evaluateOutput";
import { calculateTokenCost } from "~/utils/calculateTokenCost";
import { type JSONSerializable, type SupportedModel } from "~/server/types";
import { getModelName } from "~/server/utils/getModelName";
export default function OutputCell({ export default function OutputCell({
scenario, scenario,
@@ -37,6 +40,8 @@ export default function OutputCell({
if (variant.config === null || Object.keys(variant.config).length === 0) if (variant.config === null || Object.keys(variant.config).length === 0)
disabledReason = "Save your prompt variant to see output"; disabledReason = "Save your prompt variant to see output";
const model = getModelName(variant.config as JSONSerializable);
const shouldStream = const shouldStream =
isObject(variant) && isObject(variant) &&
"config" in variant && "config" in variant &&
@@ -110,7 +115,7 @@ export default function OutputCell({
{ maxLength: 40 } { maxLength: 40 }
)} )}
</SyntaxHighlighter> </SyntaxHighlighter>
<OutputStats modelOutput={output.data} scenario={scenario} /> <OutputStats model={model} modelOutput={output.data} scenario={scenario} />
</Box> </Box>
); );
} }
@@ -121,15 +126,17 @@ export default function OutputCell({
return ( return (
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap"> <Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
{contentToDisplay} {contentToDisplay}
{output.data && <OutputStats modelOutput={output.data} scenario={scenario} />} {output.data && <OutputStats model={model} modelOutput={output.data} scenario={scenario} />}
</Flex> </Flex>
); );
} }
const OutputStats = ({ const OutputStats = ({
model,
modelOutput, modelOutput,
scenario, scenario,
}: { }: {
model: SupportedModel | null;
modelOutput: ModelOutput; modelOutput: ModelOutput;
scenario: Scenario; scenario: Scenario;
}) => { }) => {
@@ -138,6 +145,15 @@ const OutputStats = ({
const evals = const evals =
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? []; api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
const promptTokens = modelOutput.promptTokens;
const completionTokens = modelOutput.completionTokens;
const promptCost = promptTokens && model ? calculateTokenCost(model, promptTokens) : 0;
const completionCost =
completionTokens && model ? calculateTokenCost(model, completionTokens, true) : 0;
const cost = promptCost + completionCost;
return ( return (
<HStack align="center" color="gray.500" fontSize="xs" mt={2}> <HStack align="center" color="gray.500" fontSize="xs" mt={2}>
<HStack flex={1}> <HStack flex={1}>
@@ -155,8 +171,12 @@ const OutputStats = ({
); );
})} })}
</HStack> </HStack>
<HStack> <HStack spacing={0}>
<Icon as={BsClock} mr={0.5} /> <Icon as={BsCurrencyDollar} />
<Text mr={1}>{cost.toFixed(3)}</Text>
</HStack>
<HStack spacing={0.5}>
<Icon as={BsClock} />
<Text>{(timeToComplete / 1000).toFixed(2)}s</Text> <Text>{(timeToComplete / 1000).toFixed(2)}s</Text>
</HStack> </HStack>
</HStack> </HStack>

View File

@@ -9,15 +9,15 @@ export type JSONSerializable =
// Placeholder for now // Placeholder for now
export type OpenAIChatConfig = NonNullable<JSONSerializable>; export type OpenAIChatConfig = NonNullable<JSONSerializable>;
export enum OpenAIChatModels { export enum OpenAIChatModel {
"gpt-4" = "gpt-4", "gpt-4" = "gpt-4",
"gpt-4-0613" = "gpt-4-0613", "gpt-4-0613" = "gpt-4-0613",
"gpt-4-32k" = "gpt-4-32k", "gpt-4-32k" = "gpt-4-32k",
"gpt-4-32k-0613" = "gpt-4-32k-0613", "gpt-4-32k-0613" = "gpt-4-32k-0613",
"gpt-3.5-turbo" = "gpt-3.5-turbo", "gpt-3.5-turbo" = "gpt-3.5-turbo",
"gpt-3.5-turbo-16k" = "gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0613" = "gpt-3.5-turbo-0613", "gpt-3.5-turbo-0613" = "gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k" = "gpt-3.5-turbo-16k",
"gpt-3.5-turbo-16k-0613" = "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k-0613" = "gpt-3.5-turbo-16k-0613",
} }
type SupportedModel = keyof typeof OpenAIChatModels; export type SupportedModel = keyof typeof OpenAIChatModel;

View File

@@ -4,9 +4,10 @@ import { Prisma } from "@prisma/client";
import { streamChatCompletion } from "./openai"; import { streamChatCompletion } from "./openai";
import { wsConnection } from "~/utils/wsConnection"; import { wsConnection } from "~/utils/wsConnection";
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat"; import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
import { type JSONSerializable, OpenAIChatModels } from "../types"; import { type JSONSerializable, OpenAIChatModel } from "../types";
import { env } from "~/env.mjs"; import { env } from "~/env.mjs";
import { countOpenAIChatTokens } from "~/utils/countTokens"; import { countOpenAIChatTokens } from "~/utils/countTokens";
import { getModelName } from "./getModelName";
env; env;
@@ -23,7 +24,8 @@ export async function getCompletion(
payload: JSONSerializable, payload: JSONSerializable,
channel?: string channel?: string
): Promise<CompletionResponse> { ): Promise<CompletionResponse> {
if (!payload || !isObject(payload)) const modelName = getModelName(payload);
if (!modelName)
return { return {
output: Prisma.JsonNull, output: Prisma.JsonNull,
statusCode: 400, statusCode: 400,
@@ -31,9 +33,7 @@ export async function getCompletion(
timeToComplete: 0, timeToComplete: 0,
}; };
if ( if (
"model" in payload && modelName in OpenAIChatModel
typeof payload.model === "string" &&
payload.model in OpenAIChatModels
) { ) {
return getOpenAIChatCompletion( return getOpenAIChatCompletion(
payload as unknown as CompletionCreateParams, payload as unknown as CompletionCreateParams,
@@ -109,7 +109,7 @@ export async function getOpenAIChatCompletion(
resp.promptTokens = usage.prompt_tokens; resp.promptTokens = usage.prompt_tokens;
resp.completionTokens = usage.completion_tokens; resp.completionTokens = usage.completion_tokens;
} else if (isObject(resp.output) && 'choices' in resp.output) { } else if (isObject(resp.output) && 'choices' in resp.output) {
const model = payload.model as unknown as OpenAIChatModels const model = payload.model as unknown as OpenAIChatModel
resp.promptTokens = countOpenAIChatTokens( resp.promptTokens = countOpenAIChatTokens(
model, model,
payload.messages payload.messages

View File

@@ -0,0 +1,8 @@
import { isObject } from "lodash";
import { type JSONSerializable, type SupportedModel } from "../types";
export function getModelName(config: JSONSerializable): SupportedModel | null {
if (!isObject(config)) return null;
if ("model" in config && typeof config.model === "string") return config.model as SupportedModel;
return null
}

View File

@@ -0,0 +1,39 @@
import { type SupportedModel, OpenAIChatModel } from "~/server/types";
const openAIPromptTokensToDollars: { [key in OpenAIChatModel]: number } = {
"gpt-4": 0.00003,
"gpt-4-0613": 0.00003,
"gpt-4-32k": 0.00006,
"gpt-4-32k-0613": 0.00006,
"gpt-3.5-turbo": 0.0000015,
"gpt-3.5-turbo-0613": 0.0000015,
"gpt-3.5-turbo-16k": 0.000003,
"gpt-3.5-turbo-16k-0613": 0.000003,
};
const openAICompletionTokensToDollars: { [key in OpenAIChatModel]: number } = {
"gpt-4": 0.00006,
"gpt-4-0613": 0.00006,
"gpt-4-32k": 0.00012,
"gpt-4-32k-0613": 0.00012,
"gpt-3.5-turbo": 0.000002,
"gpt-3.5-turbo-0613": 0.000002,
"gpt-3.5-turbo-16k": 0.000004,
"gpt-3.5-turbo-16k-0613": 0.000004,
};
export const calculateTokenCost = (model: SupportedModel, numTokens: number, isCompletion = false) => {
if (model in OpenAIChatModel) {
return calculateOpenAIChatTokenCost(model as OpenAIChatModel, numTokens, isCompletion);
}
return 0;
}
const calculateOpenAIChatTokenCost = (
model: OpenAIChatModel,
numTokens: number,
isCompletion: boolean
) => {
const tokensToDollars = isCompletion ? openAICompletionTokensToDollars[model] : openAIPromptTokensToDollars[model];
return tokensToDollars * numTokens;
};

View File

@@ -1,6 +1,6 @@
import { type ChatCompletion } from "openai/resources/chat"; import { type ChatCompletion } from "openai/resources/chat";
import { GPTTokens } from "gpt-tokens"; import { GPTTokens } from "gpt-tokens";
import { type OpenAIChatModels } from "~/server/types"; import { type OpenAIChatModel } from "~/server/types";
interface GPTTokensMessageItem { interface GPTTokensMessageItem {
name?: string; name?: string;
@@ -9,7 +9,7 @@ interface GPTTokensMessageItem {
} }
export const countOpenAIChatTokens = ( export const countOpenAIChatTokens = (
model: OpenAIChatModels, model: OpenAIChatModel,
messages: ChatCompletion.Choice.Message[] messages: ChatCompletion.Choice.Message[]
) => { ) => {
return new GPTTokens({ model, messages: messages as unknown as GPTTokensMessageItem[] }) return new GPTTokens({ model, messages: messages as unknown as GPTTokensMessageItem[] })