Record model and cost when reporting logs (#136)
* Rename prompt and completion tokens to input and output tokens * Add getUsage function * Record model and cost when reporting log * Remove unused imports * Move UsageGraph to its own component * Standardize model response fields * Fix types
This commit is contained in:
@@ -0,0 +1,66 @@
|
|||||||
|
/*
|
||||||
|
Warnings:
|
||||||
|
|
||||||
|
- You are about to rename the column `completionTokens` to `outputTokens` on the `ModelResponse` table.
|
||||||
|
- You are about to rename the column `promptTokens` to `inputTokens` on the `ModelResponse` table.
|
||||||
|
- You are about to rename the column `startTime` on the `LoggedCall` table to `requestedAt`. Ensure compatibility with application logic.
|
||||||
|
- You are about to rename the column `startTime` on the `LoggedCallModelResponse` table to `requestedAt`. Ensure compatibility with application logic.
|
||||||
|
- You are about to rename the column `endTime` on the `LoggedCallModelResponse` table to `receivedAt`. Ensure compatibility with application logic.
|
||||||
|
- You are about to rename the column `error` on the `LoggedCallModelResponse` table to `errorMessage`. Ensure compatibility with application logic.
|
||||||
|
- You are about to rename the column `respStatus` on the `LoggedCallModelResponse` table to `statusCode`. Ensure compatibility with application logic.
|
||||||
|
- You are about to rename the column `totalCost` on the `LoggedCallModelResponse` table to `cost`. Ensure compatibility with application logic.
|
||||||
|
- You are about to rename the column `inputHash` on the `ModelResponse` table to `cacheKey`. Ensure compatibility with application logic.
|
||||||
|
- You are about to rename the column `output` on the `ModelResponse` table to `respPayload`. Ensure compatibility with application logic.
|
||||||
|
|
||||||
|
*/
|
||||||
|
-- DropIndex
|
||||||
|
DROP INDEX "LoggedCall_startTime_idx";
|
||||||
|
|
||||||
|
-- DropIndex
|
||||||
|
DROP INDEX "ModelResponse_inputHash_idx";
|
||||||
|
|
||||||
|
-- Rename completionTokens to outputTokens
|
||||||
|
ALTER TABLE "ModelResponse"
|
||||||
|
RENAME COLUMN "completionTokens" TO "outputTokens";
|
||||||
|
|
||||||
|
-- Rename promptTokens to inputTokens
|
||||||
|
ALTER TABLE "ModelResponse"
|
||||||
|
RENAME COLUMN "promptTokens" TO "inputTokens";
|
||||||
|
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "LoggedCall"
|
||||||
|
RENAME COLUMN "startTime" TO "requestedAt";
|
||||||
|
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "LoggedCallModelResponse"
|
||||||
|
RENAME COLUMN "startTime" TO "requestedAt";
|
||||||
|
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "LoggedCallModelResponse"
|
||||||
|
RENAME COLUMN "endTime" TO "receivedAt";
|
||||||
|
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "LoggedCallModelResponse"
|
||||||
|
RENAME COLUMN "error" TO "errorMessage";
|
||||||
|
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "LoggedCallModelResponse"
|
||||||
|
RENAME COLUMN "respStatus" TO "statusCode";
|
||||||
|
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "LoggedCallModelResponse"
|
||||||
|
RENAME COLUMN "totalCost" TO "cost";
|
||||||
|
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "ModelResponse"
|
||||||
|
RENAME COLUMN "inputHash" TO "cacheKey";
|
||||||
|
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "ModelResponse"
|
||||||
|
RENAME COLUMN "output" TO "respPayload";
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE INDEX "LoggedCall_requestedAt_idx" ON "LoggedCall"("requestedAt");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE INDEX "ModelResponse_cacheKey_idx" ON "ModelResponse"("cacheKey");
|
||||||
@@ -112,13 +112,13 @@ model ScenarioVariantCell {
|
|||||||
model ModelResponse {
|
model ModelResponse {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
inputHash String
|
cacheKey String
|
||||||
requestedAt DateTime?
|
requestedAt DateTime?
|
||||||
receivedAt DateTime?
|
receivedAt DateTime?
|
||||||
output Json?
|
respPayload Json?
|
||||||
cost Float?
|
cost Float?
|
||||||
promptTokens Int?
|
inputTokens Int?
|
||||||
completionTokens Int?
|
outputTokens Int?
|
||||||
statusCode Int?
|
statusCode Int?
|
||||||
errorMessage String?
|
errorMessage String?
|
||||||
retryTime DateTime?
|
retryTime DateTime?
|
||||||
@@ -131,7 +131,7 @@ model ModelResponse {
|
|||||||
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
|
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
|
||||||
outputEvaluations OutputEvaluation[]
|
outputEvaluations OutputEvaluation[]
|
||||||
|
|
||||||
@@index([inputHash])
|
@@index([cacheKey])
|
||||||
}
|
}
|
||||||
|
|
||||||
enum EvalType {
|
enum EvalType {
|
||||||
@@ -256,7 +256,7 @@ model WorldChampEntrant {
|
|||||||
model LoggedCall {
|
model LoggedCall {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
startTime DateTime
|
requestedAt DateTime
|
||||||
|
|
||||||
// True if this call was served from the cache, false otherwise
|
// True if this call was served from the cache, false otherwise
|
||||||
cacheHit Boolean
|
cacheHit Boolean
|
||||||
@@ -278,7 +278,7 @@ model LoggedCall {
|
|||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
@@index([startTime])
|
@@index([requestedAt])
|
||||||
}
|
}
|
||||||
|
|
||||||
model LoggedCallModelResponse {
|
model LoggedCallModelResponse {
|
||||||
@@ -287,14 +287,14 @@ model LoggedCallModelResponse {
|
|||||||
reqPayload Json
|
reqPayload Json
|
||||||
|
|
||||||
// The HTTP status returned by the model provider
|
// The HTTP status returned by the model provider
|
||||||
respStatus Int?
|
statusCode Int?
|
||||||
respPayload Json?
|
respPayload Json?
|
||||||
|
|
||||||
// Should be null if the request was successful, and some string if the request failed.
|
// Should be null if the request was successful, and some string if the request failed.
|
||||||
error String?
|
errorMessage String?
|
||||||
|
|
||||||
startTime DateTime
|
requestedAt DateTime
|
||||||
endTime DateTime
|
receivedAt DateTime
|
||||||
|
|
||||||
// Note: the function to calculate the cacheKey should include the project
|
// Note: the function to calculate the cacheKey should include the project
|
||||||
// ID so we don't share cached responses between projects, which could be an
|
// ID so we don't share cached responses between projects, which could be an
|
||||||
@@ -308,7 +308,7 @@ model LoggedCallModelResponse {
|
|||||||
outputTokens Int?
|
outputTokens Int?
|
||||||
finishReason String?
|
finishReason String?
|
||||||
completionId String?
|
completionId String?
|
||||||
totalCost Decimal? @db.Decimal(18, 12)
|
cost Decimal? @db.Decimal(18, 12)
|
||||||
|
|
||||||
// The LoggedCall that created this LoggedCallModelResponse
|
// The LoggedCall that created this LoggedCallModelResponse
|
||||||
originalLoggedCallId String @unique @db.Uuid
|
originalLoggedCallId String @unique @db.Uuid
|
||||||
|
|||||||
@@ -339,17 +339,17 @@ for (let i = 0; i < 1437; i++) {
|
|||||||
MODEL_RESPONSE_TEMPLATES[Math.floor(Math.random() * MODEL_RESPONSE_TEMPLATES.length)]!;
|
MODEL_RESPONSE_TEMPLATES[Math.floor(Math.random() * MODEL_RESPONSE_TEMPLATES.length)]!;
|
||||||
const model = template.reqPayload.model;
|
const model = template.reqPayload.model;
|
||||||
// choose random time in the last two weeks, with a bias towards the last few days
|
// choose random time in the last two weeks, with a bias towards the last few days
|
||||||
const startTime = new Date(Date.now() - Math.pow(Math.random(), 2) * 1000 * 60 * 60 * 24 * 14);
|
const requestedAt = new Date(Date.now() - Math.pow(Math.random(), 2) * 1000 * 60 * 60 * 24 * 14);
|
||||||
// choose random delay anywhere from 2 to 10 seconds later for gpt-4, or 1 to 5 seconds for gpt-3.5
|
// choose random delay anywhere from 2 to 10 seconds later for gpt-4, or 1 to 5 seconds for gpt-3.5
|
||||||
const delay =
|
const delay =
|
||||||
model === "gpt-4" ? 1000 * 2 + Math.random() * 1000 * 8 : 1000 + Math.random() * 1000 * 4;
|
model === "gpt-4" ? 1000 * 2 + Math.random() * 1000 * 8 : 1000 + Math.random() * 1000 * 4;
|
||||||
const endTime = new Date(startTime.getTime() + delay);
|
const receivedAt = new Date(requestedAt.getTime() + delay);
|
||||||
loggedCallsToCreate.push({
|
loggedCallsToCreate.push({
|
||||||
id: loggedCallId,
|
id: loggedCallId,
|
||||||
cacheHit: false,
|
cacheHit: false,
|
||||||
startTime,
|
requestedAt,
|
||||||
projectId: project.id,
|
projectId: project.id,
|
||||||
createdAt: startTime,
|
createdAt: requestedAt,
|
||||||
});
|
});
|
||||||
|
|
||||||
const { promptTokenPrice, completionTokenPrice } =
|
const { promptTokenPrice, completionTokenPrice } =
|
||||||
@@ -365,21 +365,20 @@ for (let i = 0; i < 1437; i++) {
|
|||||||
|
|
||||||
loggedCallModelResponsesToCreate.push({
|
loggedCallModelResponsesToCreate.push({
|
||||||
id: loggedCallModelResponseId,
|
id: loggedCallModelResponseId,
|
||||||
startTime,
|
requestedAt,
|
||||||
endTime,
|
receivedAt,
|
||||||
originalLoggedCallId: loggedCallId,
|
originalLoggedCallId: loggedCallId,
|
||||||
reqPayload: template.reqPayload,
|
reqPayload: template.reqPayload,
|
||||||
respPayload: template.respPayload,
|
respPayload: template.respPayload,
|
||||||
respStatus: template.respStatus,
|
statusCode: template.respStatus,
|
||||||
error: template.error,
|
errorMessage: template.error,
|
||||||
createdAt: startTime,
|
createdAt: requestedAt,
|
||||||
cacheKey: hashRequest(project.id, template.reqPayload as JsonValue),
|
cacheKey: hashRequest(project.id, template.reqPayload as JsonValue),
|
||||||
durationMs: endTime.getTime() - startTime.getTime(),
|
durationMs: receivedAt.getTime() - requestedAt.getTime(),
|
||||||
inputTokens: template.inputTokens,
|
inputTokens: template.inputTokens,
|
||||||
outputTokens: template.outputTokens,
|
outputTokens: template.outputTokens,
|
||||||
finishReason: template.finishReason,
|
finishReason: template.finishReason,
|
||||||
totalCost:
|
cost: template.inputTokens * promptTokenPrice + template.outputTokens * completionTokenPrice,
|
||||||
template.inputTokens * promptTokenPrice + template.outputTokens * completionTokenPrice,
|
|
||||||
});
|
});
|
||||||
loggedCallsToUpdate.push({
|
loggedCallsToUpdate.push({
|
||||||
where: {
|
where: {
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ export default function OutputCell({
|
|||||||
|
|
||||||
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
|
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
|
||||||
|
|
||||||
const showLogs = !streamedMessage && !mostRecentResponse?.output;
|
const showLogs = !streamedMessage && !mostRecentResponse?.respPayload;
|
||||||
|
|
||||||
if (showLogs)
|
if (showLogs)
|
||||||
return (
|
return (
|
||||||
@@ -160,13 +160,13 @@ export default function OutputCell({
|
|||||||
</CellWrapper>
|
</CellWrapper>
|
||||||
);
|
);
|
||||||
|
|
||||||
const normalizedOutput = mostRecentResponse?.output
|
const normalizedOutput = mostRecentResponse?.respPayload
|
||||||
? provider.normalizeOutput(mostRecentResponse?.output)
|
? provider.normalizeOutput(mostRecentResponse?.respPayload)
|
||||||
: streamedMessage
|
: streamedMessage
|
||||||
? provider.normalizeOutput(streamedMessage)
|
? provider.normalizeOutput(streamedMessage)
|
||||||
: null;
|
: null;
|
||||||
|
|
||||||
if (mostRecentResponse?.output && normalizedOutput?.type === "json") {
|
if (mostRecentResponse?.respPayload && normalizedOutput?.type === "json") {
|
||||||
return (
|
return (
|
||||||
<CellWrapper>
|
<CellWrapper>
|
||||||
<SyntaxHighlighter
|
<SyntaxHighlighter
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ export const OutputStats = ({
|
|||||||
? modelResponse.receivedAt.getTime() - modelResponse.requestedAt.getTime()
|
? modelResponse.receivedAt.getTime() - modelResponse.requestedAt.getTime()
|
||||||
: 0;
|
: 0;
|
||||||
|
|
||||||
const promptTokens = modelResponse.promptTokens;
|
const inputTokens = modelResponse.inputTokens;
|
||||||
const completionTokens = modelResponse.completionTokens;
|
const outputTokens = modelResponse.outputTokens;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<HStack
|
<HStack
|
||||||
@@ -55,8 +55,8 @@ export const OutputStats = ({
|
|||||||
</HStack>
|
</HStack>
|
||||||
{modelResponse.cost && (
|
{modelResponse.cost && (
|
||||||
<CostTooltip
|
<CostTooltip
|
||||||
promptTokens={promptTokens}
|
inputTokens={inputTokens}
|
||||||
completionTokens={completionTokens}
|
outputTokens={outputTokens}
|
||||||
cost={modelResponse.cost}
|
cost={modelResponse.cost}
|
||||||
>
|
>
|
||||||
<HStack spacing={0}>
|
<HStack spacing={0}>
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
|||||||
initialData: {
|
initialData: {
|
||||||
evalResults: [],
|
evalResults: [],
|
||||||
overallCost: 0,
|
overallCost: 0,
|
||||||
promptTokens: 0,
|
inputTokens: 0,
|
||||||
completionTokens: 0,
|
outputTokens: 0,
|
||||||
scenarioCount: 0,
|
scenarioCount: 0,
|
||||||
outputCount: 0,
|
outputCount: 0,
|
||||||
awaitingEvals: false,
|
awaitingEvals: false,
|
||||||
@@ -68,8 +68,8 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
|||||||
</HStack>
|
</HStack>
|
||||||
{data.overallCost && (
|
{data.overallCost && (
|
||||||
<CostTooltip
|
<CostTooltip
|
||||||
promptTokens={data.promptTokens}
|
inputTokens={data.inputTokens}
|
||||||
completionTokens={data.completionTokens}
|
outputTokens={data.outputTokens}
|
||||||
cost={data.overallCost}
|
cost={data.overallCost}
|
||||||
>
|
>
|
||||||
<HStack spacing={0} align="center" color="gray.500">
|
<HStack spacing={0} align="center" color="gray.500">
|
||||||
|
|||||||
@@ -90,9 +90,9 @@ function TableRow({
|
|||||||
isExpanded: boolean;
|
isExpanded: boolean;
|
||||||
onToggle: () => void;
|
onToggle: () => void;
|
||||||
}) {
|
}) {
|
||||||
const isError = loggedCall.modelResponse?.respStatus !== 200;
|
const isError = loggedCall.modelResponse?.statusCode !== 200;
|
||||||
const timeAgo = dayjs(loggedCall.startTime).fromNow();
|
const timeAgo = dayjs(loggedCall.requestedAt).fromNow();
|
||||||
const fullTime = dayjs(loggedCall.startTime).toString();
|
const fullTime = dayjs(loggedCall.requestedAt).toString();
|
||||||
|
|
||||||
const model = useMemo(
|
const model = useMemo(
|
||||||
() => loggedCall.tags.find((tag) => tag.name.startsWith("$model"))?.value,
|
() => loggedCall.tags.find((tag) => tag.name.startsWith("$model"))?.value,
|
||||||
@@ -124,7 +124,7 @@ function TableRow({
|
|||||||
<Td isNumeric>{loggedCall.modelResponse?.inputTokens}</Td>
|
<Td isNumeric>{loggedCall.modelResponse?.inputTokens}</Td>
|
||||||
<Td isNumeric>{loggedCall.modelResponse?.outputTokens}</Td>
|
<Td isNumeric>{loggedCall.modelResponse?.outputTokens}</Td>
|
||||||
<Td sx={{ color: isError ? "red.500" : "green.500", fontWeight: "semibold" }} isNumeric>
|
<Td sx={{ color: isError ? "red.500" : "green.500", fontWeight: "semibold" }} isNumeric>
|
||||||
{loggedCall.modelResponse?.respStatus ?? "No response"}
|
{loggedCall.modelResponse?.statusCode ?? "No response"}
|
||||||
</Td>
|
</Td>
|
||||||
</Tr>
|
</Tr>
|
||||||
<Tr>
|
<Tr>
|
||||||
|
|||||||
61
app/src/components/dashboard/UsageGraph.tsx
Normal file
61
app/src/components/dashboard/UsageGraph.tsx
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import {
|
||||||
|
ResponsiveContainer,
|
||||||
|
LineChart,
|
||||||
|
Line,
|
||||||
|
XAxis,
|
||||||
|
YAxis,
|
||||||
|
CartesianGrid,
|
||||||
|
Tooltip,
|
||||||
|
Legend,
|
||||||
|
} from "recharts";
|
||||||
|
import { useMemo } from "react";
|
||||||
|
|
||||||
|
import { useSelectedProject } from "~/utils/hooks";
|
||||||
|
import dayjs from "~/utils/dayjs";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
|
||||||
|
export default function UsageGraph() {
|
||||||
|
const { data: selectedProject } = useSelectedProject();
|
||||||
|
|
||||||
|
const stats = api.dashboard.stats.useQuery(
|
||||||
|
{ projectId: selectedProject?.id ?? "" },
|
||||||
|
{ enabled: !!selectedProject },
|
||||||
|
);
|
||||||
|
|
||||||
|
const data = useMemo(() => {
|
||||||
|
return (
|
||||||
|
stats.data?.periods.map(({ period, numQueries, cost }) => ({
|
||||||
|
period,
|
||||||
|
Requests: numQueries,
|
||||||
|
"Total Spent (USD)": parseFloat(cost.toString()),
|
||||||
|
})) || []
|
||||||
|
);
|
||||||
|
}, [stats.data]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ResponsiveContainer width="100%" height={400}>
|
||||||
|
<LineChart data={data} margin={{ top: 5, right: 20, left: 10, bottom: 5 }}>
|
||||||
|
<XAxis dataKey="period" tickFormatter={(str: string) => dayjs(str).format("MMM D")} />
|
||||||
|
<YAxis yAxisId="left" dataKey="Requests" orientation="left" stroke="#8884d8" />
|
||||||
|
<YAxis
|
||||||
|
yAxisId="right"
|
||||||
|
dataKey="Total Spent (USD)"
|
||||||
|
orientation="right"
|
||||||
|
unit="$"
|
||||||
|
stroke="#82ca9d"
|
||||||
|
/>
|
||||||
|
<Tooltip />
|
||||||
|
<Legend />
|
||||||
|
<CartesianGrid stroke="#f5f5f5" />
|
||||||
|
<Line dataKey="Requests" stroke="#8884d8" yAxisId="left" dot={false} strokeWidth={2} />
|
||||||
|
<Line
|
||||||
|
dataKey="Total Spent (USD)"
|
||||||
|
stroke="#82ca9d"
|
||||||
|
yAxisId="right"
|
||||||
|
dot={false}
|
||||||
|
strokeWidth={2}
|
||||||
|
/>
|
||||||
|
</LineChart>
|
||||||
|
</ResponsiveContainer>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -2,14 +2,14 @@ import { HStack, Icon, Text, Tooltip, type TooltipProps, VStack, Divider } from
|
|||||||
import { BsCurrencyDollar } from "react-icons/bs";
|
import { BsCurrencyDollar } from "react-icons/bs";
|
||||||
|
|
||||||
type CostTooltipProps = {
|
type CostTooltipProps = {
|
||||||
promptTokens: number | null;
|
inputTokens: number | null;
|
||||||
completionTokens: number | null;
|
outputTokens: number | null;
|
||||||
cost: number;
|
cost: number;
|
||||||
} & TooltipProps;
|
} & TooltipProps;
|
||||||
|
|
||||||
export const CostTooltip = ({
|
export const CostTooltip = ({
|
||||||
promptTokens,
|
inputTokens,
|
||||||
completionTokens,
|
outputTokens,
|
||||||
cost,
|
cost,
|
||||||
children,
|
children,
|
||||||
...props
|
...props
|
||||||
@@ -36,12 +36,12 @@ export const CostTooltip = ({
|
|||||||
<HStack>
|
<HStack>
|
||||||
<VStack w="28" spacing={1}>
|
<VStack w="28" spacing={1}>
|
||||||
<Text>Prompt</Text>
|
<Text>Prompt</Text>
|
||||||
<Text>{promptTokens ?? 0}</Text>
|
<Text>{inputTokens ?? 0}</Text>
|
||||||
</VStack>
|
</VStack>
|
||||||
<Divider borderColor="gray.200" h={8} orientation="vertical" />
|
<Divider borderColor="gray.200" h={8} orientation="vertical" />
|
||||||
<VStack w="28" spacing={1}>
|
<VStack w="28" spacing={1}>
|
||||||
<Text whiteSpace="nowrap">Completion</Text>
|
<Text whiteSpace="nowrap">Completion</Text>
|
||||||
<Text>{completionTokens ?? 0}</Text>
|
<Text>{outputTokens ?? 0}</Text>
|
||||||
</VStack>
|
</VStack>
|
||||||
</HStack>
|
</HStack>
|
||||||
</VStack>
|
</VStack>
|
||||||
|
|||||||
@@ -28,6 +28,10 @@ const modelProvider: AnthropicProvider = {
|
|||||||
inputSchema: inputSchema as JSONSchema4,
|
inputSchema: inputSchema as JSONSchema4,
|
||||||
canStream: true,
|
canStream: true,
|
||||||
getCompletion,
|
getCompletion,
|
||||||
|
getUsage: (input, output) => {
|
||||||
|
// TODO: add usage logic
|
||||||
|
return null;
|
||||||
|
},
|
||||||
...frontendModelProvider,
|
...frontendModelProvider,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -4,14 +4,10 @@ import {
|
|||||||
type ChatCompletion,
|
type ChatCompletion,
|
||||||
type CompletionCreateParams,
|
type CompletionCreateParams,
|
||||||
} from "openai/resources/chat";
|
} from "openai/resources/chat";
|
||||||
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
|
||||||
import { type CompletionResponse } from "../types";
|
import { type CompletionResponse } from "../types";
|
||||||
import { isArray, isString, omit } from "lodash-es";
|
import { isArray, isString, omit } from "lodash-es";
|
||||||
import { openai } from "~/server/utils/openai";
|
import { openai } from "~/server/utils/openai";
|
||||||
import { truthyFilter } from "~/utils/utils";
|
|
||||||
import { APIError } from "openai";
|
import { APIError } from "openai";
|
||||||
import frontendModelProvider from "./frontend";
|
|
||||||
import modelProvider, { type SupportedModel } from ".";
|
|
||||||
|
|
||||||
const mergeStreamedChunks = (
|
const mergeStreamedChunks = (
|
||||||
base: ChatCompletion | null,
|
base: ChatCompletion | null,
|
||||||
@@ -60,9 +56,6 @@ export async function getCompletion(
|
|||||||
): Promise<CompletionResponse<ChatCompletion>> {
|
): Promise<CompletionResponse<ChatCompletion>> {
|
||||||
const start = Date.now();
|
const start = Date.now();
|
||||||
let finalCompletion: ChatCompletion | null = null;
|
let finalCompletion: ChatCompletion | null = null;
|
||||||
let promptTokens: number | undefined = undefined;
|
|
||||||
let completionTokens: number | undefined = undefined;
|
|
||||||
const modelName = modelProvider.getModel(input) as SupportedModel;
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (onStream) {
|
if (onStream) {
|
||||||
@@ -86,16 +79,6 @@ export async function getCompletion(
|
|||||||
autoRetry: false,
|
autoRetry: false,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
try {
|
|
||||||
promptTokens = countOpenAIChatTokens(modelName, input.messages);
|
|
||||||
completionTokens = countOpenAIChatTokens(
|
|
||||||
modelName,
|
|
||||||
finalCompletion.choices.map((c) => c.message).filter(truthyFilter),
|
|
||||||
);
|
|
||||||
} catch (err) {
|
|
||||||
// TODO handle this, library seems like maybe it doesn't work with function calls?
|
|
||||||
console.error(err);
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
const resp = await openai.chat.completions.create(
|
const resp = await openai.chat.completions.create(
|
||||||
{ ...input, stream: false },
|
{ ...input, stream: false },
|
||||||
@@ -104,25 +87,14 @@ export async function getCompletion(
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
finalCompletion = resp;
|
finalCompletion = resp;
|
||||||
promptTokens = resp.usage?.prompt_tokens ?? 0;
|
|
||||||
completionTokens = resp.usage?.completion_tokens ?? 0;
|
|
||||||
}
|
}
|
||||||
const timeToComplete = Date.now() - start;
|
const timeToComplete = Date.now() - start;
|
||||||
|
|
||||||
const { promptTokenPrice, completionTokenPrice } = frontendModelProvider.models[modelName];
|
|
||||||
let cost = undefined;
|
|
||||||
if (promptTokenPrice && completionTokenPrice && promptTokens && completionTokens) {
|
|
||||||
cost = promptTokens * promptTokenPrice + completionTokens * completionTokenPrice;
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
type: "success",
|
type: "success",
|
||||||
statusCode: 200,
|
statusCode: 200,
|
||||||
value: finalCompletion,
|
value: finalCompletion,
|
||||||
timeToComplete,
|
timeToComplete,
|
||||||
promptTokens,
|
|
||||||
completionTokens,
|
|
||||||
cost,
|
|
||||||
};
|
};
|
||||||
} catch (error: unknown) {
|
} catch (error: unknown) {
|
||||||
if (error instanceof APIError) {
|
if (error instanceof APIError) {
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import inputSchema from "./codegen/input.schema.json";
|
|||||||
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
||||||
import { getCompletion } from "./getCompletion";
|
import { getCompletion } from "./getCompletion";
|
||||||
import frontendModelProvider from "./frontend";
|
import frontendModelProvider from "./frontend";
|
||||||
|
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||||
|
import { truthyFilter } from "~/utils/utils";
|
||||||
|
|
||||||
const supportedModels = [
|
const supportedModels = [
|
||||||
"gpt-4-0613",
|
"gpt-4-0613",
|
||||||
@@ -39,6 +41,41 @@ const modelProvider: OpenaiChatModelProvider = {
|
|||||||
inputSchema: inputSchema as JSONSchema4,
|
inputSchema: inputSchema as JSONSchema4,
|
||||||
canStream: true,
|
canStream: true,
|
||||||
getCompletion,
|
getCompletion,
|
||||||
|
getUsage: (input, output) => {
|
||||||
|
if (output.choices.length === 0) return null;
|
||||||
|
|
||||||
|
const model = modelProvider.getModel(input);
|
||||||
|
if (!model) return null;
|
||||||
|
|
||||||
|
let inputTokens: number;
|
||||||
|
let outputTokens: number;
|
||||||
|
|
||||||
|
if (output.usage) {
|
||||||
|
inputTokens = output.usage.prompt_tokens;
|
||||||
|
outputTokens = output.usage.completion_tokens;
|
||||||
|
} else {
|
||||||
|
try {
|
||||||
|
inputTokens = countOpenAIChatTokens(model, input.messages);
|
||||||
|
outputTokens = countOpenAIChatTokens(
|
||||||
|
model,
|
||||||
|
output.choices.map((c) => c.message).filter(truthyFilter),
|
||||||
|
);
|
||||||
|
} catch (err) {
|
||||||
|
inputTokens = 0;
|
||||||
|
outputTokens = 0;
|
||||||
|
// TODO handle this, library seems like maybe it doesn't work with function calls?
|
||||||
|
console.error(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const { promptTokenPrice, completionTokenPrice } = frontendModelProvider.models[model];
|
||||||
|
let cost = undefined;
|
||||||
|
if (promptTokenPrice && completionTokenPrice && inputTokens && outputTokens) {
|
||||||
|
cost = inputTokens * promptTokenPrice + outputTokens * completionTokenPrice;
|
||||||
|
}
|
||||||
|
|
||||||
|
return { inputTokens: inputTokens, outputTokens: outputTokens, cost };
|
||||||
|
},
|
||||||
...frontendModelProvider,
|
...frontendModelProvider,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -75,6 +75,10 @@ const modelProvider: ReplicateLlama2Provider = {
|
|||||||
},
|
},
|
||||||
canStream: true,
|
canStream: true,
|
||||||
getCompletion,
|
getCompletion,
|
||||||
|
getUsage: (input, output) => {
|
||||||
|
// TODO: add usage logic
|
||||||
|
return null;
|
||||||
|
},
|
||||||
...frontendModelProvider,
|
...frontendModelProvider,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -43,9 +43,6 @@ export type CompletionResponse<T> =
|
|||||||
value: T;
|
value: T;
|
||||||
timeToComplete: number;
|
timeToComplete: number;
|
||||||
statusCode: number;
|
statusCode: number;
|
||||||
promptTokens?: number;
|
|
||||||
completionTokens?: number;
|
|
||||||
cost?: number;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
|
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
|
||||||
@@ -56,6 +53,10 @@ export type ModelProvider<SupportedModels extends string, InputSchema, OutputSch
|
|||||||
input: InputSchema,
|
input: InputSchema,
|
||||||
onStream: ((partialOutput: OutputSchema) => void) | null,
|
onStream: ((partialOutput: OutputSchema) => void) | null,
|
||||||
) => Promise<CompletionResponse<OutputSchema>>;
|
) => Promise<CompletionResponse<OutputSchema>>;
|
||||||
|
getUsage: (
|
||||||
|
input: InputSchema,
|
||||||
|
output: OutputSchema,
|
||||||
|
) => { gpuRuntime?: number; inputTokens?: number; outputTokens?: number; cost?: number } | null;
|
||||||
|
|
||||||
// This is just a convenience for type inference, don't use it at runtime
|
// This is just a convenience for type inference, don't use it at runtime
|
||||||
_outputSchema?: OutputSchema | null;
|
_outputSchema?: OutputSchema | null;
|
||||||
|
|||||||
@@ -18,26 +18,15 @@ import {
|
|||||||
Breadcrumb,
|
Breadcrumb,
|
||||||
BreadcrumbItem,
|
BreadcrumbItem,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import {
|
|
||||||
LineChart,
|
|
||||||
Line,
|
|
||||||
XAxis,
|
|
||||||
YAxis,
|
|
||||||
CartesianGrid,
|
|
||||||
Tooltip,
|
|
||||||
Legend,
|
|
||||||
ResponsiveContainer,
|
|
||||||
} from "recharts";
|
|
||||||
import { Ban, DollarSign, Hash } from "lucide-react";
|
import { Ban, DollarSign, Hash } from "lucide-react";
|
||||||
import { useMemo } from "react";
|
|
||||||
|
|
||||||
import AppShell from "~/components/nav/AppShell";
|
import AppShell from "~/components/nav/AppShell";
|
||||||
import PageHeaderContainer from "~/components/nav/PageHeaderContainer";
|
import PageHeaderContainer from "~/components/nav/PageHeaderContainer";
|
||||||
import ProjectBreadcrumbContents from "~/components/nav/ProjectBreadcrumbContents";
|
import ProjectBreadcrumbContents from "~/components/nav/ProjectBreadcrumbContents";
|
||||||
import { useSelectedProject } from "~/utils/hooks";
|
import { useSelectedProject } from "~/utils/hooks";
|
||||||
import dayjs from "~/utils/dayjs";
|
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import LoggedCallTable from "~/components/dashboard/LoggedCallTable";
|
import LoggedCallTable from "~/components/dashboard/LoggedCallTable";
|
||||||
|
import UsageGraph from "~/components/dashboard/UsageGraph";
|
||||||
|
|
||||||
export default function LoggedCalls() {
|
export default function LoggedCalls() {
|
||||||
const { data: selectedProject } = useSelectedProject();
|
const { data: selectedProject } = useSelectedProject();
|
||||||
@@ -47,16 +36,6 @@ export default function LoggedCalls() {
|
|||||||
{ enabled: !!selectedProject },
|
{ enabled: !!selectedProject },
|
||||||
);
|
);
|
||||||
|
|
||||||
const data = useMemo(() => {
|
|
||||||
return (
|
|
||||||
stats.data?.periods.map(({ period, numQueries, totalCost }) => ({
|
|
||||||
period,
|
|
||||||
Requests: numQueries,
|
|
||||||
"Total Spent (USD)": parseFloat(totalCost.toString()),
|
|
||||||
})) || []
|
|
||||||
);
|
|
||||||
}, [stats.data]);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<AppShell requireAuth>
|
<AppShell requireAuth>
|
||||||
<PageHeaderContainer>
|
<PageHeaderContainer>
|
||||||
@@ -83,39 +62,7 @@ export default function LoggedCalls() {
|
|||||||
</Heading>
|
</Heading>
|
||||||
</CardHeader>
|
</CardHeader>
|
||||||
<CardBody>
|
<CardBody>
|
||||||
<ResponsiveContainer width="100%" height={400}>
|
<UsageGraph />
|
||||||
<LineChart data={data} margin={{ top: 5, right: 20, left: 10, bottom: 5 }}>
|
|
||||||
<XAxis
|
|
||||||
dataKey="period"
|
|
||||||
tickFormatter={(str: string) => dayjs(str).format("MMM D")}
|
|
||||||
/>
|
|
||||||
<YAxis yAxisId="left" dataKey="Requests" orientation="left" stroke="#8884d8" />
|
|
||||||
<YAxis
|
|
||||||
yAxisId="right"
|
|
||||||
dataKey="Total Spent (USD)"
|
|
||||||
orientation="right"
|
|
||||||
unit="$"
|
|
||||||
stroke="#82ca9d"
|
|
||||||
/>
|
|
||||||
<Tooltip />
|
|
||||||
<Legend />
|
|
||||||
<CartesianGrid stroke="#f5f5f5" />
|
|
||||||
<Line
|
|
||||||
dataKey="Requests"
|
|
||||||
stroke="#8884d8"
|
|
||||||
yAxisId="left"
|
|
||||||
dot={false}
|
|
||||||
strokeWidth={2}
|
|
||||||
/>
|
|
||||||
<Line
|
|
||||||
dataKey="Total Spent (USD)"
|
|
||||||
stroke="#82ca9d"
|
|
||||||
yAxisId="right"
|
|
||||||
dot={false}
|
|
||||||
strokeWidth={2}
|
|
||||||
/>
|
|
||||||
</LineChart>
|
|
||||||
</ResponsiveContainer>
|
|
||||||
</CardBody>
|
</CardBody>
|
||||||
</Card>
|
</Card>
|
||||||
<VStack spacing="4" width="300px" align="stretch">
|
<VStack spacing="4" width="300px" align="stretch">
|
||||||
@@ -127,7 +74,7 @@ export default function LoggedCalls() {
|
|||||||
<Icon as={DollarSign} boxSize={4} color="gray.500" />
|
<Icon as={DollarSign} boxSize={4} color="gray.500" />
|
||||||
</HStack>
|
</HStack>
|
||||||
<StatNumber>
|
<StatNumber>
|
||||||
${parseFloat(stats.data?.totals?.totalCost?.toString() ?? "0").toFixed(2)}
|
${parseFloat(stats.data?.totals?.cost?.toString() ?? "0").toFixed(3)}
|
||||||
</StatNumber>
|
</StatNumber>
|
||||||
</Stat>
|
</Stat>
|
||||||
</CardBody>
|
</CardBody>
|
||||||
|
|||||||
@@ -38,7 +38,10 @@ export default function Settings() {
|
|||||||
id: selectedProject.id,
|
id: selectedProject.id,
|
||||||
updates: { name },
|
updates: { name },
|
||||||
});
|
});
|
||||||
await Promise.all([utils.projects.get.invalidate({ id: selectedProject.id })]);
|
await Promise.all([
|
||||||
|
utils.projects.get.invalidate({ id: selectedProject.id }),
|
||||||
|
utils.projects.list.invalidate(),
|
||||||
|
]);
|
||||||
}
|
}
|
||||||
}, [updateMutation, selectedProject]);
|
}, [updateMutation, selectedProject]);
|
||||||
|
|
||||||
|
|||||||
@@ -24,9 +24,9 @@ export const dashboardRouter = createTRPCRouter({
|
|||||||
)
|
)
|
||||||
.where("projectId", "=", input.projectId)
|
.where("projectId", "=", input.projectId)
|
||||||
.select(({ fn }) => [
|
.select(({ fn }) => [
|
||||||
sql<Date>`date_trunc('day', "LoggedCallModelResponse"."startTime")`.as("period"),
|
sql<Date>`date_trunc('day', "LoggedCallModelResponse"."requestedAt")`.as("period"),
|
||||||
sql<number>`count("LoggedCall"."id")::int`.as("numQueries"),
|
sql<number>`count("LoggedCall"."id")::int`.as("numQueries"),
|
||||||
fn.sum(fn.coalesce("LoggedCallModelResponse.totalCost", sql<number>`0`)).as("totalCost"),
|
fn.sum(fn.coalesce("LoggedCallModelResponse.cost", sql<number>`0`)).as("cost"),
|
||||||
])
|
])
|
||||||
.groupBy("period")
|
.groupBy("period")
|
||||||
.orderBy("period")
|
.orderBy("period")
|
||||||
@@ -57,7 +57,7 @@ export const dashboardRouter = createTRPCRouter({
|
|||||||
backfilledPeriods.unshift({
|
backfilledPeriods.unshift({
|
||||||
period: dayjs(dayToMatch).toDate(),
|
period: dayjs(dayToMatch).toDate(),
|
||||||
numQueries: 0,
|
numQueries: 0,
|
||||||
totalCost: 0,
|
cost: 0,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
dayToMatch = dayToMatch.subtract(1, "day");
|
dayToMatch = dayToMatch.subtract(1, "day");
|
||||||
@@ -72,7 +72,7 @@ export const dashboardRouter = createTRPCRouter({
|
|||||||
)
|
)
|
||||||
.where("projectId", "=", input.projectId)
|
.where("projectId", "=", input.projectId)
|
||||||
.select(({ fn }) => [
|
.select(({ fn }) => [
|
||||||
fn.sum(fn.coalesce("LoggedCallModelResponse.totalCost", sql<number>`0`)).as("totalCost"),
|
fn.sum(fn.coalesce("LoggedCallModelResponse.cost", sql<number>`0`)).as("cost"),
|
||||||
fn.count("LoggedCall.id").as("numQueries"),
|
fn.count("LoggedCall.id").as("numQueries"),
|
||||||
])
|
])
|
||||||
.executeTakeFirst();
|
.executeTakeFirst();
|
||||||
@@ -85,8 +85,8 @@ export const dashboardRouter = createTRPCRouter({
|
|||||||
"LoggedCall.id",
|
"LoggedCall.id",
|
||||||
"LoggedCallModelResponse.originalLoggedCallId",
|
"LoggedCallModelResponse.originalLoggedCallId",
|
||||||
)
|
)
|
||||||
.select(({ fn }) => [fn.count("LoggedCall.id").as("count"), "respStatus as code"])
|
.select(({ fn }) => [fn.count("LoggedCall.id").as("count"), "statusCode as code"])
|
||||||
.where("respStatus", ">", 200)
|
.where("statusCode", ">", 200)
|
||||||
.groupBy("code")
|
.groupBy("code")
|
||||||
.orderBy("count", "desc")
|
.orderBy("count", "desc")
|
||||||
.execute();
|
.execute();
|
||||||
@@ -108,7 +108,7 @@ export const dashboardRouter = createTRPCRouter({
|
|||||||
// https://discord.com/channels/966627436387266600/1122258443886153758/1122258443886153758
|
// https://discord.com/channels/966627436387266600/1122258443886153758/1122258443886153758
|
||||||
loggedCalls: publicProcedure.input(z.object({})).query(async ({ input }) => {
|
loggedCalls: publicProcedure.input(z.object({})).query(async ({ input }) => {
|
||||||
const loggedCalls = await prisma.loggedCall.findMany({
|
const loggedCalls = await prisma.loggedCall.findMany({
|
||||||
orderBy: { startTime: "desc" },
|
orderBy: { requestedAt: "desc" },
|
||||||
include: { tags: true, modelResponse: true },
|
include: { tags: true, modelResponse: true },
|
||||||
take: 20,
|
take: 20,
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -227,7 +227,7 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
...modelResponseData,
|
...modelResponseData,
|
||||||
id: newModelResponseId,
|
id: newModelResponseId,
|
||||||
scenarioVariantCellId: newCellId,
|
scenarioVariantCellId: newCellId,
|
||||||
output: (modelResponse.output as Prisma.InputJsonValue) ?? undefined,
|
respPayload: (modelResponse.respPayload as Prisma.InputJsonValue) ?? undefined,
|
||||||
});
|
});
|
||||||
for (const evaluation of outputEvaluations) {
|
for (const evaluation of outputEvaluations) {
|
||||||
outputEvaluationsToCreate.push({
|
outputEvaluationsToCreate.push({
|
||||||
|
|||||||
@@ -7,6 +7,11 @@ import { TRPCError } from "@trpc/server";
|
|||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import { hashRequest } from "~/server/utils/hashObject";
|
import { hashRequest } from "~/server/utils/hashObject";
|
||||||
|
import modelProvider from "~/modelProviders/openai-ChatCompletion";
|
||||||
|
import {
|
||||||
|
type ChatCompletion,
|
||||||
|
type CompletionCreateParams,
|
||||||
|
} from "openai/resources/chat/completions";
|
||||||
|
|
||||||
const reqValidator = z.object({
|
const reqValidator = z.object({
|
||||||
model: z.string(),
|
model: z.string(),
|
||||||
@@ -16,11 +21,6 @@ const reqValidator = z.object({
|
|||||||
const respValidator = z.object({
|
const respValidator = z.object({
|
||||||
id: z.string(),
|
id: z.string(),
|
||||||
model: z.string(),
|
model: z.string(),
|
||||||
usage: z.object({
|
|
||||||
total_tokens: z.number(),
|
|
||||||
prompt_tokens: z.number(),
|
|
||||||
completion_tokens: z.number(),
|
|
||||||
}),
|
|
||||||
choices: z.array(
|
choices: z.array(
|
||||||
z.object({
|
z.object({
|
||||||
finish_reason: z.string(),
|
finish_reason: z.string(),
|
||||||
@@ -76,7 +76,7 @@ export const externalApiRouter = createTRPCRouter({
|
|||||||
originalLoggedCall: true,
|
originalLoggedCall: true,
|
||||||
},
|
},
|
||||||
orderBy: {
|
orderBy: {
|
||||||
startTime: "desc",
|
requestedAt: "desc",
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -85,7 +85,7 @@ export const externalApiRouter = createTRPCRouter({
|
|||||||
await prisma.loggedCall.create({
|
await prisma.loggedCall.create({
|
||||||
data: {
|
data: {
|
||||||
projectId: key.projectId,
|
projectId: key.projectId,
|
||||||
startTime: new Date(input.startTime),
|
requestedAt: new Date(input.startTime),
|
||||||
cacheHit: true,
|
cacheHit: true,
|
||||||
modelResponseId: existingResponse.id,
|
modelResponseId: existingResponse.id,
|
||||||
},
|
},
|
||||||
@@ -140,14 +140,20 @@ export const externalApiRouter = createTRPCRouter({
|
|||||||
const newLoggedCallId = uuidv4();
|
const newLoggedCallId = uuidv4();
|
||||||
const newModelResponseId = uuidv4();
|
const newModelResponseId = uuidv4();
|
||||||
|
|
||||||
const usage = respPayload.success ? respPayload.data.usage : undefined;
|
let usage;
|
||||||
|
if (reqPayload.success && respPayload.success) {
|
||||||
|
usage = modelProvider.getUsage(
|
||||||
|
input.reqPayload as CompletionCreateParams,
|
||||||
|
input.respPayload as ChatCompletion,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
await prisma.$transaction([
|
await prisma.$transaction([
|
||||||
prisma.loggedCall.create({
|
prisma.loggedCall.create({
|
||||||
data: {
|
data: {
|
||||||
id: newLoggedCallId,
|
id: newLoggedCallId,
|
||||||
projectId: key.projectId,
|
projectId: key.projectId,
|
||||||
startTime: new Date(input.startTime),
|
requestedAt: new Date(input.startTime),
|
||||||
cacheHit: false,
|
cacheHit: false,
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
@@ -155,20 +161,17 @@ export const externalApiRouter = createTRPCRouter({
|
|||||||
data: {
|
data: {
|
||||||
id: newModelResponseId,
|
id: newModelResponseId,
|
||||||
originalLoggedCallId: newLoggedCallId,
|
originalLoggedCallId: newLoggedCallId,
|
||||||
startTime: new Date(input.startTime),
|
requestedAt: new Date(input.startTime),
|
||||||
endTime: new Date(input.endTime),
|
receivedAt: new Date(input.endTime),
|
||||||
reqPayload: input.reqPayload as Prisma.InputJsonValue,
|
reqPayload: input.reqPayload as Prisma.InputJsonValue,
|
||||||
respPayload: input.respPayload as Prisma.InputJsonValue,
|
respPayload: input.respPayload as Prisma.InputJsonValue,
|
||||||
respStatus: input.respStatus,
|
statusCode: input.respStatus,
|
||||||
error: input.error,
|
errorMessage: input.error,
|
||||||
durationMs: input.endTime - input.startTime,
|
durationMs: input.endTime - input.startTime,
|
||||||
...(respPayload.success
|
cacheKey: respPayload.success ? requestHash : null,
|
||||||
? {
|
inputTokens: usage?.inputTokens,
|
||||||
cacheKey: requestHash,
|
outputTokens: usage?.outputTokens,
|
||||||
inputTokens: usage ? usage.prompt_tokens : undefined,
|
cost: usage?.cost,
|
||||||
outputTokens: usage ? usage.completion_tokens : undefined,
|
|
||||||
}
|
|
||||||
: null),
|
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
// Avoid foreign key constraint error by updating the logged call after the model response is created
|
// Avoid foreign key constraint error by updating the logged call after the model response is created
|
||||||
@@ -182,24 +185,22 @@ export const externalApiRouter = createTRPCRouter({
|
|||||||
}),
|
}),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
if (input.tags) {
|
const tagsToCreate = Object.entries(input.tags ?? {}).map(([name, value]) => ({
|
||||||
const tagsToCreate = Object.entries(input.tags).map(([name, value]) => ({
|
loggedCallId: newLoggedCallId,
|
||||||
loggedCallId: newLoggedCallId,
|
// sanitize tags
|
||||||
// sanitize tags
|
name: name.replaceAll(/[^a-zA-Z0-9_]/g, "_"),
|
||||||
name: name.replaceAll(/[^a-zA-Z0-9_]/g, "_"),
|
value,
|
||||||
value,
|
}));
|
||||||
}));
|
|
||||||
|
|
||||||
if (reqPayload.success) {
|
if (reqPayload.success) {
|
||||||
tagsToCreate.push({
|
tagsToCreate.push({
|
||||||
loggedCallId: newLoggedCallId,
|
loggedCallId: newLoggedCallId,
|
||||||
name: "$model",
|
name: "$model",
|
||||||
value: reqPayload.data.model,
|
value: reqPayload.data.model,
|
||||||
});
|
|
||||||
}
|
|
||||||
await prisma.loggedCallTag.createMany({
|
|
||||||
data: tagsToCreate,
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
await prisma.loggedCallTag.createMany({
|
||||||
|
data: tagsToCreate,
|
||||||
|
});
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
where: {
|
where: {
|
||||||
modelResponse: {
|
modelResponse: {
|
||||||
outdated: false,
|
outdated: false,
|
||||||
output: { not: Prisma.AnyNull },
|
respPayload: { not: Prisma.AnyNull },
|
||||||
scenarioVariantCell: {
|
scenarioVariantCell: {
|
||||||
promptVariant: {
|
promptVariant: {
|
||||||
id: input.variantId,
|
id: input.variantId,
|
||||||
@@ -100,7 +100,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
modelResponses: {
|
modelResponses: {
|
||||||
some: {
|
some: {
|
||||||
outdated: false,
|
outdated: false,
|
||||||
output: {
|
respPayload: {
|
||||||
not: Prisma.AnyNull,
|
not: Prisma.AnyNull,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -111,7 +111,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
const overallTokens = await prisma.modelResponse.aggregate({
|
const overallTokens = await prisma.modelResponse.aggregate({
|
||||||
where: {
|
where: {
|
||||||
outdated: false,
|
outdated: false,
|
||||||
output: {
|
respPayload: {
|
||||||
not: Prisma.AnyNull,
|
not: Prisma.AnyNull,
|
||||||
},
|
},
|
||||||
scenarioVariantCell: {
|
scenarioVariantCell: {
|
||||||
@@ -123,13 +123,13 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
_sum: {
|
_sum: {
|
||||||
cost: true,
|
cost: true,
|
||||||
promptTokens: true,
|
inputTokens: true,
|
||||||
completionTokens: true,
|
outputTokens: true,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
|
const inputTokens = overallTokens._sum?.inputTokens ?? 0;
|
||||||
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
|
const outputTokens = overallTokens._sum?.outputTokens ?? 0;
|
||||||
|
|
||||||
const awaitingEvals = !!evalResults.find(
|
const awaitingEvals = !!evalResults.find(
|
||||||
(result) => result.totalCount < scenarioCount * evals.length,
|
(result) => result.totalCount < scenarioCount * evals.length,
|
||||||
@@ -137,8 +137,8 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
evalResults,
|
evalResults,
|
||||||
promptTokens,
|
inputTokens,
|
||||||
completionTokens,
|
outputTokens,
|
||||||
overallCost: overallTokens._sum?.cost ?? 0,
|
overallCost: overallTokens._sum?.cost ?? 0,
|
||||||
scenarioCount,
|
scenarioCount,
|
||||||
outputCount,
|
outputCount,
|
||||||
|
|||||||
@@ -99,26 +99,27 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
|
|||||||
}
|
}
|
||||||
: null;
|
: null;
|
||||||
|
|
||||||
const inputHash = hashObject(prompt as JsonValue);
|
const cacheKey = hashObject(prompt as JsonValue);
|
||||||
|
|
||||||
let modelResponse = await prisma.modelResponse.create({
|
let modelResponse = await prisma.modelResponse.create({
|
||||||
data: {
|
data: {
|
||||||
inputHash,
|
cacheKey,
|
||||||
scenarioVariantCellId: cellId,
|
scenarioVariantCellId: cellId,
|
||||||
requestedAt: new Date(),
|
requestedAt: new Date(),
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
||||||
if (response.type === "success") {
|
if (response.type === "success") {
|
||||||
|
const usage = provider.getUsage(prompt.modelInput, response.value);
|
||||||
modelResponse = await prisma.modelResponse.update({
|
modelResponse = await prisma.modelResponse.update({
|
||||||
where: { id: modelResponse.id },
|
where: { id: modelResponse.id },
|
||||||
data: {
|
data: {
|
||||||
output: response.value as Prisma.InputJsonObject,
|
respPayload: response.value as Prisma.InputJsonObject,
|
||||||
statusCode: response.statusCode,
|
statusCode: response.statusCode,
|
||||||
receivedAt: new Date(),
|
receivedAt: new Date(),
|
||||||
promptTokens: response.promptTokens,
|
inputTokens: usage?.inputTokens,
|
||||||
completionTokens: response.completionTokens,
|
outputTokens: usage?.outputTokens,
|
||||||
cost: response.cost,
|
cost: usage?.cost,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ export const runAllEvals = async (experimentId: string) => {
|
|||||||
const outputs = await prisma.modelResponse.findMany({
|
const outputs = await prisma.modelResponse.findMany({
|
||||||
where: {
|
where: {
|
||||||
outdated: false,
|
outdated: false,
|
||||||
output: {
|
respPayload: {
|
||||||
not: Prisma.AnyNull,
|
not: Prisma.AnyNull,
|
||||||
},
|
},
|
||||||
scenarioVariantCell: {
|
scenarioVariantCell: {
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ export const generateNewCell = async (
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const inputHash = hashObject(parsedConstructFn);
|
const cacheKey = hashObject(parsedConstructFn);
|
||||||
|
|
||||||
cell = await prisma.scenarioVariantCell.create({
|
cell = await prisma.scenarioVariantCell.create({
|
||||||
data: {
|
data: {
|
||||||
@@ -73,8 +73,8 @@ export const generateNewCell = async (
|
|||||||
|
|
||||||
const matchingModelResponse = await prisma.modelResponse.findFirst({
|
const matchingModelResponse = await prisma.modelResponse.findFirst({
|
||||||
where: {
|
where: {
|
||||||
inputHash,
|
cacheKey,
|
||||||
output: {
|
respPayload: {
|
||||||
not: Prisma.AnyNull,
|
not: Prisma.AnyNull,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -92,7 +92,7 @@ export const generateNewCell = async (
|
|||||||
data: {
|
data: {
|
||||||
...omit(matchingModelResponse, ["id", "scenarioVariantCell"]),
|
...omit(matchingModelResponse, ["id", "scenarioVariantCell"]),
|
||||||
scenarioVariantCellId: cell.id,
|
scenarioVariantCellId: cell.id,
|
||||||
output: matchingModelResponse.output as Prisma.InputJsonValue,
|
respPayload: matchingModelResponse.respPayload as Prisma.InputJsonValue,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ export const runOneEval = async (
|
|||||||
provider: SupportedProvider,
|
provider: SupportedProvider,
|
||||||
): Promise<{ result: number; details?: string }> => {
|
): Promise<{ result: number; details?: string }> => {
|
||||||
const modelProvider = modelProviders[provider];
|
const modelProvider = modelProviders[provider];
|
||||||
const message = modelProvider.normalizeOutput(modelResponse.output);
|
const message = modelProvider.normalizeOutput(modelResponse.respPayload);
|
||||||
|
|
||||||
if (!message) return { result: 0 };
|
if (!message) return { result: 0 };
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user