From fe501a80cb7516873675fec169b562f7c7d4b242 Mon Sep 17 00:00:00 2001 From: arcticfly <41524992+arcticfly@users.noreply.github.com> Date: Thu, 6 Jul 2023 15:33:49 -0700 Subject: [PATCH] Add total token cost to variant stats (#13) * Add total token cost to variant stats * Copy over token counts for new variants * Update invalidate call --- .../OutputsTable/EditEvaluations.tsx | 4 +- src/components/OutputsTable/OutputCell.tsx | 2 +- src/components/OutputsTable/VariantStats.tsx | 44 +++++++++++-------- src/server/api/routers/evaluations.router.ts | 9 ---- src/server/api/routers/modelOutputs.router.ts | 2 + .../api/routers/promptVariants.router.ts | 42 ++++++++++++++++++ src/server/utils/getModelName.ts | 5 ++- src/utils/calculateTokenCost.ts | 23 ++++++---- 8 files changed, 91 insertions(+), 40 deletions(-) diff --git a/src/components/OutputsTable/EditEvaluations.tsx b/src/components/OutputsTable/EditEvaluations.tsx index 6f95f40..3aaeb3a 100644 --- a/src/components/OutputsTable/EditEvaluations.tsx +++ b/src/components/OutputsTable/EditEvaluations.tsx @@ -105,7 +105,7 @@ export default function EditEvaluations() { const [onDelete] = useHandledAsyncCallback(async (id: string) => { await deleteMutation.mutateAsync({ id }); await utils.evaluations.list.invalidate(); - await utils.evaluations.results.invalidate(); + await utils.promptVariants.stats.invalidate(); }, []); const [onSave] = useHandledAsyncCallback(async (id: string | undefined, vals: EvalValues) => { @@ -124,7 +124,7 @@ export default function EditEvaluations() { }); } await utils.evaluations.list.invalidate(); - await utils.evaluations.results.invalidate(); + await utils.promptVariants.stats.invalidate(); }, []); const onCancel = useCallback(() => { diff --git a/src/components/OutputsTable/OutputCell.tsx b/src/components/OutputsTable/OutputCell.tsx index ffca43b..5594466 100644 --- a/src/components/OutputsTable/OutputCell.tsx +++ b/src/components/OutputsTable/OutputCell.tsx @@ -65,7 +65,7 @@ export default function OutputCell({ channel, }); setOutput(output); - await utils.evaluations.results.invalidate(); + await utils.promptVariants.stats.invalidate(); }, [outputMutation, scenario.id, variant.id, channel]); useEffect(fetchOutput, []); diff --git a/src/components/OutputsTable/VariantStats.tsx b/src/components/OutputsTable/VariantStats.tsx index 4143f54..c2b65b0 100644 --- a/src/components/OutputsTable/VariantStats.tsx +++ b/src/components/OutputsTable/VariantStats.tsx @@ -1,14 +1,14 @@ -import { HStack, Text, useToken } from "@chakra-ui/react"; +import { HStack, Icon, Text, useToken } from "@chakra-ui/react"; import { type PromptVariant } from "./types"; import { cellPadding } from "../constants"; import { api } from "~/utils/api"; import chroma from "chroma-js"; +import { BsCurrencyDollar } from "react-icons/bs"; export default function VariantStats(props: { variant: PromptVariant }) { - const evalResults = - api.evaluations.results.useQuery({ - variantId: props.variant.id, - }).data ?? []; + const { evalResults, overallCost } = api.promptVariants.stats.useQuery({ + variantId: props.variant.id, + }).data ?? { evalResults: [] }; const [passColor, neutralColor, failColor] = useToken("colors", [ "green.500", @@ -18,21 +18,29 @@ export default function VariantStats(props: { variant: PromptVariant }) { const scale = chroma.scale([failColor, neutralColor, passColor]).domain([0, 0.5, 1]); - if (!(evalResults.length > 0)) return null; + if (!(evalResults.length > 0) && !overallCost) return null; return ( - - {evalResults.map((result) => { - const passedFrac = result.passCount / (result.passCount + result.failCount); - return ( - - {result.evaluation.name} - - {(passedFrac * 100).toFixed(1)}% - - - ); - })} + + + {evalResults.map((result) => { + const passedFrac = result.passCount / (result.passCount + result.failCount); + return ( + + {result.evaluation.name} + + {(passedFrac * 100).toFixed(1)}% + + + ); + })} + + {overallCost && ( + + + {overallCost.toFixed(3)} + + )} ); } diff --git a/src/server/api/routers/evaluations.router.ts b/src/server/api/routers/evaluations.router.ts index 3f60e1d..99a98b0 100644 --- a/src/server/api/routers/evaluations.router.ts +++ b/src/server/api/routers/evaluations.router.ts @@ -14,15 +14,6 @@ export const evaluationsRouter = createTRPCRouter({ }); }), - results: publicProcedure.input(z.object({ variantId: z.string() })).query(async ({ input }) => { - return await prisma.evaluationResult.findMany({ - where: { - promptVariantId: input.variantId, - }, - include: { evaluation: true }, - }); - }), - create: publicProcedure .input( z.object({ diff --git a/src/server/api/routers/modelOutputs.router.ts b/src/server/api/routers/modelOutputs.router.ts index a726712..12d0192 100644 --- a/src/server/api/routers/modelOutputs.router.ts +++ b/src/server/api/routers/modelOutputs.router.ts @@ -62,6 +62,8 @@ export const modelOutputsRouter = createTRPCRouter({ statusCode: existingResponse.statusCode, errorMessage: existingResponse.errorMessage, timeToComplete: existingResponse.timeToComplete, + promptTokens: existingResponse.promptTokens ?? undefined, + completionTokens: existingResponse.completionTokens ?? undefined, }; } else { modelResponse = await getCompletion(filledTemplate, input.channel); diff --git a/src/server/api/routers/promptVariants.router.ts b/src/server/api/routers/promptVariants.router.ts index c519814..ff7b597 100644 --- a/src/server/api/routers/promptVariants.router.ts +++ b/src/server/api/routers/promptVariants.router.ts @@ -2,6 +2,8 @@ import { z } from "zod"; import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { prisma } from "~/server/db"; import { type OpenAIChatConfig } from "~/server/types"; +import { getModelName } from "~/server/utils/getModelName"; +import { calculateTokenCost } from "~/utils/calculateTokenCost"; export const promptVariantsRouter = createTRPCRouter({ list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { @@ -14,6 +16,46 @@ export const promptVariantsRouter = createTRPCRouter({ }); }), + stats: publicProcedure.input(z.object({ variantId: z.string() })).query(async ({ input }) => { + const variant = await prisma.promptVariant.findUnique({ + where: { + id: input.variantId, + }, + }); + + if (!variant) { + throw new Error(`Prompt Variant with id ${input.variantId} does not exist`); + } + + const evalResults = await prisma.evaluationResult.findMany({ + where: { + promptVariantId: input.variantId, + }, + include: { evaluation: true }, + }); + + const overallTokens = await prisma.modelOutput.aggregate({ + where: { + promptVariantId: input.variantId, + }, + _sum: { + promptTokens: true, + completionTokens: true, + }, + }); + + const model = getModelName(variant.config); + + const promptTokens = overallTokens._sum?.promptTokens ?? 0; + const overallPromptCost = calculateTokenCost(model, promptTokens); + const completionTokens = overallTokens._sum?.completionTokens ?? 0; + const overallCompletionCost = calculateTokenCost(model, completionTokens, true); + + const overallCost = overallPromptCost + overallCompletionCost; + + return { evalResults, overallCost }; + }), + create: publicProcedure .input( z.object({ diff --git a/src/server/utils/getModelName.ts b/src/server/utils/getModelName.ts index 5118ea1..4fcaef1 100644 --- a/src/server/utils/getModelName.ts +++ b/src/server/utils/getModelName.ts @@ -1,8 +1,9 @@ import { isObject } from "lodash"; import { type JSONSerializable, type SupportedModel } from "../types"; +import { type Prisma } from "@prisma/client"; -export function getModelName(config: JSONSerializable): SupportedModel | null { +export function getModelName(config: JSONSerializable | Prisma.JsonValue): SupportedModel | null { if (!isObject(config)) return null; if ("model" in config && typeof config.model === "string") return config.model as SupportedModel; - return null + return null; } diff --git a/src/utils/calculateTokenCost.ts b/src/utils/calculateTokenCost.ts index 978f7c5..274fd89 100644 --- a/src/utils/calculateTokenCost.ts +++ b/src/utils/calculateTokenCost.ts @@ -22,18 +22,25 @@ const openAICompletionTokensToDollars: { [key in OpenAIChatModel]: number } = { "gpt-3.5-turbo-16k-0613": 0.000004, }; -export const calculateTokenCost = (model: SupportedModel, numTokens: number, isCompletion = false) => { - if (model in OpenAIChatModel) { - return calculateOpenAIChatTokenCost(model as OpenAIChatModel, numTokens, isCompletion); - } - return 0; -} +export const calculateTokenCost = ( + model: SupportedModel | null, + numTokens: number, + isCompletion = false +) => { + if (!model) return 0; + if (model in OpenAIChatModel) { + return calculateOpenAIChatTokenCost(model as OpenAIChatModel, numTokens, isCompletion); + } + return 0; +}; const calculateOpenAIChatTokenCost = ( model: OpenAIChatModel, numTokens: number, isCompletion: boolean ) => { - const tokensToDollars = isCompletion ? openAICompletionTokensToDollars[model] : openAIPromptTokensToDollars[model]; - return tokensToDollars * numTokens; + const tokensToDollars = isCompletion + ? openAICompletionTokensToDollars[model] + : openAIPromptTokensToDollars[model]; + return tokensToDollars * numTokens; };