Record model and cost when reporting logs (#136)

* Rename prompt and completion tokens to input and output tokens

* Add getUsage function

* Record model and cost when reporting log

* Remove unused imports

* Move UsageGraph to its own component

* Standardize model response fields

* Fix types
This commit is contained in:
arcticfly
2023-08-11 13:56:47 -07:00
committed by GitHub
parent f270579283
commit 8d1ee62ff1
24 changed files with 295 additions and 199 deletions

View File

@@ -0,0 +1,66 @@
/*
Warnings:
- You are about to rename the column `completionTokens` to `outputTokens` on the `ModelResponse` table.
- You are about to rename the column `promptTokens` to `inputTokens` on the `ModelResponse` table.
- You are about to rename the column `startTime` on the `LoggedCall` table to `requestedAt`. Ensure compatibility with application logic.
- You are about to rename the column `startTime` on the `LoggedCallModelResponse` table to `requestedAt`. Ensure compatibility with application logic.
- You are about to rename the column `endTime` on the `LoggedCallModelResponse` table to `receivedAt`. Ensure compatibility with application logic.
- You are about to rename the column `error` on the `LoggedCallModelResponse` table to `errorMessage`. Ensure compatibility with application logic.
- You are about to rename the column `respStatus` on the `LoggedCallModelResponse` table to `statusCode`. Ensure compatibility with application logic.
- You are about to rename the column `totalCost` on the `LoggedCallModelResponse` table to `cost`. Ensure compatibility with application logic.
- You are about to rename the column `inputHash` on the `ModelResponse` table to `cacheKey`. Ensure compatibility with application logic.
- You are about to rename the column `output` on the `ModelResponse` table to `respPayload`. Ensure compatibility with application logic.
*/
-- DropIndex
DROP INDEX "LoggedCall_startTime_idx";
-- DropIndex
DROP INDEX "ModelResponse_inputHash_idx";
-- Rename completionTokens to outputTokens
ALTER TABLE "ModelResponse"
RENAME COLUMN "completionTokens" TO "outputTokens";
-- Rename promptTokens to inputTokens
ALTER TABLE "ModelResponse"
RENAME COLUMN "promptTokens" TO "inputTokens";
-- AlterTable
ALTER TABLE "LoggedCall"
RENAME COLUMN "startTime" TO "requestedAt";
-- AlterTable
ALTER TABLE "LoggedCallModelResponse"
RENAME COLUMN "startTime" TO "requestedAt";
-- AlterTable
ALTER TABLE "LoggedCallModelResponse"
RENAME COLUMN "endTime" TO "receivedAt";
-- AlterTable
ALTER TABLE "LoggedCallModelResponse"
RENAME COLUMN "error" TO "errorMessage";
-- AlterTable
ALTER TABLE "LoggedCallModelResponse"
RENAME COLUMN "respStatus" TO "statusCode";
-- AlterTable
ALTER TABLE "LoggedCallModelResponse"
RENAME COLUMN "totalCost" TO "cost";
-- AlterTable
ALTER TABLE "ModelResponse"
RENAME COLUMN "inputHash" TO "cacheKey";
-- AlterTable
ALTER TABLE "ModelResponse"
RENAME COLUMN "output" TO "respPayload";
-- CreateIndex
CREATE INDEX "LoggedCall_requestedAt_idx" ON "LoggedCall"("requestedAt");
-- CreateIndex
CREATE INDEX "ModelResponse_cacheKey_idx" ON "ModelResponse"("cacheKey");

View File

@@ -112,13 +112,13 @@ model ScenarioVariantCell {
model ModelResponse { model ModelResponse {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
inputHash String cacheKey String
requestedAt DateTime? requestedAt DateTime?
receivedAt DateTime? receivedAt DateTime?
output Json? respPayload Json?
cost Float? cost Float?
promptTokens Int? inputTokens Int?
completionTokens Int? outputTokens Int?
statusCode Int? statusCode Int?
errorMessage String? errorMessage String?
retryTime DateTime? retryTime DateTime?
@@ -131,7 +131,7 @@ model ModelResponse {
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade) scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
outputEvaluations OutputEvaluation[] outputEvaluations OutputEvaluation[]
@@index([inputHash]) @@index([cacheKey])
} }
enum EvalType { enum EvalType {
@@ -256,7 +256,7 @@ model WorldChampEntrant {
model LoggedCall { model LoggedCall {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
startTime DateTime requestedAt DateTime
// True if this call was served from the cache, false otherwise // True if this call was served from the cache, false otherwise
cacheHit Boolean cacheHit Boolean
@@ -278,7 +278,7 @@ model LoggedCall {
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
@@index([startTime]) @@index([requestedAt])
} }
model LoggedCallModelResponse { model LoggedCallModelResponse {
@@ -287,14 +287,14 @@ model LoggedCallModelResponse {
reqPayload Json reqPayload Json
// The HTTP status returned by the model provider // The HTTP status returned by the model provider
respStatus Int? statusCode Int?
respPayload Json? respPayload Json?
// Should be null if the request was successful, and some string if the request failed. // Should be null if the request was successful, and some string if the request failed.
error String? errorMessage String?
startTime DateTime requestedAt DateTime
endTime DateTime receivedAt DateTime
// Note: the function to calculate the cacheKey should include the project // Note: the function to calculate the cacheKey should include the project
// ID so we don't share cached responses between projects, which could be an // ID so we don't share cached responses between projects, which could be an
@@ -308,7 +308,7 @@ model LoggedCallModelResponse {
outputTokens Int? outputTokens Int?
finishReason String? finishReason String?
completionId String? completionId String?
totalCost Decimal? @db.Decimal(18, 12) cost Decimal? @db.Decimal(18, 12)
// The LoggedCall that created this LoggedCallModelResponse // The LoggedCall that created this LoggedCallModelResponse
originalLoggedCallId String @unique @db.Uuid originalLoggedCallId String @unique @db.Uuid

View File

@@ -339,17 +339,17 @@ for (let i = 0; i < 1437; i++) {
MODEL_RESPONSE_TEMPLATES[Math.floor(Math.random() * MODEL_RESPONSE_TEMPLATES.length)]!; MODEL_RESPONSE_TEMPLATES[Math.floor(Math.random() * MODEL_RESPONSE_TEMPLATES.length)]!;
const model = template.reqPayload.model; const model = template.reqPayload.model;
// choose random time in the last two weeks, with a bias towards the last few days // choose random time in the last two weeks, with a bias towards the last few days
const startTime = new Date(Date.now() - Math.pow(Math.random(), 2) * 1000 * 60 * 60 * 24 * 14); const requestedAt = new Date(Date.now() - Math.pow(Math.random(), 2) * 1000 * 60 * 60 * 24 * 14);
// choose random delay anywhere from 2 to 10 seconds later for gpt-4, or 1 to 5 seconds for gpt-3.5 // choose random delay anywhere from 2 to 10 seconds later for gpt-4, or 1 to 5 seconds for gpt-3.5
const delay = const delay =
model === "gpt-4" ? 1000 * 2 + Math.random() * 1000 * 8 : 1000 + Math.random() * 1000 * 4; model === "gpt-4" ? 1000 * 2 + Math.random() * 1000 * 8 : 1000 + Math.random() * 1000 * 4;
const endTime = new Date(startTime.getTime() + delay); const receivedAt = new Date(requestedAt.getTime() + delay);
loggedCallsToCreate.push({ loggedCallsToCreate.push({
id: loggedCallId, id: loggedCallId,
cacheHit: false, cacheHit: false,
startTime, requestedAt,
projectId: project.id, projectId: project.id,
createdAt: startTime, createdAt: requestedAt,
}); });
const { promptTokenPrice, completionTokenPrice } = const { promptTokenPrice, completionTokenPrice } =
@@ -365,21 +365,20 @@ for (let i = 0; i < 1437; i++) {
loggedCallModelResponsesToCreate.push({ loggedCallModelResponsesToCreate.push({
id: loggedCallModelResponseId, id: loggedCallModelResponseId,
startTime, requestedAt,
endTime, receivedAt,
originalLoggedCallId: loggedCallId, originalLoggedCallId: loggedCallId,
reqPayload: template.reqPayload, reqPayload: template.reqPayload,
respPayload: template.respPayload, respPayload: template.respPayload,
respStatus: template.respStatus, statusCode: template.respStatus,
error: template.error, errorMessage: template.error,
createdAt: startTime, createdAt: requestedAt,
cacheKey: hashRequest(project.id, template.reqPayload as JsonValue), cacheKey: hashRequest(project.id, template.reqPayload as JsonValue),
durationMs: endTime.getTime() - startTime.getTime(), durationMs: receivedAt.getTime() - requestedAt.getTime(),
inputTokens: template.inputTokens, inputTokens: template.inputTokens,
outputTokens: template.outputTokens, outputTokens: template.outputTokens,
finishReason: template.finishReason, finishReason: template.finishReason,
totalCost: cost: template.inputTokens * promptTokenPrice + template.outputTokens * completionTokenPrice,
template.inputTokens * promptTokenPrice + template.outputTokens * completionTokenPrice,
}); });
loggedCallsToUpdate.push({ loggedCallsToUpdate.push({
where: { where: {

View File

@@ -107,7 +107,7 @@ export default function OutputCell({
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>; if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
const showLogs = !streamedMessage && !mostRecentResponse?.output; const showLogs = !streamedMessage && !mostRecentResponse?.respPayload;
if (showLogs) if (showLogs)
return ( return (
@@ -160,13 +160,13 @@ export default function OutputCell({
</CellWrapper> </CellWrapper>
); );
const normalizedOutput = mostRecentResponse?.output const normalizedOutput = mostRecentResponse?.respPayload
? provider.normalizeOutput(mostRecentResponse?.output) ? provider.normalizeOutput(mostRecentResponse?.respPayload)
: streamedMessage : streamedMessage
? provider.normalizeOutput(streamedMessage) ? provider.normalizeOutput(streamedMessage)
: null; : null;
if (mostRecentResponse?.output && normalizedOutput?.type === "json") { if (mostRecentResponse?.respPayload && normalizedOutput?.type === "json") {
return ( return (
<CellWrapper> <CellWrapper>
<SyntaxHighlighter <SyntaxHighlighter

View File

@@ -19,8 +19,8 @@ export const OutputStats = ({
? modelResponse.receivedAt.getTime() - modelResponse.requestedAt.getTime() ? modelResponse.receivedAt.getTime() - modelResponse.requestedAt.getTime()
: 0; : 0;
const promptTokens = modelResponse.promptTokens; const inputTokens = modelResponse.inputTokens;
const completionTokens = modelResponse.completionTokens; const outputTokens = modelResponse.outputTokens;
return ( return (
<HStack <HStack
@@ -55,8 +55,8 @@ export const OutputStats = ({
</HStack> </HStack>
{modelResponse.cost && ( {modelResponse.cost && (
<CostTooltip <CostTooltip
promptTokens={promptTokens} inputTokens={inputTokens}
completionTokens={completionTokens} outputTokens={outputTokens}
cost={modelResponse.cost} cost={modelResponse.cost}
> >
<HStack spacing={0}> <HStack spacing={0}>

View File

@@ -17,8 +17,8 @@ export default function VariantStats(props: { variant: PromptVariant }) {
initialData: { initialData: {
evalResults: [], evalResults: [],
overallCost: 0, overallCost: 0,
promptTokens: 0, inputTokens: 0,
completionTokens: 0, outputTokens: 0,
scenarioCount: 0, scenarioCount: 0,
outputCount: 0, outputCount: 0,
awaitingEvals: false, awaitingEvals: false,
@@ -68,8 +68,8 @@ export default function VariantStats(props: { variant: PromptVariant }) {
</HStack> </HStack>
{data.overallCost && ( {data.overallCost && (
<CostTooltip <CostTooltip
promptTokens={data.promptTokens} inputTokens={data.inputTokens}
completionTokens={data.completionTokens} outputTokens={data.outputTokens}
cost={data.overallCost} cost={data.overallCost}
> >
<HStack spacing={0} align="center" color="gray.500"> <HStack spacing={0} align="center" color="gray.500">

View File

@@ -90,9 +90,9 @@ function TableRow({
isExpanded: boolean; isExpanded: boolean;
onToggle: () => void; onToggle: () => void;
}) { }) {
const isError = loggedCall.modelResponse?.respStatus !== 200; const isError = loggedCall.modelResponse?.statusCode !== 200;
const timeAgo = dayjs(loggedCall.startTime).fromNow(); const timeAgo = dayjs(loggedCall.requestedAt).fromNow();
const fullTime = dayjs(loggedCall.startTime).toString(); const fullTime = dayjs(loggedCall.requestedAt).toString();
const model = useMemo( const model = useMemo(
() => loggedCall.tags.find((tag) => tag.name.startsWith("$model"))?.value, () => loggedCall.tags.find((tag) => tag.name.startsWith("$model"))?.value,
@@ -124,7 +124,7 @@ function TableRow({
<Td isNumeric>{loggedCall.modelResponse?.inputTokens}</Td> <Td isNumeric>{loggedCall.modelResponse?.inputTokens}</Td>
<Td isNumeric>{loggedCall.modelResponse?.outputTokens}</Td> <Td isNumeric>{loggedCall.modelResponse?.outputTokens}</Td>
<Td sx={{ color: isError ? "red.500" : "green.500", fontWeight: "semibold" }} isNumeric> <Td sx={{ color: isError ? "red.500" : "green.500", fontWeight: "semibold" }} isNumeric>
{loggedCall.modelResponse?.respStatus ?? "No response"} {loggedCall.modelResponse?.statusCode ?? "No response"}
</Td> </Td>
</Tr> </Tr>
<Tr> <Tr>

View File

@@ -0,0 +1,61 @@
import {
ResponsiveContainer,
LineChart,
Line,
XAxis,
YAxis,
CartesianGrid,
Tooltip,
Legend,
} from "recharts";
import { useMemo } from "react";
import { useSelectedProject } from "~/utils/hooks";
import dayjs from "~/utils/dayjs";
import { api } from "~/utils/api";
export default function UsageGraph() {
const { data: selectedProject } = useSelectedProject();
const stats = api.dashboard.stats.useQuery(
{ projectId: selectedProject?.id ?? "" },
{ enabled: !!selectedProject },
);
const data = useMemo(() => {
return (
stats.data?.periods.map(({ period, numQueries, cost }) => ({
period,
Requests: numQueries,
"Total Spent (USD)": parseFloat(cost.toString()),
})) || []
);
}, [stats.data]);
return (
<ResponsiveContainer width="100%" height={400}>
<LineChart data={data} margin={{ top: 5, right: 20, left: 10, bottom: 5 }}>
<XAxis dataKey="period" tickFormatter={(str: string) => dayjs(str).format("MMM D")} />
<YAxis yAxisId="left" dataKey="Requests" orientation="left" stroke="#8884d8" />
<YAxis
yAxisId="right"
dataKey="Total Spent (USD)"
orientation="right"
unit="$"
stroke="#82ca9d"
/>
<Tooltip />
<Legend />
<CartesianGrid stroke="#f5f5f5" />
<Line dataKey="Requests" stroke="#8884d8" yAxisId="left" dot={false} strokeWidth={2} />
<Line
dataKey="Total Spent (USD)"
stroke="#82ca9d"
yAxisId="right"
dot={false}
strokeWidth={2}
/>
</LineChart>
</ResponsiveContainer>
);
}

View File

@@ -2,14 +2,14 @@ import { HStack, Icon, Text, Tooltip, type TooltipProps, VStack, Divider } from
import { BsCurrencyDollar } from "react-icons/bs"; import { BsCurrencyDollar } from "react-icons/bs";
type CostTooltipProps = { type CostTooltipProps = {
promptTokens: number | null; inputTokens: number | null;
completionTokens: number | null; outputTokens: number | null;
cost: number; cost: number;
} & TooltipProps; } & TooltipProps;
export const CostTooltip = ({ export const CostTooltip = ({
promptTokens, inputTokens,
completionTokens, outputTokens,
cost, cost,
children, children,
...props ...props
@@ -36,12 +36,12 @@ export const CostTooltip = ({
<HStack> <HStack>
<VStack w="28" spacing={1}> <VStack w="28" spacing={1}>
<Text>Prompt</Text> <Text>Prompt</Text>
<Text>{promptTokens ?? 0}</Text> <Text>{inputTokens ?? 0}</Text>
</VStack> </VStack>
<Divider borderColor="gray.200" h={8} orientation="vertical" /> <Divider borderColor="gray.200" h={8} orientation="vertical" />
<VStack w="28" spacing={1}> <VStack w="28" spacing={1}>
<Text whiteSpace="nowrap">Completion</Text> <Text whiteSpace="nowrap">Completion</Text>
<Text>{completionTokens ?? 0}</Text> <Text>{outputTokens ?? 0}</Text>
</VStack> </VStack>
</HStack> </HStack>
</VStack> </VStack>

View File

@@ -28,6 +28,10 @@ const modelProvider: AnthropicProvider = {
inputSchema: inputSchema as JSONSchema4, inputSchema: inputSchema as JSONSchema4,
canStream: true, canStream: true,
getCompletion, getCompletion,
getUsage: (input, output) => {
// TODO: add usage logic
return null;
},
...frontendModelProvider, ...frontendModelProvider,
}; };

View File

@@ -4,14 +4,10 @@ import {
type ChatCompletion, type ChatCompletion,
type CompletionCreateParams, type CompletionCreateParams,
} from "openai/resources/chat"; } from "openai/resources/chat";
import { countOpenAIChatTokens } from "~/utils/countTokens";
import { type CompletionResponse } from "../types"; import { type CompletionResponse } from "../types";
import { isArray, isString, omit } from "lodash-es"; import { isArray, isString, omit } from "lodash-es";
import { openai } from "~/server/utils/openai"; import { openai } from "~/server/utils/openai";
import { truthyFilter } from "~/utils/utils";
import { APIError } from "openai"; import { APIError } from "openai";
import frontendModelProvider from "./frontend";
import modelProvider, { type SupportedModel } from ".";
const mergeStreamedChunks = ( const mergeStreamedChunks = (
base: ChatCompletion | null, base: ChatCompletion | null,
@@ -60,9 +56,6 @@ export async function getCompletion(
): Promise<CompletionResponse<ChatCompletion>> { ): Promise<CompletionResponse<ChatCompletion>> {
const start = Date.now(); const start = Date.now();
let finalCompletion: ChatCompletion | null = null; let finalCompletion: ChatCompletion | null = null;
let promptTokens: number | undefined = undefined;
let completionTokens: number | undefined = undefined;
const modelName = modelProvider.getModel(input) as SupportedModel;
try { try {
if (onStream) { if (onStream) {
@@ -86,16 +79,6 @@ export async function getCompletion(
autoRetry: false, autoRetry: false,
}; };
} }
try {
promptTokens = countOpenAIChatTokens(modelName, input.messages);
completionTokens = countOpenAIChatTokens(
modelName,
finalCompletion.choices.map((c) => c.message).filter(truthyFilter),
);
} catch (err) {
// TODO handle this, library seems like maybe it doesn't work with function calls?
console.error(err);
}
} else { } else {
const resp = await openai.chat.completions.create( const resp = await openai.chat.completions.create(
{ ...input, stream: false }, { ...input, stream: false },
@@ -104,25 +87,14 @@ export async function getCompletion(
}, },
); );
finalCompletion = resp; finalCompletion = resp;
promptTokens = resp.usage?.prompt_tokens ?? 0;
completionTokens = resp.usage?.completion_tokens ?? 0;
} }
const timeToComplete = Date.now() - start; const timeToComplete = Date.now() - start;
const { promptTokenPrice, completionTokenPrice } = frontendModelProvider.models[modelName];
let cost = undefined;
if (promptTokenPrice && completionTokenPrice && promptTokens && completionTokens) {
cost = promptTokens * promptTokenPrice + completionTokens * completionTokenPrice;
}
return { return {
type: "success", type: "success",
statusCode: 200, statusCode: 200,
value: finalCompletion, value: finalCompletion,
timeToComplete, timeToComplete,
promptTokens,
completionTokens,
cost,
}; };
} catch (error: unknown) { } catch (error: unknown) {
if (error instanceof APIError) { if (error instanceof APIError) {

View File

@@ -4,6 +4,8 @@ import inputSchema from "./codegen/input.schema.json";
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat"; import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
import { getCompletion } from "./getCompletion"; import { getCompletion } from "./getCompletion";
import frontendModelProvider from "./frontend"; import frontendModelProvider from "./frontend";
import { countOpenAIChatTokens } from "~/utils/countTokens";
import { truthyFilter } from "~/utils/utils";
const supportedModels = [ const supportedModels = [
"gpt-4-0613", "gpt-4-0613",
@@ -39,6 +41,41 @@ const modelProvider: OpenaiChatModelProvider = {
inputSchema: inputSchema as JSONSchema4, inputSchema: inputSchema as JSONSchema4,
canStream: true, canStream: true,
getCompletion, getCompletion,
getUsage: (input, output) => {
if (output.choices.length === 0) return null;
const model = modelProvider.getModel(input);
if (!model) return null;
let inputTokens: number;
let outputTokens: number;
if (output.usage) {
inputTokens = output.usage.prompt_tokens;
outputTokens = output.usage.completion_tokens;
} else {
try {
inputTokens = countOpenAIChatTokens(model, input.messages);
outputTokens = countOpenAIChatTokens(
model,
output.choices.map((c) => c.message).filter(truthyFilter),
);
} catch (err) {
inputTokens = 0;
outputTokens = 0;
// TODO handle this, library seems like maybe it doesn't work with function calls?
console.error(err);
}
}
const { promptTokenPrice, completionTokenPrice } = frontendModelProvider.models[model];
let cost = undefined;
if (promptTokenPrice && completionTokenPrice && inputTokens && outputTokens) {
cost = inputTokens * promptTokenPrice + outputTokens * completionTokenPrice;
}
return { inputTokens: inputTokens, outputTokens: outputTokens, cost };
},
...frontendModelProvider, ...frontendModelProvider,
}; };

View File

@@ -75,6 +75,10 @@ const modelProvider: ReplicateLlama2Provider = {
}, },
canStream: true, canStream: true,
getCompletion, getCompletion,
getUsage: (input, output) => {
// TODO: add usage logic
return null;
},
...frontendModelProvider, ...frontendModelProvider,
}; };

View File

@@ -43,9 +43,6 @@ export type CompletionResponse<T> =
value: T; value: T;
timeToComplete: number; timeToComplete: number;
statusCode: number; statusCode: number;
promptTokens?: number;
completionTokens?: number;
cost?: number;
}; };
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = { export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
@@ -56,6 +53,10 @@ export type ModelProvider<SupportedModels extends string, InputSchema, OutputSch
input: InputSchema, input: InputSchema,
onStream: ((partialOutput: OutputSchema) => void) | null, onStream: ((partialOutput: OutputSchema) => void) | null,
) => Promise<CompletionResponse<OutputSchema>>; ) => Promise<CompletionResponse<OutputSchema>>;
getUsage: (
input: InputSchema,
output: OutputSchema,
) => { gpuRuntime?: number; inputTokens?: number; outputTokens?: number; cost?: number } | null;
// This is just a convenience for type inference, don't use it at runtime // This is just a convenience for type inference, don't use it at runtime
_outputSchema?: OutputSchema | null; _outputSchema?: OutputSchema | null;

View File

@@ -18,26 +18,15 @@ import {
Breadcrumb, Breadcrumb,
BreadcrumbItem, BreadcrumbItem,
} from "@chakra-ui/react"; } from "@chakra-ui/react";
import {
LineChart,
Line,
XAxis,
YAxis,
CartesianGrid,
Tooltip,
Legend,
ResponsiveContainer,
} from "recharts";
import { Ban, DollarSign, Hash } from "lucide-react"; import { Ban, DollarSign, Hash } from "lucide-react";
import { useMemo } from "react";
import AppShell from "~/components/nav/AppShell"; import AppShell from "~/components/nav/AppShell";
import PageHeaderContainer from "~/components/nav/PageHeaderContainer"; import PageHeaderContainer from "~/components/nav/PageHeaderContainer";
import ProjectBreadcrumbContents from "~/components/nav/ProjectBreadcrumbContents"; import ProjectBreadcrumbContents from "~/components/nav/ProjectBreadcrumbContents";
import { useSelectedProject } from "~/utils/hooks"; import { useSelectedProject } from "~/utils/hooks";
import dayjs from "~/utils/dayjs";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import LoggedCallTable from "~/components/dashboard/LoggedCallTable"; import LoggedCallTable from "~/components/dashboard/LoggedCallTable";
import UsageGraph from "~/components/dashboard/UsageGraph";
export default function LoggedCalls() { export default function LoggedCalls() {
const { data: selectedProject } = useSelectedProject(); const { data: selectedProject } = useSelectedProject();
@@ -47,16 +36,6 @@ export default function LoggedCalls() {
{ enabled: !!selectedProject }, { enabled: !!selectedProject },
); );
const data = useMemo(() => {
return (
stats.data?.periods.map(({ period, numQueries, totalCost }) => ({
period,
Requests: numQueries,
"Total Spent (USD)": parseFloat(totalCost.toString()),
})) || []
);
}, [stats.data]);
return ( return (
<AppShell requireAuth> <AppShell requireAuth>
<PageHeaderContainer> <PageHeaderContainer>
@@ -83,39 +62,7 @@ export default function LoggedCalls() {
</Heading> </Heading>
</CardHeader> </CardHeader>
<CardBody> <CardBody>
<ResponsiveContainer width="100%" height={400}> <UsageGraph />
<LineChart data={data} margin={{ top: 5, right: 20, left: 10, bottom: 5 }}>
<XAxis
dataKey="period"
tickFormatter={(str: string) => dayjs(str).format("MMM D")}
/>
<YAxis yAxisId="left" dataKey="Requests" orientation="left" stroke="#8884d8" />
<YAxis
yAxisId="right"
dataKey="Total Spent (USD)"
orientation="right"
unit="$"
stroke="#82ca9d"
/>
<Tooltip />
<Legend />
<CartesianGrid stroke="#f5f5f5" />
<Line
dataKey="Requests"
stroke="#8884d8"
yAxisId="left"
dot={false}
strokeWidth={2}
/>
<Line
dataKey="Total Spent (USD)"
stroke="#82ca9d"
yAxisId="right"
dot={false}
strokeWidth={2}
/>
</LineChart>
</ResponsiveContainer>
</CardBody> </CardBody>
</Card> </Card>
<VStack spacing="4" width="300px" align="stretch"> <VStack spacing="4" width="300px" align="stretch">
@@ -127,7 +74,7 @@ export default function LoggedCalls() {
<Icon as={DollarSign} boxSize={4} color="gray.500" /> <Icon as={DollarSign} boxSize={4} color="gray.500" />
</HStack> </HStack>
<StatNumber> <StatNumber>
${parseFloat(stats.data?.totals?.totalCost?.toString() ?? "0").toFixed(2)} ${parseFloat(stats.data?.totals?.cost?.toString() ?? "0").toFixed(3)}
</StatNumber> </StatNumber>
</Stat> </Stat>
</CardBody> </CardBody>

View File

@@ -38,7 +38,10 @@ export default function Settings() {
id: selectedProject.id, id: selectedProject.id,
updates: { name }, updates: { name },
}); });
await Promise.all([utils.projects.get.invalidate({ id: selectedProject.id })]); await Promise.all([
utils.projects.get.invalidate({ id: selectedProject.id }),
utils.projects.list.invalidate(),
]);
} }
}, [updateMutation, selectedProject]); }, [updateMutation, selectedProject]);

View File

@@ -24,9 +24,9 @@ export const dashboardRouter = createTRPCRouter({
) )
.where("projectId", "=", input.projectId) .where("projectId", "=", input.projectId)
.select(({ fn }) => [ .select(({ fn }) => [
sql<Date>`date_trunc('day', "LoggedCallModelResponse"."startTime")`.as("period"), sql<Date>`date_trunc('day', "LoggedCallModelResponse"."requestedAt")`.as("period"),
sql<number>`count("LoggedCall"."id")::int`.as("numQueries"), sql<number>`count("LoggedCall"."id")::int`.as("numQueries"),
fn.sum(fn.coalesce("LoggedCallModelResponse.totalCost", sql<number>`0`)).as("totalCost"), fn.sum(fn.coalesce("LoggedCallModelResponse.cost", sql<number>`0`)).as("cost"),
]) ])
.groupBy("period") .groupBy("period")
.orderBy("period") .orderBy("period")
@@ -57,7 +57,7 @@ export const dashboardRouter = createTRPCRouter({
backfilledPeriods.unshift({ backfilledPeriods.unshift({
period: dayjs(dayToMatch).toDate(), period: dayjs(dayToMatch).toDate(),
numQueries: 0, numQueries: 0,
totalCost: 0, cost: 0,
}); });
} }
dayToMatch = dayToMatch.subtract(1, "day"); dayToMatch = dayToMatch.subtract(1, "day");
@@ -72,7 +72,7 @@ export const dashboardRouter = createTRPCRouter({
) )
.where("projectId", "=", input.projectId) .where("projectId", "=", input.projectId)
.select(({ fn }) => [ .select(({ fn }) => [
fn.sum(fn.coalesce("LoggedCallModelResponse.totalCost", sql<number>`0`)).as("totalCost"), fn.sum(fn.coalesce("LoggedCallModelResponse.cost", sql<number>`0`)).as("cost"),
fn.count("LoggedCall.id").as("numQueries"), fn.count("LoggedCall.id").as("numQueries"),
]) ])
.executeTakeFirst(); .executeTakeFirst();
@@ -85,8 +85,8 @@ export const dashboardRouter = createTRPCRouter({
"LoggedCall.id", "LoggedCall.id",
"LoggedCallModelResponse.originalLoggedCallId", "LoggedCallModelResponse.originalLoggedCallId",
) )
.select(({ fn }) => [fn.count("LoggedCall.id").as("count"), "respStatus as code"]) .select(({ fn }) => [fn.count("LoggedCall.id").as("count"), "statusCode as code"])
.where("respStatus", ">", 200) .where("statusCode", ">", 200)
.groupBy("code") .groupBy("code")
.orderBy("count", "desc") .orderBy("count", "desc")
.execute(); .execute();
@@ -108,7 +108,7 @@ export const dashboardRouter = createTRPCRouter({
// https://discord.com/channels/966627436387266600/1122258443886153758/1122258443886153758 // https://discord.com/channels/966627436387266600/1122258443886153758/1122258443886153758
loggedCalls: publicProcedure.input(z.object({})).query(async ({ input }) => { loggedCalls: publicProcedure.input(z.object({})).query(async ({ input }) => {
const loggedCalls = await prisma.loggedCall.findMany({ const loggedCalls = await prisma.loggedCall.findMany({
orderBy: { startTime: "desc" }, orderBy: { requestedAt: "desc" },
include: { tags: true, modelResponse: true }, include: { tags: true, modelResponse: true },
take: 20, take: 20,
}); });

View File

@@ -227,7 +227,7 @@ export const experimentsRouter = createTRPCRouter({
...modelResponseData, ...modelResponseData,
id: newModelResponseId, id: newModelResponseId,
scenarioVariantCellId: newCellId, scenarioVariantCellId: newCellId,
output: (modelResponse.output as Prisma.InputJsonValue) ?? undefined, respPayload: (modelResponse.respPayload as Prisma.InputJsonValue) ?? undefined,
}); });
for (const evaluation of outputEvaluations) { for (const evaluation of outputEvaluations) {
outputEvaluationsToCreate.push({ outputEvaluationsToCreate.push({

View File

@@ -7,6 +7,11 @@ import { TRPCError } from "@trpc/server";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { hashRequest } from "~/server/utils/hashObject"; import { hashRequest } from "~/server/utils/hashObject";
import modelProvider from "~/modelProviders/openai-ChatCompletion";
import {
type ChatCompletion,
type CompletionCreateParams,
} from "openai/resources/chat/completions";
const reqValidator = z.object({ const reqValidator = z.object({
model: z.string(), model: z.string(),
@@ -16,11 +21,6 @@ const reqValidator = z.object({
const respValidator = z.object({ const respValidator = z.object({
id: z.string(), id: z.string(),
model: z.string(), model: z.string(),
usage: z.object({
total_tokens: z.number(),
prompt_tokens: z.number(),
completion_tokens: z.number(),
}),
choices: z.array( choices: z.array(
z.object({ z.object({
finish_reason: z.string(), finish_reason: z.string(),
@@ -76,7 +76,7 @@ export const externalApiRouter = createTRPCRouter({
originalLoggedCall: true, originalLoggedCall: true,
}, },
orderBy: { orderBy: {
startTime: "desc", requestedAt: "desc",
}, },
}); });
@@ -85,7 +85,7 @@ export const externalApiRouter = createTRPCRouter({
await prisma.loggedCall.create({ await prisma.loggedCall.create({
data: { data: {
projectId: key.projectId, projectId: key.projectId,
startTime: new Date(input.startTime), requestedAt: new Date(input.startTime),
cacheHit: true, cacheHit: true,
modelResponseId: existingResponse.id, modelResponseId: existingResponse.id,
}, },
@@ -140,14 +140,20 @@ export const externalApiRouter = createTRPCRouter({
const newLoggedCallId = uuidv4(); const newLoggedCallId = uuidv4();
const newModelResponseId = uuidv4(); const newModelResponseId = uuidv4();
const usage = respPayload.success ? respPayload.data.usage : undefined; let usage;
if (reqPayload.success && respPayload.success) {
usage = modelProvider.getUsage(
input.reqPayload as CompletionCreateParams,
input.respPayload as ChatCompletion,
);
}
await prisma.$transaction([ await prisma.$transaction([
prisma.loggedCall.create({ prisma.loggedCall.create({
data: { data: {
id: newLoggedCallId, id: newLoggedCallId,
projectId: key.projectId, projectId: key.projectId,
startTime: new Date(input.startTime), requestedAt: new Date(input.startTime),
cacheHit: false, cacheHit: false,
}, },
}), }),
@@ -155,20 +161,17 @@ export const externalApiRouter = createTRPCRouter({
data: { data: {
id: newModelResponseId, id: newModelResponseId,
originalLoggedCallId: newLoggedCallId, originalLoggedCallId: newLoggedCallId,
startTime: new Date(input.startTime), requestedAt: new Date(input.startTime),
endTime: new Date(input.endTime), receivedAt: new Date(input.endTime),
reqPayload: input.reqPayload as Prisma.InputJsonValue, reqPayload: input.reqPayload as Prisma.InputJsonValue,
respPayload: input.respPayload as Prisma.InputJsonValue, respPayload: input.respPayload as Prisma.InputJsonValue,
respStatus: input.respStatus, statusCode: input.respStatus,
error: input.error, errorMessage: input.error,
durationMs: input.endTime - input.startTime, durationMs: input.endTime - input.startTime,
...(respPayload.success cacheKey: respPayload.success ? requestHash : null,
? { inputTokens: usage?.inputTokens,
cacheKey: requestHash, outputTokens: usage?.outputTokens,
inputTokens: usage ? usage.prompt_tokens : undefined, cost: usage?.cost,
outputTokens: usage ? usage.completion_tokens : undefined,
}
: null),
}, },
}), }),
// Avoid foreign key constraint error by updating the logged call after the model response is created // Avoid foreign key constraint error by updating the logged call after the model response is created
@@ -182,24 +185,22 @@ export const externalApiRouter = createTRPCRouter({
}), }),
]); ]);
if (input.tags) { const tagsToCreate = Object.entries(input.tags ?? {}).map(([name, value]) => ({
const tagsToCreate = Object.entries(input.tags).map(([name, value]) => ({ loggedCallId: newLoggedCallId,
loggedCallId: newLoggedCallId, // sanitize tags
// sanitize tags name: name.replaceAll(/[^a-zA-Z0-9_]/g, "_"),
name: name.replaceAll(/[^a-zA-Z0-9_]/g, "_"), value,
value, }));
}));
if (reqPayload.success) { if (reqPayload.success) {
tagsToCreate.push({ tagsToCreate.push({
loggedCallId: newLoggedCallId, loggedCallId: newLoggedCallId,
name: "$model", name: "$model",
value: reqPayload.data.model, value: reqPayload.data.model,
});
}
await prisma.loggedCallTag.createMany({
data: tagsToCreate,
}); });
} }
await prisma.loggedCallTag.createMany({
data: tagsToCreate,
});
}), }),
}); });

View File

@@ -55,7 +55,7 @@ export const promptVariantsRouter = createTRPCRouter({
where: { where: {
modelResponse: { modelResponse: {
outdated: false, outdated: false,
output: { not: Prisma.AnyNull }, respPayload: { not: Prisma.AnyNull },
scenarioVariantCell: { scenarioVariantCell: {
promptVariant: { promptVariant: {
id: input.variantId, id: input.variantId,
@@ -100,7 +100,7 @@ export const promptVariantsRouter = createTRPCRouter({
modelResponses: { modelResponses: {
some: { some: {
outdated: false, outdated: false,
output: { respPayload: {
not: Prisma.AnyNull, not: Prisma.AnyNull,
}, },
}, },
@@ -111,7 +111,7 @@ export const promptVariantsRouter = createTRPCRouter({
const overallTokens = await prisma.modelResponse.aggregate({ const overallTokens = await prisma.modelResponse.aggregate({
where: { where: {
outdated: false, outdated: false,
output: { respPayload: {
not: Prisma.AnyNull, not: Prisma.AnyNull,
}, },
scenarioVariantCell: { scenarioVariantCell: {
@@ -123,13 +123,13 @@ export const promptVariantsRouter = createTRPCRouter({
}, },
_sum: { _sum: {
cost: true, cost: true,
promptTokens: true, inputTokens: true,
completionTokens: true, outputTokens: true,
}, },
}); });
const promptTokens = overallTokens._sum?.promptTokens ?? 0; const inputTokens = overallTokens._sum?.inputTokens ?? 0;
const completionTokens = overallTokens._sum?.completionTokens ?? 0; const outputTokens = overallTokens._sum?.outputTokens ?? 0;
const awaitingEvals = !!evalResults.find( const awaitingEvals = !!evalResults.find(
(result) => result.totalCount < scenarioCount * evals.length, (result) => result.totalCount < scenarioCount * evals.length,
@@ -137,8 +137,8 @@ export const promptVariantsRouter = createTRPCRouter({
return { return {
evalResults, evalResults,
promptTokens, inputTokens,
completionTokens, outputTokens,
overallCost: overallTokens._sum?.cost ?? 0, overallCost: overallTokens._sum?.cost ?? 0,
scenarioCount, scenarioCount,
outputCount, outputCount,

View File

@@ -99,26 +99,27 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
} }
: null; : null;
const inputHash = hashObject(prompt as JsonValue); const cacheKey = hashObject(prompt as JsonValue);
let modelResponse = await prisma.modelResponse.create({ let modelResponse = await prisma.modelResponse.create({
data: { data: {
inputHash, cacheKey,
scenarioVariantCellId: cellId, scenarioVariantCellId: cellId,
requestedAt: new Date(), requestedAt: new Date(),
}, },
}); });
const response = await provider.getCompletion(prompt.modelInput, onStream); const response = await provider.getCompletion(prompt.modelInput, onStream);
if (response.type === "success") { if (response.type === "success") {
const usage = provider.getUsage(prompt.modelInput, response.value);
modelResponse = await prisma.modelResponse.update({ modelResponse = await prisma.modelResponse.update({
where: { id: modelResponse.id }, where: { id: modelResponse.id },
data: { data: {
output: response.value as Prisma.InputJsonObject, respPayload: response.value as Prisma.InputJsonObject,
statusCode: response.statusCode, statusCode: response.statusCode,
receivedAt: new Date(), receivedAt: new Date(),
promptTokens: response.promptTokens, inputTokens: usage?.inputTokens,
completionTokens: response.completionTokens, outputTokens: usage?.outputTokens,
cost: response.cost, cost: usage?.cost,
}, },
}); });

View File

@@ -51,7 +51,7 @@ export const runAllEvals = async (experimentId: string) => {
const outputs = await prisma.modelResponse.findMany({ const outputs = await prisma.modelResponse.findMany({
where: { where: {
outdated: false, outdated: false,
output: { respPayload: {
not: Prisma.AnyNull, not: Prisma.AnyNull,
}, },
scenarioVariantCell: { scenarioVariantCell: {

View File

@@ -57,7 +57,7 @@ export const generateNewCell = async (
return; return;
} }
const inputHash = hashObject(parsedConstructFn); const cacheKey = hashObject(parsedConstructFn);
cell = await prisma.scenarioVariantCell.create({ cell = await prisma.scenarioVariantCell.create({
data: { data: {
@@ -73,8 +73,8 @@ export const generateNewCell = async (
const matchingModelResponse = await prisma.modelResponse.findFirst({ const matchingModelResponse = await prisma.modelResponse.findFirst({
where: { where: {
inputHash, cacheKey,
output: { respPayload: {
not: Prisma.AnyNull, not: Prisma.AnyNull,
}, },
}, },
@@ -92,7 +92,7 @@ export const generateNewCell = async (
data: { data: {
...omit(matchingModelResponse, ["id", "scenarioVariantCell"]), ...omit(matchingModelResponse, ["id", "scenarioVariantCell"]),
scenarioVariantCellId: cell.id, scenarioVariantCellId: cell.id,
output: matchingModelResponse.output as Prisma.InputJsonValue, respPayload: matchingModelResponse.respPayload as Prisma.InputJsonValue,
}, },
}); });

View File

@@ -71,7 +71,7 @@ export const runOneEval = async (
provider: SupportedProvider, provider: SupportedProvider,
): Promise<{ result: number; details?: string }> => { ): Promise<{ result: number; details?: string }> => {
const modelProvider = modelProviders[provider]; const modelProvider = modelProviders[provider];
const message = modelProvider.normalizeOutput(modelResponse.output); const message = modelProvider.normalizeOutput(modelResponse.respPayload);
if (!message) return { result: 0 }; if (!message) return { result: 0 };