Add total token cost to variant stats (#13)
* Add total token cost to variant stats * Copy over token counts for new variants * Update invalidate call
This commit is contained in:
@@ -105,7 +105,7 @@ export default function EditEvaluations() {
|
|||||||
const [onDelete] = useHandledAsyncCallback(async (id: string) => {
|
const [onDelete] = useHandledAsyncCallback(async (id: string) => {
|
||||||
await deleteMutation.mutateAsync({ id });
|
await deleteMutation.mutateAsync({ id });
|
||||||
await utils.evaluations.list.invalidate();
|
await utils.evaluations.list.invalidate();
|
||||||
await utils.evaluations.results.invalidate();
|
await utils.promptVariants.stats.invalidate();
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const [onSave] = useHandledAsyncCallback(async (id: string | undefined, vals: EvalValues) => {
|
const [onSave] = useHandledAsyncCallback(async (id: string | undefined, vals: EvalValues) => {
|
||||||
@@ -124,7 +124,7 @@ export default function EditEvaluations() {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
await utils.evaluations.list.invalidate();
|
await utils.evaluations.list.invalidate();
|
||||||
await utils.evaluations.results.invalidate();
|
await utils.promptVariants.stats.invalidate();
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const onCancel = useCallback(() => {
|
const onCancel = useCallback(() => {
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ export default function OutputCell({
|
|||||||
channel,
|
channel,
|
||||||
});
|
});
|
||||||
setOutput(output);
|
setOutput(output);
|
||||||
await utils.evaluations.results.invalidate();
|
await utils.promptVariants.stats.invalidate();
|
||||||
}, [outputMutation, scenario.id, variant.id, channel]);
|
}, [outputMutation, scenario.id, variant.id, channel]);
|
||||||
|
|
||||||
useEffect(fetchOutput, []);
|
useEffect(fetchOutput, []);
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
import { HStack, Text, useToken } from "@chakra-ui/react";
|
import { HStack, Icon, Text, useToken } from "@chakra-ui/react";
|
||||||
import { type PromptVariant } from "./types";
|
import { type PromptVariant } from "./types";
|
||||||
import { cellPadding } from "../constants";
|
import { cellPadding } from "../constants";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import chroma from "chroma-js";
|
import chroma from "chroma-js";
|
||||||
|
import { BsCurrencyDollar } from "react-icons/bs";
|
||||||
|
|
||||||
export default function VariantStats(props: { variant: PromptVariant }) {
|
export default function VariantStats(props: { variant: PromptVariant }) {
|
||||||
const evalResults =
|
const { evalResults, overallCost } = api.promptVariants.stats.useQuery({
|
||||||
api.evaluations.results.useQuery({
|
|
||||||
variantId: props.variant.id,
|
variantId: props.variant.id,
|
||||||
}).data ?? [];
|
}).data ?? { evalResults: [] };
|
||||||
|
|
||||||
const [passColor, neutralColor, failColor] = useToken("colors", [
|
const [passColor, neutralColor, failColor] = useToken("colors", [
|
||||||
"green.500",
|
"green.500",
|
||||||
@@ -18,9 +18,10 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
|||||||
|
|
||||||
const scale = chroma.scale([failColor, neutralColor, passColor]).domain([0, 0.5, 1]);
|
const scale = chroma.scale([failColor, neutralColor, passColor]).domain([0, 0.5, 1]);
|
||||||
|
|
||||||
if (!(evalResults.length > 0)) return null;
|
if (!(evalResults.length > 0) && !overallCost) return null;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
<HStack justifyContent="space-between" alignItems="center" mx="2">
|
||||||
<HStack px={cellPadding.x} py={cellPadding.y} fontSize="sm">
|
<HStack px={cellPadding.x} py={cellPadding.y} fontSize="sm">
|
||||||
{evalResults.map((result) => {
|
{evalResults.map((result) => {
|
||||||
const passedFrac = result.passCount / (result.passCount + result.failCount);
|
const passedFrac = result.passCount / (result.passCount + result.failCount);
|
||||||
@@ -34,5 +35,12 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
|||||||
);
|
);
|
||||||
})}
|
})}
|
||||||
</HStack>
|
</HStack>
|
||||||
|
{overallCost && (
|
||||||
|
<HStack spacing={0} align="center" color="gray.500" fontSize="xs" my="2">
|
||||||
|
<Icon as={BsCurrencyDollar} />
|
||||||
|
<Text mr={1}>{overallCost.toFixed(3)}</Text>
|
||||||
|
</HStack>
|
||||||
|
)}
|
||||||
|
</HStack>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,15 +14,6 @@ export const evaluationsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
}),
|
}),
|
||||||
|
|
||||||
results: publicProcedure.input(z.object({ variantId: z.string() })).query(async ({ input }) => {
|
|
||||||
return await prisma.evaluationResult.findMany({
|
|
||||||
where: {
|
|
||||||
promptVariantId: input.variantId,
|
|
||||||
},
|
|
||||||
include: { evaluation: true },
|
|
||||||
});
|
|
||||||
}),
|
|
||||||
|
|
||||||
create: publicProcedure
|
create: publicProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
|
|||||||
@@ -62,6 +62,8 @@ export const modelOutputsRouter = createTRPCRouter({
|
|||||||
statusCode: existingResponse.statusCode,
|
statusCode: existingResponse.statusCode,
|
||||||
errorMessage: existingResponse.errorMessage,
|
errorMessage: existingResponse.errorMessage,
|
||||||
timeToComplete: existingResponse.timeToComplete,
|
timeToComplete: existingResponse.timeToComplete,
|
||||||
|
promptTokens: existingResponse.promptTokens ?? undefined,
|
||||||
|
completionTokens: existingResponse.completionTokens ?? undefined,
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
modelResponse = await getCompletion(filledTemplate, input.channel);
|
modelResponse = await getCompletion(filledTemplate, input.channel);
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ import { z } from "zod";
|
|||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import { type OpenAIChatConfig } from "~/server/types";
|
import { type OpenAIChatConfig } from "~/server/types";
|
||||||
|
import { getModelName } from "~/server/utils/getModelName";
|
||||||
|
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
||||||
|
|
||||||
export const promptVariantsRouter = createTRPCRouter({
|
export const promptVariantsRouter = createTRPCRouter({
|
||||||
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
||||||
@@ -14,6 +16,46 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
}),
|
}),
|
||||||
|
|
||||||
|
stats: publicProcedure.input(z.object({ variantId: z.string() })).query(async ({ input }) => {
|
||||||
|
const variant = await prisma.promptVariant.findUnique({
|
||||||
|
where: {
|
||||||
|
id: input.variantId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!variant) {
|
||||||
|
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const evalResults = await prisma.evaluationResult.findMany({
|
||||||
|
where: {
|
||||||
|
promptVariantId: input.variantId,
|
||||||
|
},
|
||||||
|
include: { evaluation: true },
|
||||||
|
});
|
||||||
|
|
||||||
|
const overallTokens = await prisma.modelOutput.aggregate({
|
||||||
|
where: {
|
||||||
|
promptVariantId: input.variantId,
|
||||||
|
},
|
||||||
|
_sum: {
|
||||||
|
promptTokens: true,
|
||||||
|
completionTokens: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const model = getModelName(variant.config);
|
||||||
|
|
||||||
|
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
|
||||||
|
const overallPromptCost = calculateTokenCost(model, promptTokens);
|
||||||
|
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
|
||||||
|
const overallCompletionCost = calculateTokenCost(model, completionTokens, true);
|
||||||
|
|
||||||
|
const overallCost = overallPromptCost + overallCompletionCost;
|
||||||
|
|
||||||
|
return { evalResults, overallCost };
|
||||||
|
}),
|
||||||
|
|
||||||
create: publicProcedure
|
create: publicProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import { isObject } from "lodash";
|
import { isObject } from "lodash";
|
||||||
import { type JSONSerializable, type SupportedModel } from "../types";
|
import { type JSONSerializable, type SupportedModel } from "../types";
|
||||||
|
import { type Prisma } from "@prisma/client";
|
||||||
|
|
||||||
export function getModelName(config: JSONSerializable): SupportedModel | null {
|
export function getModelName(config: JSONSerializable | Prisma.JsonValue): SupportedModel | null {
|
||||||
if (!isObject(config)) return null;
|
if (!isObject(config)) return null;
|
||||||
if ("model" in config && typeof config.model === "string") return config.model as SupportedModel;
|
if ("model" in config && typeof config.model === "string") return config.model as SupportedModel;
|
||||||
return null
|
return null;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,18 +22,25 @@ const openAICompletionTokensToDollars: { [key in OpenAIChatModel]: number } = {
|
|||||||
"gpt-3.5-turbo-16k-0613": 0.000004,
|
"gpt-3.5-turbo-16k-0613": 0.000004,
|
||||||
};
|
};
|
||||||
|
|
||||||
export const calculateTokenCost = (model: SupportedModel, numTokens: number, isCompletion = false) => {
|
export const calculateTokenCost = (
|
||||||
|
model: SupportedModel | null,
|
||||||
|
numTokens: number,
|
||||||
|
isCompletion = false
|
||||||
|
) => {
|
||||||
|
if (!model) return 0;
|
||||||
if (model in OpenAIChatModel) {
|
if (model in OpenAIChatModel) {
|
||||||
return calculateOpenAIChatTokenCost(model as OpenAIChatModel, numTokens, isCompletion);
|
return calculateOpenAIChatTokenCost(model as OpenAIChatModel, numTokens, isCompletion);
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
};
|
||||||
|
|
||||||
const calculateOpenAIChatTokenCost = (
|
const calculateOpenAIChatTokenCost = (
|
||||||
model: OpenAIChatModel,
|
model: OpenAIChatModel,
|
||||||
numTokens: number,
|
numTokens: number,
|
||||||
isCompletion: boolean
|
isCompletion: boolean
|
||||||
) => {
|
) => {
|
||||||
const tokensToDollars = isCompletion ? openAICompletionTokensToDollars[model] : openAIPromptTokensToDollars[model];
|
const tokensToDollars = isCompletion
|
||||||
|
? openAICompletionTokensToDollars[model]
|
||||||
|
: openAIPromptTokensToDollars[model];
|
||||||
return tokensToDollars * numTokens;
|
return tokensToDollars * numTokens;
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user