Add total token cost to variant stats (#13)

* Add total token cost to variant stats

* Copy over token counts for new variants

* Update invalidate call
This commit is contained in:
arcticfly
2023-07-06 15:33:49 -07:00
committed by GitHub
parent 1fa0d7bc62
commit fe501a80cb
8 changed files with 91 additions and 40 deletions

View File

@@ -105,7 +105,7 @@ export default function EditEvaluations() {
const [onDelete] = useHandledAsyncCallback(async (id: string) => {
await deleteMutation.mutateAsync({ id });
await utils.evaluations.list.invalidate();
await utils.evaluations.results.invalidate();
await utils.promptVariants.stats.invalidate();
}, []);
const [onSave] = useHandledAsyncCallback(async (id: string | undefined, vals: EvalValues) => {
@@ -124,7 +124,7 @@ export default function EditEvaluations() {
});
}
await utils.evaluations.list.invalidate();
await utils.evaluations.results.invalidate();
await utils.promptVariants.stats.invalidate();
}, []);
const onCancel = useCallback(() => {

View File

@@ -65,7 +65,7 @@ export default function OutputCell({
channel,
});
setOutput(output);
await utils.evaluations.results.invalidate();
await utils.promptVariants.stats.invalidate();
}, [outputMutation, scenario.id, variant.id, channel]);
useEffect(fetchOutput, []);

View File

@@ -1,14 +1,14 @@
import { HStack, Text, useToken } from "@chakra-ui/react";
import { HStack, Icon, Text, useToken } from "@chakra-ui/react";
import { type PromptVariant } from "./types";
import { cellPadding } from "../constants";
import { api } from "~/utils/api";
import chroma from "chroma-js";
import { BsCurrencyDollar } from "react-icons/bs";
export default function VariantStats(props: { variant: PromptVariant }) {
const evalResults =
api.evaluations.results.useQuery({
variantId: props.variant.id,
}).data ?? [];
const { evalResults, overallCost } = api.promptVariants.stats.useQuery({
variantId: props.variant.id,
}).data ?? { evalResults: [] };
const [passColor, neutralColor, failColor] = useToken("colors", [
"green.500",
@@ -18,21 +18,29 @@ export default function VariantStats(props: { variant: PromptVariant }) {
const scale = chroma.scale([failColor, neutralColor, passColor]).domain([0, 0.5, 1]);
if (!(evalResults.length > 0)) return null;
if (!(evalResults.length > 0) && !overallCost) return null;
return (
<HStack px={cellPadding.x} py={cellPadding.y} fontSize="sm">
{evalResults.map((result) => {
const passedFrac = result.passCount / (result.passCount + result.failCount);
return (
<HStack key={result.id}>
<Text>{result.evaluation.name}</Text>
<Text color={scale(passedFrac).hex()} fontWeight="bold">
{(passedFrac * 100).toFixed(1)}%
</Text>
</HStack>
);
})}
<HStack justifyContent="space-between" alignItems="center" mx="2">
<HStack px={cellPadding.x} py={cellPadding.y} fontSize="sm">
{evalResults.map((result) => {
const passedFrac = result.passCount / (result.passCount + result.failCount);
return (
<HStack key={result.id}>
<Text>{result.evaluation.name}</Text>
<Text color={scale(passedFrac).hex()} fontWeight="bold">
{(passedFrac * 100).toFixed(1)}%
</Text>
</HStack>
);
})}
</HStack>
{overallCost && (
<HStack spacing={0} align="center" color="gray.500" fontSize="xs" my="2">
<Icon as={BsCurrencyDollar} />
<Text mr={1}>{overallCost.toFixed(3)}</Text>
</HStack>
)}
</HStack>
);
}

View File

@@ -14,15 +14,6 @@ export const evaluationsRouter = createTRPCRouter({
});
}),
results: publicProcedure.input(z.object({ variantId: z.string() })).query(async ({ input }) => {
return await prisma.evaluationResult.findMany({
where: {
promptVariantId: input.variantId,
},
include: { evaluation: true },
});
}),
create: publicProcedure
.input(
z.object({

View File

@@ -62,6 +62,8 @@ export const modelOutputsRouter = createTRPCRouter({
statusCode: existingResponse.statusCode,
errorMessage: existingResponse.errorMessage,
timeToComplete: existingResponse.timeToComplete,
promptTokens: existingResponse.promptTokens ?? undefined,
completionTokens: existingResponse.completionTokens ?? undefined,
};
} else {
modelResponse = await getCompletion(filledTemplate, input.channel);

View File

@@ -2,6 +2,8 @@ import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { type OpenAIChatConfig } from "~/server/types";
import { getModelName } from "~/server/utils/getModelName";
import { calculateTokenCost } from "~/utils/calculateTokenCost";
export const promptVariantsRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
@@ -14,6 +16,46 @@ export const promptVariantsRouter = createTRPCRouter({
});
}),
stats: publicProcedure.input(z.object({ variantId: z.string() })).query(async ({ input }) => {
const variant = await prisma.promptVariant.findUnique({
where: {
id: input.variantId,
},
});
if (!variant) {
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
}
const evalResults = await prisma.evaluationResult.findMany({
where: {
promptVariantId: input.variantId,
},
include: { evaluation: true },
});
const overallTokens = await prisma.modelOutput.aggregate({
where: {
promptVariantId: input.variantId,
},
_sum: {
promptTokens: true,
completionTokens: true,
},
});
const model = getModelName(variant.config);
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
const overallPromptCost = calculateTokenCost(model, promptTokens);
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
const overallCompletionCost = calculateTokenCost(model, completionTokens, true);
const overallCost = overallPromptCost + overallCompletionCost;
return { evalResults, overallCost };
}),
create: publicProcedure
.input(
z.object({

View File

@@ -1,8 +1,9 @@
import { isObject } from "lodash";
import { type JSONSerializable, type SupportedModel } from "../types";
import { type Prisma } from "@prisma/client";
export function getModelName(config: JSONSerializable): SupportedModel | null {
export function getModelName(config: JSONSerializable | Prisma.JsonValue): SupportedModel | null {
if (!isObject(config)) return null;
if ("model" in config && typeof config.model === "string") return config.model as SupportedModel;
return null
return null;
}

View File

@@ -22,18 +22,25 @@ const openAICompletionTokensToDollars: { [key in OpenAIChatModel]: number } = {
"gpt-3.5-turbo-16k-0613": 0.000004,
};
export const calculateTokenCost = (model: SupportedModel, numTokens: number, isCompletion = false) => {
if (model in OpenAIChatModel) {
return calculateOpenAIChatTokenCost(model as OpenAIChatModel, numTokens, isCompletion);
}
return 0;
}
export const calculateTokenCost = (
model: SupportedModel | null,
numTokens: number,
isCompletion = false
) => {
if (!model) return 0;
if (model in OpenAIChatModel) {
return calculateOpenAIChatTokenCost(model as OpenAIChatModel, numTokens, isCompletion);
}
return 0;
};
const calculateOpenAIChatTokenCost = (
model: OpenAIChatModel,
numTokens: number,
isCompletion: boolean
) => {
const tokensToDollars = isCompletion ? openAICompletionTokensToDollars[model] : openAIPromptTokensToDollars[model];
return tokensToDollars * numTokens;
const tokensToDollars = isCompletion
? openAICompletionTokensToDollars[model]
: openAIPromptTokensToDollars[model];
return tokensToDollars * numTokens;
};