Record model and cost when reporting logs (#136)

* Rename prompt and completion tokens to input and output tokens

* Add getUsage function

* Record model and cost when reporting log

* Remove unused imports

* Move UsageGraph to its own component

* Standardize model response fields

* Fix types
This commit is contained in:
arcticfly
2023-08-11 13:56:47 -07:00
committed by GitHub
parent f270579283
commit 8d1ee62ff1
24 changed files with 295 additions and 199 deletions

View File

@@ -0,0 +1,66 @@
/*
Warnings:
- You are about to rename the column `completionTokens` to `outputTokens` on the `ModelResponse` table.
- You are about to rename the column `promptTokens` to `inputTokens` on the `ModelResponse` table.
- You are about to rename the column `startTime` on the `LoggedCall` table to `requestedAt`. Ensure compatibility with application logic.
- You are about to rename the column `startTime` on the `LoggedCallModelResponse` table to `requestedAt`. Ensure compatibility with application logic.
- You are about to rename the column `endTime` on the `LoggedCallModelResponse` table to `receivedAt`. Ensure compatibility with application logic.
- You are about to rename the column `error` on the `LoggedCallModelResponse` table to `errorMessage`. Ensure compatibility with application logic.
- You are about to rename the column `respStatus` on the `LoggedCallModelResponse` table to `statusCode`. Ensure compatibility with application logic.
- You are about to rename the column `totalCost` on the `LoggedCallModelResponse` table to `cost`. Ensure compatibility with application logic.
- You are about to rename the column `inputHash` on the `ModelResponse` table to `cacheKey`. Ensure compatibility with application logic.
- You are about to rename the column `output` on the `ModelResponse` table to `respPayload`. Ensure compatibility with application logic.
*/
-- DropIndex
DROP INDEX "LoggedCall_startTime_idx";
-- DropIndex
DROP INDEX "ModelResponse_inputHash_idx";
-- Rename completionTokens to outputTokens
ALTER TABLE "ModelResponse"
RENAME COLUMN "completionTokens" TO "outputTokens";
-- Rename promptTokens to inputTokens
ALTER TABLE "ModelResponse"
RENAME COLUMN "promptTokens" TO "inputTokens";
-- AlterTable
ALTER TABLE "LoggedCall"
RENAME COLUMN "startTime" TO "requestedAt";
-- AlterTable
ALTER TABLE "LoggedCallModelResponse"
RENAME COLUMN "startTime" TO "requestedAt";
-- AlterTable
ALTER TABLE "LoggedCallModelResponse"
RENAME COLUMN "endTime" TO "receivedAt";
-- AlterTable
ALTER TABLE "LoggedCallModelResponse"
RENAME COLUMN "error" TO "errorMessage";
-- AlterTable
ALTER TABLE "LoggedCallModelResponse"
RENAME COLUMN "respStatus" TO "statusCode";
-- AlterTable
ALTER TABLE "LoggedCallModelResponse"
RENAME COLUMN "totalCost" TO "cost";
-- AlterTable
ALTER TABLE "ModelResponse"
RENAME COLUMN "inputHash" TO "cacheKey";
-- AlterTable
ALTER TABLE "ModelResponse"
RENAME COLUMN "output" TO "respPayload";
-- CreateIndex
CREATE INDEX "LoggedCall_requestedAt_idx" ON "LoggedCall"("requestedAt");
-- CreateIndex
CREATE INDEX "ModelResponse_cacheKey_idx" ON "ModelResponse"("cacheKey");

View File

@@ -112,13 +112,13 @@ model ScenarioVariantCell {
model ModelResponse {
id String @id @default(uuid()) @db.Uuid
inputHash String
cacheKey String
requestedAt DateTime?
receivedAt DateTime?
output Json?
respPayload Json?
cost Float?
promptTokens Int?
completionTokens Int?
inputTokens Int?
outputTokens Int?
statusCode Int?
errorMessage String?
retryTime DateTime?
@@ -131,7 +131,7 @@ model ModelResponse {
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
outputEvaluations OutputEvaluation[]
@@index([inputHash])
@@index([cacheKey])
}
enum EvalType {
@@ -256,7 +256,7 @@ model WorldChampEntrant {
model LoggedCall {
id String @id @default(uuid()) @db.Uuid
startTime DateTime
requestedAt DateTime
// True if this call was served from the cache, false otherwise
cacheHit Boolean
@@ -278,7 +278,7 @@ model LoggedCall {
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
@@index([startTime])
@@index([requestedAt])
}
model LoggedCallModelResponse {
@@ -287,14 +287,14 @@ model LoggedCallModelResponse {
reqPayload Json
// The HTTP status returned by the model provider
respStatus Int?
statusCode Int?
respPayload Json?
// Should be null if the request was successful, and some string if the request failed.
error String?
errorMessage String?
startTime DateTime
endTime DateTime
requestedAt DateTime
receivedAt DateTime
// Note: the function to calculate the cacheKey should include the project
// ID so we don't share cached responses between projects, which could be an
@@ -308,7 +308,7 @@ model LoggedCallModelResponse {
outputTokens Int?
finishReason String?
completionId String?
totalCost Decimal? @db.Decimal(18, 12)
cost Decimal? @db.Decimal(18, 12)
// The LoggedCall that created this LoggedCallModelResponse
originalLoggedCallId String @unique @db.Uuid

View File

@@ -339,17 +339,17 @@ for (let i = 0; i < 1437; i++) {
MODEL_RESPONSE_TEMPLATES[Math.floor(Math.random() * MODEL_RESPONSE_TEMPLATES.length)]!;
const model = template.reqPayload.model;
// choose random time in the last two weeks, with a bias towards the last few days
const startTime = new Date(Date.now() - Math.pow(Math.random(), 2) * 1000 * 60 * 60 * 24 * 14);
const requestedAt = new Date(Date.now() - Math.pow(Math.random(), 2) * 1000 * 60 * 60 * 24 * 14);
// choose random delay anywhere from 2 to 10 seconds later for gpt-4, or 1 to 5 seconds for gpt-3.5
const delay =
model === "gpt-4" ? 1000 * 2 + Math.random() * 1000 * 8 : 1000 + Math.random() * 1000 * 4;
const endTime = new Date(startTime.getTime() + delay);
const receivedAt = new Date(requestedAt.getTime() + delay);
loggedCallsToCreate.push({
id: loggedCallId,
cacheHit: false,
startTime,
requestedAt,
projectId: project.id,
createdAt: startTime,
createdAt: requestedAt,
});
const { promptTokenPrice, completionTokenPrice } =
@@ -365,21 +365,20 @@ for (let i = 0; i < 1437; i++) {
loggedCallModelResponsesToCreate.push({
id: loggedCallModelResponseId,
startTime,
endTime,
requestedAt,
receivedAt,
originalLoggedCallId: loggedCallId,
reqPayload: template.reqPayload,
respPayload: template.respPayload,
respStatus: template.respStatus,
error: template.error,
createdAt: startTime,
statusCode: template.respStatus,
errorMessage: template.error,
createdAt: requestedAt,
cacheKey: hashRequest(project.id, template.reqPayload as JsonValue),
durationMs: endTime.getTime() - startTime.getTime(),
durationMs: receivedAt.getTime() - requestedAt.getTime(),
inputTokens: template.inputTokens,
outputTokens: template.outputTokens,
finishReason: template.finishReason,
totalCost:
template.inputTokens * promptTokenPrice + template.outputTokens * completionTokenPrice,
cost: template.inputTokens * promptTokenPrice + template.outputTokens * completionTokenPrice,
});
loggedCallsToUpdate.push({
where: {

View File

@@ -107,7 +107,7 @@ export default function OutputCell({
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
const showLogs = !streamedMessage && !mostRecentResponse?.output;
const showLogs = !streamedMessage && !mostRecentResponse?.respPayload;
if (showLogs)
return (
@@ -160,13 +160,13 @@ export default function OutputCell({
</CellWrapper>
);
const normalizedOutput = mostRecentResponse?.output
? provider.normalizeOutput(mostRecentResponse?.output)
const normalizedOutput = mostRecentResponse?.respPayload
? provider.normalizeOutput(mostRecentResponse?.respPayload)
: streamedMessage
? provider.normalizeOutput(streamedMessage)
: null;
if (mostRecentResponse?.output && normalizedOutput?.type === "json") {
if (mostRecentResponse?.respPayload && normalizedOutput?.type === "json") {
return (
<CellWrapper>
<SyntaxHighlighter

View File

@@ -19,8 +19,8 @@ export const OutputStats = ({
? modelResponse.receivedAt.getTime() - modelResponse.requestedAt.getTime()
: 0;
const promptTokens = modelResponse.promptTokens;
const completionTokens = modelResponse.completionTokens;
const inputTokens = modelResponse.inputTokens;
const outputTokens = modelResponse.outputTokens;
return (
<HStack
@@ -55,8 +55,8 @@ export const OutputStats = ({
</HStack>
{modelResponse.cost && (
<CostTooltip
promptTokens={promptTokens}
completionTokens={completionTokens}
inputTokens={inputTokens}
outputTokens={outputTokens}
cost={modelResponse.cost}
>
<HStack spacing={0}>

View File

@@ -17,8 +17,8 @@ export default function VariantStats(props: { variant: PromptVariant }) {
initialData: {
evalResults: [],
overallCost: 0,
promptTokens: 0,
completionTokens: 0,
inputTokens: 0,
outputTokens: 0,
scenarioCount: 0,
outputCount: 0,
awaitingEvals: false,
@@ -68,8 +68,8 @@ export default function VariantStats(props: { variant: PromptVariant }) {
</HStack>
{data.overallCost && (
<CostTooltip
promptTokens={data.promptTokens}
completionTokens={data.completionTokens}
inputTokens={data.inputTokens}
outputTokens={data.outputTokens}
cost={data.overallCost}
>
<HStack spacing={0} align="center" color="gray.500">

View File

@@ -90,9 +90,9 @@ function TableRow({
isExpanded: boolean;
onToggle: () => void;
}) {
const isError = loggedCall.modelResponse?.respStatus !== 200;
const timeAgo = dayjs(loggedCall.startTime).fromNow();
const fullTime = dayjs(loggedCall.startTime).toString();
const isError = loggedCall.modelResponse?.statusCode !== 200;
const timeAgo = dayjs(loggedCall.requestedAt).fromNow();
const fullTime = dayjs(loggedCall.requestedAt).toString();
const model = useMemo(
() => loggedCall.tags.find((tag) => tag.name.startsWith("$model"))?.value,
@@ -124,7 +124,7 @@ function TableRow({
<Td isNumeric>{loggedCall.modelResponse?.inputTokens}</Td>
<Td isNumeric>{loggedCall.modelResponse?.outputTokens}</Td>
<Td sx={{ color: isError ? "red.500" : "green.500", fontWeight: "semibold" }} isNumeric>
{loggedCall.modelResponse?.respStatus ?? "No response"}
{loggedCall.modelResponse?.statusCode ?? "No response"}
</Td>
</Tr>
<Tr>

View File

@@ -0,0 +1,61 @@
import {
ResponsiveContainer,
LineChart,
Line,
XAxis,
YAxis,
CartesianGrid,
Tooltip,
Legend,
} from "recharts";
import { useMemo } from "react";
import { useSelectedProject } from "~/utils/hooks";
import dayjs from "~/utils/dayjs";
import { api } from "~/utils/api";
export default function UsageGraph() {
const { data: selectedProject } = useSelectedProject();
const stats = api.dashboard.stats.useQuery(
{ projectId: selectedProject?.id ?? "" },
{ enabled: !!selectedProject },
);
const data = useMemo(() => {
return (
stats.data?.periods.map(({ period, numQueries, cost }) => ({
period,
Requests: numQueries,
"Total Spent (USD)": parseFloat(cost.toString()),
})) || []
);
}, [stats.data]);
return (
<ResponsiveContainer width="100%" height={400}>
<LineChart data={data} margin={{ top: 5, right: 20, left: 10, bottom: 5 }}>
<XAxis dataKey="period" tickFormatter={(str: string) => dayjs(str).format("MMM D")} />
<YAxis yAxisId="left" dataKey="Requests" orientation="left" stroke="#8884d8" />
<YAxis
yAxisId="right"
dataKey="Total Spent (USD)"
orientation="right"
unit="$"
stroke="#82ca9d"
/>
<Tooltip />
<Legend />
<CartesianGrid stroke="#f5f5f5" />
<Line dataKey="Requests" stroke="#8884d8" yAxisId="left" dot={false} strokeWidth={2} />
<Line
dataKey="Total Spent (USD)"
stroke="#82ca9d"
yAxisId="right"
dot={false}
strokeWidth={2}
/>
</LineChart>
</ResponsiveContainer>
);
}

View File

@@ -2,14 +2,14 @@ import { HStack, Icon, Text, Tooltip, type TooltipProps, VStack, Divider } from
import { BsCurrencyDollar } from "react-icons/bs";
type CostTooltipProps = {
promptTokens: number | null;
completionTokens: number | null;
inputTokens: number | null;
outputTokens: number | null;
cost: number;
} & TooltipProps;
export const CostTooltip = ({
promptTokens,
completionTokens,
inputTokens,
outputTokens,
cost,
children,
...props
@@ -36,12 +36,12 @@ export const CostTooltip = ({
<HStack>
<VStack w="28" spacing={1}>
<Text>Prompt</Text>
<Text>{promptTokens ?? 0}</Text>
<Text>{inputTokens ?? 0}</Text>
</VStack>
<Divider borderColor="gray.200" h={8} orientation="vertical" />
<VStack w="28" spacing={1}>
<Text whiteSpace="nowrap">Completion</Text>
<Text>{completionTokens ?? 0}</Text>
<Text>{outputTokens ?? 0}</Text>
</VStack>
</HStack>
</VStack>

View File

@@ -28,6 +28,10 @@ const modelProvider: AnthropicProvider = {
inputSchema: inputSchema as JSONSchema4,
canStream: true,
getCompletion,
getUsage: (input, output) => {
// TODO: add usage logic
return null;
},
...frontendModelProvider,
};

View File

@@ -4,14 +4,10 @@ import {
type ChatCompletion,
type CompletionCreateParams,
} from "openai/resources/chat";
import { countOpenAIChatTokens } from "~/utils/countTokens";
import { type CompletionResponse } from "../types";
import { isArray, isString, omit } from "lodash-es";
import { openai } from "~/server/utils/openai";
import { truthyFilter } from "~/utils/utils";
import { APIError } from "openai";
import frontendModelProvider from "./frontend";
import modelProvider, { type SupportedModel } from ".";
const mergeStreamedChunks = (
base: ChatCompletion | null,
@@ -60,9 +56,6 @@ export async function getCompletion(
): Promise<CompletionResponse<ChatCompletion>> {
const start = Date.now();
let finalCompletion: ChatCompletion | null = null;
let promptTokens: number | undefined = undefined;
let completionTokens: number | undefined = undefined;
const modelName = modelProvider.getModel(input) as SupportedModel;
try {
if (onStream) {
@@ -86,16 +79,6 @@ export async function getCompletion(
autoRetry: false,
};
}
try {
promptTokens = countOpenAIChatTokens(modelName, input.messages);
completionTokens = countOpenAIChatTokens(
modelName,
finalCompletion.choices.map((c) => c.message).filter(truthyFilter),
);
} catch (err) {
// TODO handle this, library seems like maybe it doesn't work with function calls?
console.error(err);
}
} else {
const resp = await openai.chat.completions.create(
{ ...input, stream: false },
@@ -104,25 +87,14 @@ export async function getCompletion(
},
);
finalCompletion = resp;
promptTokens = resp.usage?.prompt_tokens ?? 0;
completionTokens = resp.usage?.completion_tokens ?? 0;
}
const timeToComplete = Date.now() - start;
const { promptTokenPrice, completionTokenPrice } = frontendModelProvider.models[modelName];
let cost = undefined;
if (promptTokenPrice && completionTokenPrice && promptTokens && completionTokens) {
cost = promptTokens * promptTokenPrice + completionTokens * completionTokenPrice;
}
return {
type: "success",
statusCode: 200,
value: finalCompletion,
timeToComplete,
promptTokens,
completionTokens,
cost,
};
} catch (error: unknown) {
if (error instanceof APIError) {

View File

@@ -4,6 +4,8 @@ import inputSchema from "./codegen/input.schema.json";
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
import { getCompletion } from "./getCompletion";
import frontendModelProvider from "./frontend";
import { countOpenAIChatTokens } from "~/utils/countTokens";
import { truthyFilter } from "~/utils/utils";
const supportedModels = [
"gpt-4-0613",
@@ -39,6 +41,41 @@ const modelProvider: OpenaiChatModelProvider = {
inputSchema: inputSchema as JSONSchema4,
canStream: true,
getCompletion,
getUsage: (input, output) => {
if (output.choices.length === 0) return null;
const model = modelProvider.getModel(input);
if (!model) return null;
let inputTokens: number;
let outputTokens: number;
if (output.usage) {
inputTokens = output.usage.prompt_tokens;
outputTokens = output.usage.completion_tokens;
} else {
try {
inputTokens = countOpenAIChatTokens(model, input.messages);
outputTokens = countOpenAIChatTokens(
model,
output.choices.map((c) => c.message).filter(truthyFilter),
);
} catch (err) {
inputTokens = 0;
outputTokens = 0;
// TODO handle this, library seems like maybe it doesn't work with function calls?
console.error(err);
}
}
const { promptTokenPrice, completionTokenPrice } = frontendModelProvider.models[model];
let cost = undefined;
if (promptTokenPrice && completionTokenPrice && inputTokens && outputTokens) {
cost = inputTokens * promptTokenPrice + outputTokens * completionTokenPrice;
}
return { inputTokens: inputTokens, outputTokens: outputTokens, cost };
},
...frontendModelProvider,
};

View File

@@ -75,6 +75,10 @@ const modelProvider: ReplicateLlama2Provider = {
},
canStream: true,
getCompletion,
getUsage: (input, output) => {
// TODO: add usage logic
return null;
},
...frontendModelProvider,
};

View File

@@ -43,9 +43,6 @@ export type CompletionResponse<T> =
value: T;
timeToComplete: number;
statusCode: number;
promptTokens?: number;
completionTokens?: number;
cost?: number;
};
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
@@ -56,6 +53,10 @@ export type ModelProvider<SupportedModels extends string, InputSchema, OutputSch
input: InputSchema,
onStream: ((partialOutput: OutputSchema) => void) | null,
) => Promise<CompletionResponse<OutputSchema>>;
getUsage: (
input: InputSchema,
output: OutputSchema,
) => { gpuRuntime?: number; inputTokens?: number; outputTokens?: number; cost?: number } | null;
// This is just a convenience for type inference, don't use it at runtime
_outputSchema?: OutputSchema | null;

View File

@@ -18,26 +18,15 @@ import {
Breadcrumb,
BreadcrumbItem,
} from "@chakra-ui/react";
import {
LineChart,
Line,
XAxis,
YAxis,
CartesianGrid,
Tooltip,
Legend,
ResponsiveContainer,
} from "recharts";
import { Ban, DollarSign, Hash } from "lucide-react";
import { useMemo } from "react";
import AppShell from "~/components/nav/AppShell";
import PageHeaderContainer from "~/components/nav/PageHeaderContainer";
import ProjectBreadcrumbContents from "~/components/nav/ProjectBreadcrumbContents";
import { useSelectedProject } from "~/utils/hooks";
import dayjs from "~/utils/dayjs";
import { api } from "~/utils/api";
import LoggedCallTable from "~/components/dashboard/LoggedCallTable";
import UsageGraph from "~/components/dashboard/UsageGraph";
export default function LoggedCalls() {
const { data: selectedProject } = useSelectedProject();
@@ -47,16 +36,6 @@ export default function LoggedCalls() {
{ enabled: !!selectedProject },
);
const data = useMemo(() => {
return (
stats.data?.periods.map(({ period, numQueries, totalCost }) => ({
period,
Requests: numQueries,
"Total Spent (USD)": parseFloat(totalCost.toString()),
})) || []
);
}, [stats.data]);
return (
<AppShell requireAuth>
<PageHeaderContainer>
@@ -83,39 +62,7 @@ export default function LoggedCalls() {
</Heading>
</CardHeader>
<CardBody>
<ResponsiveContainer width="100%" height={400}>
<LineChart data={data} margin={{ top: 5, right: 20, left: 10, bottom: 5 }}>
<XAxis
dataKey="period"
tickFormatter={(str: string) => dayjs(str).format("MMM D")}
/>
<YAxis yAxisId="left" dataKey="Requests" orientation="left" stroke="#8884d8" />
<YAxis
yAxisId="right"
dataKey="Total Spent (USD)"
orientation="right"
unit="$"
stroke="#82ca9d"
/>
<Tooltip />
<Legend />
<CartesianGrid stroke="#f5f5f5" />
<Line
dataKey="Requests"
stroke="#8884d8"
yAxisId="left"
dot={false}
strokeWidth={2}
/>
<Line
dataKey="Total Spent (USD)"
stroke="#82ca9d"
yAxisId="right"
dot={false}
strokeWidth={2}
/>
</LineChart>
</ResponsiveContainer>
<UsageGraph />
</CardBody>
</Card>
<VStack spacing="4" width="300px" align="stretch">
@@ -127,7 +74,7 @@ export default function LoggedCalls() {
<Icon as={DollarSign} boxSize={4} color="gray.500" />
</HStack>
<StatNumber>
${parseFloat(stats.data?.totals?.totalCost?.toString() ?? "0").toFixed(2)}
${parseFloat(stats.data?.totals?.cost?.toString() ?? "0").toFixed(3)}
</StatNumber>
</Stat>
</CardBody>

View File

@@ -38,7 +38,10 @@ export default function Settings() {
id: selectedProject.id,
updates: { name },
});
await Promise.all([utils.projects.get.invalidate({ id: selectedProject.id })]);
await Promise.all([
utils.projects.get.invalidate({ id: selectedProject.id }),
utils.projects.list.invalidate(),
]);
}
}, [updateMutation, selectedProject]);

View File

@@ -24,9 +24,9 @@ export const dashboardRouter = createTRPCRouter({
)
.where("projectId", "=", input.projectId)
.select(({ fn }) => [
sql<Date>`date_trunc('day', "LoggedCallModelResponse"."startTime")`.as("period"),
sql<Date>`date_trunc('day', "LoggedCallModelResponse"."requestedAt")`.as("period"),
sql<number>`count("LoggedCall"."id")::int`.as("numQueries"),
fn.sum(fn.coalesce("LoggedCallModelResponse.totalCost", sql<number>`0`)).as("totalCost"),
fn.sum(fn.coalesce("LoggedCallModelResponse.cost", sql<number>`0`)).as("cost"),
])
.groupBy("period")
.orderBy("period")
@@ -57,7 +57,7 @@ export const dashboardRouter = createTRPCRouter({
backfilledPeriods.unshift({
period: dayjs(dayToMatch).toDate(),
numQueries: 0,
totalCost: 0,
cost: 0,
});
}
dayToMatch = dayToMatch.subtract(1, "day");
@@ -72,7 +72,7 @@ export const dashboardRouter = createTRPCRouter({
)
.where("projectId", "=", input.projectId)
.select(({ fn }) => [
fn.sum(fn.coalesce("LoggedCallModelResponse.totalCost", sql<number>`0`)).as("totalCost"),
fn.sum(fn.coalesce("LoggedCallModelResponse.cost", sql<number>`0`)).as("cost"),
fn.count("LoggedCall.id").as("numQueries"),
])
.executeTakeFirst();
@@ -85,8 +85,8 @@ export const dashboardRouter = createTRPCRouter({
"LoggedCall.id",
"LoggedCallModelResponse.originalLoggedCallId",
)
.select(({ fn }) => [fn.count("LoggedCall.id").as("count"), "respStatus as code"])
.where("respStatus", ">", 200)
.select(({ fn }) => [fn.count("LoggedCall.id").as("count"), "statusCode as code"])
.where("statusCode", ">", 200)
.groupBy("code")
.orderBy("count", "desc")
.execute();
@@ -108,7 +108,7 @@ export const dashboardRouter = createTRPCRouter({
// https://discord.com/channels/966627436387266600/1122258443886153758/1122258443886153758
loggedCalls: publicProcedure.input(z.object({})).query(async ({ input }) => {
const loggedCalls = await prisma.loggedCall.findMany({
orderBy: { startTime: "desc" },
orderBy: { requestedAt: "desc" },
include: { tags: true, modelResponse: true },
take: 20,
});

View File

@@ -227,7 +227,7 @@ export const experimentsRouter = createTRPCRouter({
...modelResponseData,
id: newModelResponseId,
scenarioVariantCellId: newCellId,
output: (modelResponse.output as Prisma.InputJsonValue) ?? undefined,
respPayload: (modelResponse.respPayload as Prisma.InputJsonValue) ?? undefined,
});
for (const evaluation of outputEvaluations) {
outputEvaluationsToCreate.push({

View File

@@ -7,6 +7,11 @@ import { TRPCError } from "@trpc/server";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { hashRequest } from "~/server/utils/hashObject";
import modelProvider from "~/modelProviders/openai-ChatCompletion";
import {
type ChatCompletion,
type CompletionCreateParams,
} from "openai/resources/chat/completions";
const reqValidator = z.object({
model: z.string(),
@@ -16,11 +21,6 @@ const reqValidator = z.object({
const respValidator = z.object({
id: z.string(),
model: z.string(),
usage: z.object({
total_tokens: z.number(),
prompt_tokens: z.number(),
completion_tokens: z.number(),
}),
choices: z.array(
z.object({
finish_reason: z.string(),
@@ -76,7 +76,7 @@ export const externalApiRouter = createTRPCRouter({
originalLoggedCall: true,
},
orderBy: {
startTime: "desc",
requestedAt: "desc",
},
});
@@ -85,7 +85,7 @@ export const externalApiRouter = createTRPCRouter({
await prisma.loggedCall.create({
data: {
projectId: key.projectId,
startTime: new Date(input.startTime),
requestedAt: new Date(input.startTime),
cacheHit: true,
modelResponseId: existingResponse.id,
},
@@ -140,14 +140,20 @@ export const externalApiRouter = createTRPCRouter({
const newLoggedCallId = uuidv4();
const newModelResponseId = uuidv4();
const usage = respPayload.success ? respPayload.data.usage : undefined;
let usage;
if (reqPayload.success && respPayload.success) {
usage = modelProvider.getUsage(
input.reqPayload as CompletionCreateParams,
input.respPayload as ChatCompletion,
);
}
await prisma.$transaction([
prisma.loggedCall.create({
data: {
id: newLoggedCallId,
projectId: key.projectId,
startTime: new Date(input.startTime),
requestedAt: new Date(input.startTime),
cacheHit: false,
},
}),
@@ -155,20 +161,17 @@ export const externalApiRouter = createTRPCRouter({
data: {
id: newModelResponseId,
originalLoggedCallId: newLoggedCallId,
startTime: new Date(input.startTime),
endTime: new Date(input.endTime),
requestedAt: new Date(input.startTime),
receivedAt: new Date(input.endTime),
reqPayload: input.reqPayload as Prisma.InputJsonValue,
respPayload: input.respPayload as Prisma.InputJsonValue,
respStatus: input.respStatus,
error: input.error,
statusCode: input.respStatus,
errorMessage: input.error,
durationMs: input.endTime - input.startTime,
...(respPayload.success
? {
cacheKey: requestHash,
inputTokens: usage ? usage.prompt_tokens : undefined,
outputTokens: usage ? usage.completion_tokens : undefined,
}
: null),
cacheKey: respPayload.success ? requestHash : null,
inputTokens: usage?.inputTokens,
outputTokens: usage?.outputTokens,
cost: usage?.cost,
},
}),
// Avoid foreign key constraint error by updating the logged call after the model response is created
@@ -182,24 +185,22 @@ export const externalApiRouter = createTRPCRouter({
}),
]);
if (input.tags) {
const tagsToCreate = Object.entries(input.tags).map(([name, value]) => ({
loggedCallId: newLoggedCallId,
// sanitize tags
name: name.replaceAll(/[^a-zA-Z0-9_]/g, "_"),
value,
}));
const tagsToCreate = Object.entries(input.tags ?? {}).map(([name, value]) => ({
loggedCallId: newLoggedCallId,
// sanitize tags
name: name.replaceAll(/[^a-zA-Z0-9_]/g, "_"),
value,
}));
if (reqPayload.success) {
tagsToCreate.push({
loggedCallId: newLoggedCallId,
name: "$model",
value: reqPayload.data.model,
});
}
await prisma.loggedCallTag.createMany({
data: tagsToCreate,
if (reqPayload.success) {
tagsToCreate.push({
loggedCallId: newLoggedCallId,
name: "$model",
value: reqPayload.data.model,
});
}
await prisma.loggedCallTag.createMany({
data: tagsToCreate,
});
}),
});

View File

@@ -55,7 +55,7 @@ export const promptVariantsRouter = createTRPCRouter({
where: {
modelResponse: {
outdated: false,
output: { not: Prisma.AnyNull },
respPayload: { not: Prisma.AnyNull },
scenarioVariantCell: {
promptVariant: {
id: input.variantId,
@@ -100,7 +100,7 @@ export const promptVariantsRouter = createTRPCRouter({
modelResponses: {
some: {
outdated: false,
output: {
respPayload: {
not: Prisma.AnyNull,
},
},
@@ -111,7 +111,7 @@ export const promptVariantsRouter = createTRPCRouter({
const overallTokens = await prisma.modelResponse.aggregate({
where: {
outdated: false,
output: {
respPayload: {
not: Prisma.AnyNull,
},
scenarioVariantCell: {
@@ -123,13 +123,13 @@ export const promptVariantsRouter = createTRPCRouter({
},
_sum: {
cost: true,
promptTokens: true,
completionTokens: true,
inputTokens: true,
outputTokens: true,
},
});
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
const inputTokens = overallTokens._sum?.inputTokens ?? 0;
const outputTokens = overallTokens._sum?.outputTokens ?? 0;
const awaitingEvals = !!evalResults.find(
(result) => result.totalCount < scenarioCount * evals.length,
@@ -137,8 +137,8 @@ export const promptVariantsRouter = createTRPCRouter({
return {
evalResults,
promptTokens,
completionTokens,
inputTokens,
outputTokens,
overallCost: overallTokens._sum?.cost ?? 0,
scenarioCount,
outputCount,

View File

@@ -99,26 +99,27 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
}
: null;
const inputHash = hashObject(prompt as JsonValue);
const cacheKey = hashObject(prompt as JsonValue);
let modelResponse = await prisma.modelResponse.create({
data: {
inputHash,
cacheKey,
scenarioVariantCellId: cellId,
requestedAt: new Date(),
},
});
const response = await provider.getCompletion(prompt.modelInput, onStream);
if (response.type === "success") {
const usage = provider.getUsage(prompt.modelInput, response.value);
modelResponse = await prisma.modelResponse.update({
where: { id: modelResponse.id },
data: {
output: response.value as Prisma.InputJsonObject,
respPayload: response.value as Prisma.InputJsonObject,
statusCode: response.statusCode,
receivedAt: new Date(),
promptTokens: response.promptTokens,
completionTokens: response.completionTokens,
cost: response.cost,
inputTokens: usage?.inputTokens,
outputTokens: usage?.outputTokens,
cost: usage?.cost,
},
});

View File

@@ -51,7 +51,7 @@ export const runAllEvals = async (experimentId: string) => {
const outputs = await prisma.modelResponse.findMany({
where: {
outdated: false,
output: {
respPayload: {
not: Prisma.AnyNull,
},
scenarioVariantCell: {

View File

@@ -57,7 +57,7 @@ export const generateNewCell = async (
return;
}
const inputHash = hashObject(parsedConstructFn);
const cacheKey = hashObject(parsedConstructFn);
cell = await prisma.scenarioVariantCell.create({
data: {
@@ -73,8 +73,8 @@ export const generateNewCell = async (
const matchingModelResponse = await prisma.modelResponse.findFirst({
where: {
inputHash,
output: {
cacheKey,
respPayload: {
not: Prisma.AnyNull,
},
},
@@ -92,7 +92,7 @@ export const generateNewCell = async (
data: {
...omit(matchingModelResponse, ["id", "scenarioVariantCell"]),
scenarioVariantCellId: cell.id,
output: matchingModelResponse.output as Prisma.InputJsonValue,
respPayload: matchingModelResponse.respPayload as Prisma.InputJsonValue,
},
});

View File

@@ -71,7 +71,7 @@ export const runOneEval = async (
provider: SupportedProvider,
): Promise<{ result: number; details?: string }> => {
const modelProvider = modelProviders[provider];
const message = modelProvider.normalizeOutput(modelResponse.output);
const message = modelProvider.normalizeOutput(modelResponse.respPayload);
if (!message) return { result: 0 };