diff --git a/app/prisma/migrations/20230811200114_rename_model_response_fields/migration.sql b/app/prisma/migrations/20230811200114_rename_model_response_fields/migration.sql new file mode 100644 index 0000000..7d638ad --- /dev/null +++ b/app/prisma/migrations/20230811200114_rename_model_response_fields/migration.sql @@ -0,0 +1,66 @@ +/* + Warnings: + + - You are about to rename the column `completionTokens` to `outputTokens` on the `ModelResponse` table. + - You are about to rename the column `promptTokens` to `inputTokens` on the `ModelResponse` table. + - You are about to rename the column `startTime` on the `LoggedCall` table to `requestedAt`. Ensure compatibility with application logic. + - You are about to rename the column `startTime` on the `LoggedCallModelResponse` table to `requestedAt`. Ensure compatibility with application logic. + - You are about to rename the column `endTime` on the `LoggedCallModelResponse` table to `receivedAt`. Ensure compatibility with application logic. + - You are about to rename the column `error` on the `LoggedCallModelResponse` table to `errorMessage`. Ensure compatibility with application logic. + - You are about to rename the column `respStatus` on the `LoggedCallModelResponse` table to `statusCode`. Ensure compatibility with application logic. + - You are about to rename the column `totalCost` on the `LoggedCallModelResponse` table to `cost`. Ensure compatibility with application logic. + - You are about to rename the column `inputHash` on the `ModelResponse` table to `cacheKey`. Ensure compatibility with application logic. + - You are about to rename the column `output` on the `ModelResponse` table to `respPayload`. Ensure compatibility with application logic. + +*/ +-- DropIndex +DROP INDEX "LoggedCall_startTime_idx"; + +-- DropIndex +DROP INDEX "ModelResponse_inputHash_idx"; + +-- Rename completionTokens to outputTokens +ALTER TABLE "ModelResponse" +RENAME COLUMN "completionTokens" TO "outputTokens"; + +-- Rename promptTokens to inputTokens +ALTER TABLE "ModelResponse" +RENAME COLUMN "promptTokens" TO "inputTokens"; + +-- AlterTable +ALTER TABLE "LoggedCall" +RENAME COLUMN "startTime" TO "requestedAt"; + +-- AlterTable +ALTER TABLE "LoggedCallModelResponse" +RENAME COLUMN "startTime" TO "requestedAt"; + +-- AlterTable +ALTER TABLE "LoggedCallModelResponse" +RENAME COLUMN "endTime" TO "receivedAt"; + +-- AlterTable +ALTER TABLE "LoggedCallModelResponse" +RENAME COLUMN "error" TO "errorMessage"; + +-- AlterTable +ALTER TABLE "LoggedCallModelResponse" +RENAME COLUMN "respStatus" TO "statusCode"; + +-- AlterTable +ALTER TABLE "LoggedCallModelResponse" +RENAME COLUMN "totalCost" TO "cost"; + +-- AlterTable +ALTER TABLE "ModelResponse" +RENAME COLUMN "inputHash" TO "cacheKey"; + +-- AlterTable +ALTER TABLE "ModelResponse" +RENAME COLUMN "output" TO "respPayload"; + +-- CreateIndex +CREATE INDEX "LoggedCall_requestedAt_idx" ON "LoggedCall"("requestedAt"); + +-- CreateIndex +CREATE INDEX "ModelResponse_cacheKey_idx" ON "ModelResponse"("cacheKey"); diff --git a/app/prisma/schema.prisma b/app/prisma/schema.prisma index 1600669..59319f7 100644 --- a/app/prisma/schema.prisma +++ b/app/prisma/schema.prisma @@ -112,13 +112,13 @@ model ScenarioVariantCell { model ModelResponse { id String @id @default(uuid()) @db.Uuid - inputHash String + cacheKey String requestedAt DateTime? receivedAt DateTime? - output Json? + respPayload Json? cost Float? - promptTokens Int? - completionTokens Int? + inputTokens Int? + outputTokens Int? statusCode Int? errorMessage String? retryTime DateTime? @@ -131,7 +131,7 @@ model ModelResponse { scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade) outputEvaluations OutputEvaluation[] - @@index([inputHash]) + @@index([cacheKey]) } enum EvalType { @@ -256,7 +256,7 @@ model WorldChampEntrant { model LoggedCall { id String @id @default(uuid()) @db.Uuid - startTime DateTime + requestedAt DateTime // True if this call was served from the cache, false otherwise cacheHit Boolean @@ -278,7 +278,7 @@ model LoggedCall { createdAt DateTime @default(now()) updatedAt DateTime @updatedAt - @@index([startTime]) + @@index([requestedAt]) } model LoggedCallModelResponse { @@ -287,14 +287,14 @@ model LoggedCallModelResponse { reqPayload Json // The HTTP status returned by the model provider - respStatus Int? + statusCode Int? respPayload Json? // Should be null if the request was successful, and some string if the request failed. - error String? + errorMessage String? - startTime DateTime - endTime DateTime + requestedAt DateTime + receivedAt DateTime // Note: the function to calculate the cacheKey should include the project // ID so we don't share cached responses between projects, which could be an @@ -308,7 +308,7 @@ model LoggedCallModelResponse { outputTokens Int? finishReason String? completionId String? - totalCost Decimal? @db.Decimal(18, 12) + cost Decimal? @db.Decimal(18, 12) // The LoggedCall that created this LoggedCallModelResponse originalLoggedCallId String @unique @db.Uuid diff --git a/app/prisma/seedDashboard.ts b/app/prisma/seedDashboard.ts index 8e55a2d..a12ad6a 100644 --- a/app/prisma/seedDashboard.ts +++ b/app/prisma/seedDashboard.ts @@ -339,17 +339,17 @@ for (let i = 0; i < 1437; i++) { MODEL_RESPONSE_TEMPLATES[Math.floor(Math.random() * MODEL_RESPONSE_TEMPLATES.length)]!; const model = template.reqPayload.model; // choose random time in the last two weeks, with a bias towards the last few days - const startTime = new Date(Date.now() - Math.pow(Math.random(), 2) * 1000 * 60 * 60 * 24 * 14); + const requestedAt = new Date(Date.now() - Math.pow(Math.random(), 2) * 1000 * 60 * 60 * 24 * 14); // choose random delay anywhere from 2 to 10 seconds later for gpt-4, or 1 to 5 seconds for gpt-3.5 const delay = model === "gpt-4" ? 1000 * 2 + Math.random() * 1000 * 8 : 1000 + Math.random() * 1000 * 4; - const endTime = new Date(startTime.getTime() + delay); + const receivedAt = new Date(requestedAt.getTime() + delay); loggedCallsToCreate.push({ id: loggedCallId, cacheHit: false, - startTime, + requestedAt, projectId: project.id, - createdAt: startTime, + createdAt: requestedAt, }); const { promptTokenPrice, completionTokenPrice } = @@ -365,21 +365,20 @@ for (let i = 0; i < 1437; i++) { loggedCallModelResponsesToCreate.push({ id: loggedCallModelResponseId, - startTime, - endTime, + requestedAt, + receivedAt, originalLoggedCallId: loggedCallId, reqPayload: template.reqPayload, respPayload: template.respPayload, - respStatus: template.respStatus, - error: template.error, - createdAt: startTime, + statusCode: template.respStatus, + errorMessage: template.error, + createdAt: requestedAt, cacheKey: hashRequest(project.id, template.reqPayload as JsonValue), - durationMs: endTime.getTime() - startTime.getTime(), + durationMs: receivedAt.getTime() - requestedAt.getTime(), inputTokens: template.inputTokens, outputTokens: template.outputTokens, finishReason: template.finishReason, - totalCost: - template.inputTokens * promptTokenPrice + template.outputTokens * completionTokenPrice, + cost: template.inputTokens * promptTokenPrice + template.outputTokens * completionTokenPrice, }); loggedCallsToUpdate.push({ where: { diff --git a/app/src/components/OutputsTable/OutputCell/OutputCell.tsx b/app/src/components/OutputsTable/OutputCell/OutputCell.tsx index 4eafbb7..d679b4f 100644 --- a/app/src/components/OutputsTable/OutputCell/OutputCell.tsx +++ b/app/src/components/OutputsTable/OutputCell/OutputCell.tsx @@ -107,7 +107,7 @@ export default function OutputCell({ if (disabledReason) return {disabledReason}; - const showLogs = !streamedMessage && !mostRecentResponse?.output; + const showLogs = !streamedMessage && !mostRecentResponse?.respPayload; if (showLogs) return ( @@ -160,13 +160,13 @@ export default function OutputCell({ ); - const normalizedOutput = mostRecentResponse?.output - ? provider.normalizeOutput(mostRecentResponse?.output) + const normalizedOutput = mostRecentResponse?.respPayload + ? provider.normalizeOutput(mostRecentResponse?.respPayload) : streamedMessage ? provider.normalizeOutput(streamedMessage) : null; - if (mostRecentResponse?.output && normalizedOutput?.type === "json") { + if (mostRecentResponse?.respPayload && normalizedOutput?.type === "json") { return ( {modelResponse.cost && ( diff --git a/app/src/components/OutputsTable/VariantStats.tsx b/app/src/components/OutputsTable/VariantStats.tsx index 40bae1d..8c81fdb 100644 --- a/app/src/components/OutputsTable/VariantStats.tsx +++ b/app/src/components/OutputsTable/VariantStats.tsx @@ -17,8 +17,8 @@ export default function VariantStats(props: { variant: PromptVariant }) { initialData: { evalResults: [], overallCost: 0, - promptTokens: 0, - completionTokens: 0, + inputTokens: 0, + outputTokens: 0, scenarioCount: 0, outputCount: 0, awaitingEvals: false, @@ -68,8 +68,8 @@ export default function VariantStats(props: { variant: PromptVariant }) { {data.overallCost && ( diff --git a/app/src/components/dashboard/LoggedCallTable.tsx b/app/src/components/dashboard/LoggedCallTable.tsx index 70aa7a2..10ed790 100644 --- a/app/src/components/dashboard/LoggedCallTable.tsx +++ b/app/src/components/dashboard/LoggedCallTable.tsx @@ -90,9 +90,9 @@ function TableRow({ isExpanded: boolean; onToggle: () => void; }) { - const isError = loggedCall.modelResponse?.respStatus !== 200; - const timeAgo = dayjs(loggedCall.startTime).fromNow(); - const fullTime = dayjs(loggedCall.startTime).toString(); + const isError = loggedCall.modelResponse?.statusCode !== 200; + const timeAgo = dayjs(loggedCall.requestedAt).fromNow(); + const fullTime = dayjs(loggedCall.requestedAt).toString(); const model = useMemo( () => loggedCall.tags.find((tag) => tag.name.startsWith("$model"))?.value, @@ -124,7 +124,7 @@ function TableRow({ {loggedCall.modelResponse?.inputTokens} {loggedCall.modelResponse?.outputTokens} - {loggedCall.modelResponse?.respStatus ?? "No response"} + {loggedCall.modelResponse?.statusCode ?? "No response"} diff --git a/app/src/components/dashboard/UsageGraph.tsx b/app/src/components/dashboard/UsageGraph.tsx new file mode 100644 index 0000000..dc3faa4 --- /dev/null +++ b/app/src/components/dashboard/UsageGraph.tsx @@ -0,0 +1,61 @@ +import { + ResponsiveContainer, + LineChart, + Line, + XAxis, + YAxis, + CartesianGrid, + Tooltip, + Legend, +} from "recharts"; +import { useMemo } from "react"; + +import { useSelectedProject } from "~/utils/hooks"; +import dayjs from "~/utils/dayjs"; +import { api } from "~/utils/api"; + +export default function UsageGraph() { + const { data: selectedProject } = useSelectedProject(); + + const stats = api.dashboard.stats.useQuery( + { projectId: selectedProject?.id ?? "" }, + { enabled: !!selectedProject }, + ); + + const data = useMemo(() => { + return ( + stats.data?.periods.map(({ period, numQueries, cost }) => ({ + period, + Requests: numQueries, + "Total Spent (USD)": parseFloat(cost.toString()), + })) || [] + ); + }, [stats.data]); + + return ( + + + dayjs(str).format("MMM D")} /> + + + + + + + + + + ); +} diff --git a/app/src/components/tooltip/CostTooltip.tsx b/app/src/components/tooltip/CostTooltip.tsx index 68cf3ea..0e2cd17 100644 --- a/app/src/components/tooltip/CostTooltip.tsx +++ b/app/src/components/tooltip/CostTooltip.tsx @@ -2,14 +2,14 @@ import { HStack, Icon, Text, Tooltip, type TooltipProps, VStack, Divider } from import { BsCurrencyDollar } from "react-icons/bs"; type CostTooltipProps = { - promptTokens: number | null; - completionTokens: number | null; + inputTokens: number | null; + outputTokens: number | null; cost: number; } & TooltipProps; export const CostTooltip = ({ - promptTokens, - completionTokens, + inputTokens, + outputTokens, cost, children, ...props @@ -36,12 +36,12 @@ export const CostTooltip = ({ Prompt - {promptTokens ?? 0} + {inputTokens ?? 0} Completion - {completionTokens ?? 0} + {outputTokens ?? 0} diff --git a/app/src/modelProviders/anthropic-completion/index.ts b/app/src/modelProviders/anthropic-completion/index.ts index 3b2d670..314d2ab 100644 --- a/app/src/modelProviders/anthropic-completion/index.ts +++ b/app/src/modelProviders/anthropic-completion/index.ts @@ -28,6 +28,10 @@ const modelProvider: AnthropicProvider = { inputSchema: inputSchema as JSONSchema4, canStream: true, getCompletion, + getUsage: (input, output) => { + // TODO: add usage logic + return null; + }, ...frontendModelProvider, }; diff --git a/app/src/modelProviders/openai-ChatCompletion/getCompletion.ts b/app/src/modelProviders/openai-ChatCompletion/getCompletion.ts index 462f2fa..e6f5123 100644 --- a/app/src/modelProviders/openai-ChatCompletion/getCompletion.ts +++ b/app/src/modelProviders/openai-ChatCompletion/getCompletion.ts @@ -4,14 +4,10 @@ import { type ChatCompletion, type CompletionCreateParams, } from "openai/resources/chat"; -import { countOpenAIChatTokens } from "~/utils/countTokens"; import { type CompletionResponse } from "../types"; import { isArray, isString, omit } from "lodash-es"; import { openai } from "~/server/utils/openai"; -import { truthyFilter } from "~/utils/utils"; import { APIError } from "openai"; -import frontendModelProvider from "./frontend"; -import modelProvider, { type SupportedModel } from "."; const mergeStreamedChunks = ( base: ChatCompletion | null, @@ -60,9 +56,6 @@ export async function getCompletion( ): Promise> { const start = Date.now(); let finalCompletion: ChatCompletion | null = null; - let promptTokens: number | undefined = undefined; - let completionTokens: number | undefined = undefined; - const modelName = modelProvider.getModel(input) as SupportedModel; try { if (onStream) { @@ -86,16 +79,6 @@ export async function getCompletion( autoRetry: false, }; } - try { - promptTokens = countOpenAIChatTokens(modelName, input.messages); - completionTokens = countOpenAIChatTokens( - modelName, - finalCompletion.choices.map((c) => c.message).filter(truthyFilter), - ); - } catch (err) { - // TODO handle this, library seems like maybe it doesn't work with function calls? - console.error(err); - } } else { const resp = await openai.chat.completions.create( { ...input, stream: false }, @@ -104,25 +87,14 @@ export async function getCompletion( }, ); finalCompletion = resp; - promptTokens = resp.usage?.prompt_tokens ?? 0; - completionTokens = resp.usage?.completion_tokens ?? 0; } const timeToComplete = Date.now() - start; - const { promptTokenPrice, completionTokenPrice } = frontendModelProvider.models[modelName]; - let cost = undefined; - if (promptTokenPrice && completionTokenPrice && promptTokens && completionTokens) { - cost = promptTokens * promptTokenPrice + completionTokens * completionTokenPrice; - } - return { type: "success", statusCode: 200, value: finalCompletion, timeToComplete, - promptTokens, - completionTokens, - cost, }; } catch (error: unknown) { if (error instanceof APIError) { diff --git a/app/src/modelProviders/openai-ChatCompletion/index.ts b/app/src/modelProviders/openai-ChatCompletion/index.ts index 2b4e90c..b5dc00f 100644 --- a/app/src/modelProviders/openai-ChatCompletion/index.ts +++ b/app/src/modelProviders/openai-ChatCompletion/index.ts @@ -4,6 +4,8 @@ import inputSchema from "./codegen/input.schema.json"; import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat"; import { getCompletion } from "./getCompletion"; import frontendModelProvider from "./frontend"; +import { countOpenAIChatTokens } from "~/utils/countTokens"; +import { truthyFilter } from "~/utils/utils"; const supportedModels = [ "gpt-4-0613", @@ -39,6 +41,41 @@ const modelProvider: OpenaiChatModelProvider = { inputSchema: inputSchema as JSONSchema4, canStream: true, getCompletion, + getUsage: (input, output) => { + if (output.choices.length === 0) return null; + + const model = modelProvider.getModel(input); + if (!model) return null; + + let inputTokens: number; + let outputTokens: number; + + if (output.usage) { + inputTokens = output.usage.prompt_tokens; + outputTokens = output.usage.completion_tokens; + } else { + try { + inputTokens = countOpenAIChatTokens(model, input.messages); + outputTokens = countOpenAIChatTokens( + model, + output.choices.map((c) => c.message).filter(truthyFilter), + ); + } catch (err) { + inputTokens = 0; + outputTokens = 0; + // TODO handle this, library seems like maybe it doesn't work with function calls? + console.error(err); + } + } + + const { promptTokenPrice, completionTokenPrice } = frontendModelProvider.models[model]; + let cost = undefined; + if (promptTokenPrice && completionTokenPrice && inputTokens && outputTokens) { + cost = inputTokens * promptTokenPrice + outputTokens * completionTokenPrice; + } + + return { inputTokens: inputTokens, outputTokens: outputTokens, cost }; + }, ...frontendModelProvider, }; diff --git a/app/src/modelProviders/replicate-llama2/index.ts b/app/src/modelProviders/replicate-llama2/index.ts index b2397ce..9e68eca 100644 --- a/app/src/modelProviders/replicate-llama2/index.ts +++ b/app/src/modelProviders/replicate-llama2/index.ts @@ -75,6 +75,10 @@ const modelProvider: ReplicateLlama2Provider = { }, canStream: true, getCompletion, + getUsage: (input, output) => { + // TODO: add usage logic + return null; + }, ...frontendModelProvider, }; diff --git a/app/src/modelProviders/types.ts b/app/src/modelProviders/types.ts index 5e5bf26..6b5e09e 100644 --- a/app/src/modelProviders/types.ts +++ b/app/src/modelProviders/types.ts @@ -43,9 +43,6 @@ export type CompletionResponse = value: T; timeToComplete: number; statusCode: number; - promptTokens?: number; - completionTokens?: number; - cost?: number; }; export type ModelProvider = { @@ -56,6 +53,10 @@ export type ModelProvider void) | null, ) => Promise>; + getUsage: ( + input: InputSchema, + output: OutputSchema, + ) => { gpuRuntime?: number; inputTokens?: number; outputTokens?: number; cost?: number } | null; // This is just a convenience for type inference, don't use it at runtime _outputSchema?: OutputSchema | null; diff --git a/app/src/pages/logged-calls/index.tsx b/app/src/pages/logged-calls/index.tsx index 8acb8ea..f19f90b 100644 --- a/app/src/pages/logged-calls/index.tsx +++ b/app/src/pages/logged-calls/index.tsx @@ -18,26 +18,15 @@ import { Breadcrumb, BreadcrumbItem, } from "@chakra-ui/react"; -import { - LineChart, - Line, - XAxis, - YAxis, - CartesianGrid, - Tooltip, - Legend, - ResponsiveContainer, -} from "recharts"; import { Ban, DollarSign, Hash } from "lucide-react"; -import { useMemo } from "react"; import AppShell from "~/components/nav/AppShell"; import PageHeaderContainer from "~/components/nav/PageHeaderContainer"; import ProjectBreadcrumbContents from "~/components/nav/ProjectBreadcrumbContents"; import { useSelectedProject } from "~/utils/hooks"; -import dayjs from "~/utils/dayjs"; import { api } from "~/utils/api"; import LoggedCallTable from "~/components/dashboard/LoggedCallTable"; +import UsageGraph from "~/components/dashboard/UsageGraph"; export default function LoggedCalls() { const { data: selectedProject } = useSelectedProject(); @@ -47,16 +36,6 @@ export default function LoggedCalls() { { enabled: !!selectedProject }, ); - const data = useMemo(() => { - return ( - stats.data?.periods.map(({ period, numQueries, totalCost }) => ({ - period, - Requests: numQueries, - "Total Spent (USD)": parseFloat(totalCost.toString()), - })) || [] - ); - }, [stats.data]); - return ( @@ -83,39 +62,7 @@ export default function LoggedCalls() { - - - dayjs(str).format("MMM D")} - /> - - - - - - - - - + @@ -127,7 +74,7 @@ export default function LoggedCalls() { - ${parseFloat(stats.data?.totals?.totalCost?.toString() ?? "0").toFixed(2)} + ${parseFloat(stats.data?.totals?.cost?.toString() ?? "0").toFixed(3)} diff --git a/app/src/pages/project/settings/index.tsx b/app/src/pages/project/settings/index.tsx index 0896f32..7fbaf4d 100644 --- a/app/src/pages/project/settings/index.tsx +++ b/app/src/pages/project/settings/index.tsx @@ -38,7 +38,10 @@ export default function Settings() { id: selectedProject.id, updates: { name }, }); - await Promise.all([utils.projects.get.invalidate({ id: selectedProject.id })]); + await Promise.all([ + utils.projects.get.invalidate({ id: selectedProject.id }), + utils.projects.list.invalidate(), + ]); } }, [updateMutation, selectedProject]); diff --git a/app/src/server/api/routers/dashboard.router.ts b/app/src/server/api/routers/dashboard.router.ts index 6c9be03..8508e27 100644 --- a/app/src/server/api/routers/dashboard.router.ts +++ b/app/src/server/api/routers/dashboard.router.ts @@ -24,9 +24,9 @@ export const dashboardRouter = createTRPCRouter({ ) .where("projectId", "=", input.projectId) .select(({ fn }) => [ - sql`date_trunc('day', "LoggedCallModelResponse"."startTime")`.as("period"), + sql`date_trunc('day', "LoggedCallModelResponse"."requestedAt")`.as("period"), sql`count("LoggedCall"."id")::int`.as("numQueries"), - fn.sum(fn.coalesce("LoggedCallModelResponse.totalCost", sql`0`)).as("totalCost"), + fn.sum(fn.coalesce("LoggedCallModelResponse.cost", sql`0`)).as("cost"), ]) .groupBy("period") .orderBy("period") @@ -57,7 +57,7 @@ export const dashboardRouter = createTRPCRouter({ backfilledPeriods.unshift({ period: dayjs(dayToMatch).toDate(), numQueries: 0, - totalCost: 0, + cost: 0, }); } dayToMatch = dayToMatch.subtract(1, "day"); @@ -72,7 +72,7 @@ export const dashboardRouter = createTRPCRouter({ ) .where("projectId", "=", input.projectId) .select(({ fn }) => [ - fn.sum(fn.coalesce("LoggedCallModelResponse.totalCost", sql`0`)).as("totalCost"), + fn.sum(fn.coalesce("LoggedCallModelResponse.cost", sql`0`)).as("cost"), fn.count("LoggedCall.id").as("numQueries"), ]) .executeTakeFirst(); @@ -85,8 +85,8 @@ export const dashboardRouter = createTRPCRouter({ "LoggedCall.id", "LoggedCallModelResponse.originalLoggedCallId", ) - .select(({ fn }) => [fn.count("LoggedCall.id").as("count"), "respStatus as code"]) - .where("respStatus", ">", 200) + .select(({ fn }) => [fn.count("LoggedCall.id").as("count"), "statusCode as code"]) + .where("statusCode", ">", 200) .groupBy("code") .orderBy("count", "desc") .execute(); @@ -108,7 +108,7 @@ export const dashboardRouter = createTRPCRouter({ // https://discord.com/channels/966627436387266600/1122258443886153758/1122258443886153758 loggedCalls: publicProcedure.input(z.object({})).query(async ({ input }) => { const loggedCalls = await prisma.loggedCall.findMany({ - orderBy: { startTime: "desc" }, + orderBy: { requestedAt: "desc" }, include: { tags: true, modelResponse: true }, take: 20, }); diff --git a/app/src/server/api/routers/experiments.router.ts b/app/src/server/api/routers/experiments.router.ts index 17dd991..f707a9b 100644 --- a/app/src/server/api/routers/experiments.router.ts +++ b/app/src/server/api/routers/experiments.router.ts @@ -227,7 +227,7 @@ export const experimentsRouter = createTRPCRouter({ ...modelResponseData, id: newModelResponseId, scenarioVariantCellId: newCellId, - output: (modelResponse.output as Prisma.InputJsonValue) ?? undefined, + respPayload: (modelResponse.respPayload as Prisma.InputJsonValue) ?? undefined, }); for (const evaluation of outputEvaluations) { outputEvaluationsToCreate.push({ diff --git a/app/src/server/api/routers/externalApi.router.ts b/app/src/server/api/routers/externalApi.router.ts index 5eff9dd..aecbbbc 100644 --- a/app/src/server/api/routers/externalApi.router.ts +++ b/app/src/server/api/routers/externalApi.router.ts @@ -7,6 +7,11 @@ import { TRPCError } from "@trpc/server"; import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { prisma } from "~/server/db"; import { hashRequest } from "~/server/utils/hashObject"; +import modelProvider from "~/modelProviders/openai-ChatCompletion"; +import { + type ChatCompletion, + type CompletionCreateParams, +} from "openai/resources/chat/completions"; const reqValidator = z.object({ model: z.string(), @@ -16,11 +21,6 @@ const reqValidator = z.object({ const respValidator = z.object({ id: z.string(), model: z.string(), - usage: z.object({ - total_tokens: z.number(), - prompt_tokens: z.number(), - completion_tokens: z.number(), - }), choices: z.array( z.object({ finish_reason: z.string(), @@ -76,7 +76,7 @@ export const externalApiRouter = createTRPCRouter({ originalLoggedCall: true, }, orderBy: { - startTime: "desc", + requestedAt: "desc", }, }); @@ -85,7 +85,7 @@ export const externalApiRouter = createTRPCRouter({ await prisma.loggedCall.create({ data: { projectId: key.projectId, - startTime: new Date(input.startTime), + requestedAt: new Date(input.startTime), cacheHit: true, modelResponseId: existingResponse.id, }, @@ -140,14 +140,20 @@ export const externalApiRouter = createTRPCRouter({ const newLoggedCallId = uuidv4(); const newModelResponseId = uuidv4(); - const usage = respPayload.success ? respPayload.data.usage : undefined; + let usage; + if (reqPayload.success && respPayload.success) { + usage = modelProvider.getUsage( + input.reqPayload as CompletionCreateParams, + input.respPayload as ChatCompletion, + ); + } await prisma.$transaction([ prisma.loggedCall.create({ data: { id: newLoggedCallId, projectId: key.projectId, - startTime: new Date(input.startTime), + requestedAt: new Date(input.startTime), cacheHit: false, }, }), @@ -155,20 +161,17 @@ export const externalApiRouter = createTRPCRouter({ data: { id: newModelResponseId, originalLoggedCallId: newLoggedCallId, - startTime: new Date(input.startTime), - endTime: new Date(input.endTime), + requestedAt: new Date(input.startTime), + receivedAt: new Date(input.endTime), reqPayload: input.reqPayload as Prisma.InputJsonValue, respPayload: input.respPayload as Prisma.InputJsonValue, - respStatus: input.respStatus, - error: input.error, + statusCode: input.respStatus, + errorMessage: input.error, durationMs: input.endTime - input.startTime, - ...(respPayload.success - ? { - cacheKey: requestHash, - inputTokens: usage ? usage.prompt_tokens : undefined, - outputTokens: usage ? usage.completion_tokens : undefined, - } - : null), + cacheKey: respPayload.success ? requestHash : null, + inputTokens: usage?.inputTokens, + outputTokens: usage?.outputTokens, + cost: usage?.cost, }, }), // Avoid foreign key constraint error by updating the logged call after the model response is created @@ -182,24 +185,22 @@ export const externalApiRouter = createTRPCRouter({ }), ]); - if (input.tags) { - const tagsToCreate = Object.entries(input.tags).map(([name, value]) => ({ - loggedCallId: newLoggedCallId, - // sanitize tags - name: name.replaceAll(/[^a-zA-Z0-9_]/g, "_"), - value, - })); + const tagsToCreate = Object.entries(input.tags ?? {}).map(([name, value]) => ({ + loggedCallId: newLoggedCallId, + // sanitize tags + name: name.replaceAll(/[^a-zA-Z0-9_]/g, "_"), + value, + })); - if (reqPayload.success) { - tagsToCreate.push({ - loggedCallId: newLoggedCallId, - name: "$model", - value: reqPayload.data.model, - }); - } - await prisma.loggedCallTag.createMany({ - data: tagsToCreate, + if (reqPayload.success) { + tagsToCreate.push({ + loggedCallId: newLoggedCallId, + name: "$model", + value: reqPayload.data.model, }); } + await prisma.loggedCallTag.createMany({ + data: tagsToCreate, + }); }), }); diff --git a/app/src/server/api/routers/promptVariants.router.ts b/app/src/server/api/routers/promptVariants.router.ts index f19bbd0..7a01da1 100644 --- a/app/src/server/api/routers/promptVariants.router.ts +++ b/app/src/server/api/routers/promptVariants.router.ts @@ -55,7 +55,7 @@ export const promptVariantsRouter = createTRPCRouter({ where: { modelResponse: { outdated: false, - output: { not: Prisma.AnyNull }, + respPayload: { not: Prisma.AnyNull }, scenarioVariantCell: { promptVariant: { id: input.variantId, @@ -100,7 +100,7 @@ export const promptVariantsRouter = createTRPCRouter({ modelResponses: { some: { outdated: false, - output: { + respPayload: { not: Prisma.AnyNull, }, }, @@ -111,7 +111,7 @@ export const promptVariantsRouter = createTRPCRouter({ const overallTokens = await prisma.modelResponse.aggregate({ where: { outdated: false, - output: { + respPayload: { not: Prisma.AnyNull, }, scenarioVariantCell: { @@ -123,13 +123,13 @@ export const promptVariantsRouter = createTRPCRouter({ }, _sum: { cost: true, - promptTokens: true, - completionTokens: true, + inputTokens: true, + outputTokens: true, }, }); - const promptTokens = overallTokens._sum?.promptTokens ?? 0; - const completionTokens = overallTokens._sum?.completionTokens ?? 0; + const inputTokens = overallTokens._sum?.inputTokens ?? 0; + const outputTokens = overallTokens._sum?.outputTokens ?? 0; const awaitingEvals = !!evalResults.find( (result) => result.totalCount < scenarioCount * evals.length, @@ -137,8 +137,8 @@ export const promptVariantsRouter = createTRPCRouter({ return { evalResults, - promptTokens, - completionTokens, + inputTokens, + outputTokens, overallCost: overallTokens._sum?.cost ?? 0, scenarioCount, outputCount, diff --git a/app/src/server/tasks/queryModel.task.ts b/app/src/server/tasks/queryModel.task.ts index d7a5dc8..4580f2c 100644 --- a/app/src/server/tasks/queryModel.task.ts +++ b/app/src/server/tasks/queryModel.task.ts @@ -99,26 +99,27 @@ export const queryModel = defineTask("queryModel", async (task) = } : null; - const inputHash = hashObject(prompt as JsonValue); + const cacheKey = hashObject(prompt as JsonValue); let modelResponse = await prisma.modelResponse.create({ data: { - inputHash, + cacheKey, scenarioVariantCellId: cellId, requestedAt: new Date(), }, }); const response = await provider.getCompletion(prompt.modelInput, onStream); if (response.type === "success") { + const usage = provider.getUsage(prompt.modelInput, response.value); modelResponse = await prisma.modelResponse.update({ where: { id: modelResponse.id }, data: { - output: response.value as Prisma.InputJsonObject, + respPayload: response.value as Prisma.InputJsonObject, statusCode: response.statusCode, receivedAt: new Date(), - promptTokens: response.promptTokens, - completionTokens: response.completionTokens, - cost: response.cost, + inputTokens: usage?.inputTokens, + outputTokens: usage?.outputTokens, + cost: usage?.cost, }, }); diff --git a/app/src/server/utils/evaluations.ts b/app/src/server/utils/evaluations.ts index 9259d91..f3039f0 100644 --- a/app/src/server/utils/evaluations.ts +++ b/app/src/server/utils/evaluations.ts @@ -51,7 +51,7 @@ export const runAllEvals = async (experimentId: string) => { const outputs = await prisma.modelResponse.findMany({ where: { outdated: false, - output: { + respPayload: { not: Prisma.AnyNull, }, scenarioVariantCell: { diff --git a/app/src/server/utils/generateNewCell.ts b/app/src/server/utils/generateNewCell.ts index 858781e..678740d 100644 --- a/app/src/server/utils/generateNewCell.ts +++ b/app/src/server/utils/generateNewCell.ts @@ -57,7 +57,7 @@ export const generateNewCell = async ( return; } - const inputHash = hashObject(parsedConstructFn); + const cacheKey = hashObject(parsedConstructFn); cell = await prisma.scenarioVariantCell.create({ data: { @@ -73,8 +73,8 @@ export const generateNewCell = async ( const matchingModelResponse = await prisma.modelResponse.findFirst({ where: { - inputHash, - output: { + cacheKey, + respPayload: { not: Prisma.AnyNull, }, }, @@ -92,7 +92,7 @@ export const generateNewCell = async ( data: { ...omit(matchingModelResponse, ["id", "scenarioVariantCell"]), scenarioVariantCellId: cell.id, - output: matchingModelResponse.output as Prisma.InputJsonValue, + respPayload: matchingModelResponse.respPayload as Prisma.InputJsonValue, }, }); diff --git a/app/src/server/utils/runOneEval.ts b/app/src/server/utils/runOneEval.ts index a65f417..87f4664 100644 --- a/app/src/server/utils/runOneEval.ts +++ b/app/src/server/utils/runOneEval.ts @@ -71,7 +71,7 @@ export const runOneEval = async ( provider: SupportedProvider, ): Promise<{ result: number; details?: string }> => { const modelProvider = modelProviders[provider]; - const message = modelProvider.normalizeOutput(modelResponse.output); + const message = modelProvider.normalizeOutput(modelResponse.respPayload); if (!message) return { result: 0 };