Record model and cost when reporting logs (#136)

* Rename prompt and completion tokens to input and output tokens

* Add getUsage function

* Record model and cost when reporting log

* Remove unused imports

* Move UsageGraph to its own component

* Standardize model response fields

* Fix types
This commit is contained in:
arcticfly
2023-08-11 13:56:47 -07:00
committed by GitHub
parent f270579283
commit 8d1ee62ff1
24 changed files with 295 additions and 199 deletions

View File

@@ -0,0 +1,66 @@
/*
Warnings:
- You are about to rename the column `completionTokens` to `outputTokens` on the `ModelResponse` table.
- You are about to rename the column `promptTokens` to `inputTokens` on the `ModelResponse` table.
- You are about to rename the column `startTime` on the `LoggedCall` table to `requestedAt`. Ensure compatibility with application logic.
- You are about to rename the column `startTime` on the `LoggedCallModelResponse` table to `requestedAt`. Ensure compatibility with application logic.
- You are about to rename the column `endTime` on the `LoggedCallModelResponse` table to `receivedAt`. Ensure compatibility with application logic.
- You are about to rename the column `error` on the `LoggedCallModelResponse` table to `errorMessage`. Ensure compatibility with application logic.
- You are about to rename the column `respStatus` on the `LoggedCallModelResponse` table to `statusCode`. Ensure compatibility with application logic.
- You are about to rename the column `totalCost` on the `LoggedCallModelResponse` table to `cost`. Ensure compatibility with application logic.
- You are about to rename the column `inputHash` on the `ModelResponse` table to `cacheKey`. Ensure compatibility with application logic.
- You are about to rename the column `output` on the `ModelResponse` table to `respPayload`. Ensure compatibility with application logic.
*/
-- DropIndex
DROP INDEX "LoggedCall_startTime_idx";
-- DropIndex
DROP INDEX "ModelResponse_inputHash_idx";
-- Rename completionTokens to outputTokens
ALTER TABLE "ModelResponse"
RENAME COLUMN "completionTokens" TO "outputTokens";
-- Rename promptTokens to inputTokens
ALTER TABLE "ModelResponse"
RENAME COLUMN "promptTokens" TO "inputTokens";
-- AlterTable
ALTER TABLE "LoggedCall"
RENAME COLUMN "startTime" TO "requestedAt";
-- AlterTable
ALTER TABLE "LoggedCallModelResponse"
RENAME COLUMN "startTime" TO "requestedAt";
-- AlterTable
ALTER TABLE "LoggedCallModelResponse"
RENAME COLUMN "endTime" TO "receivedAt";
-- AlterTable
ALTER TABLE "LoggedCallModelResponse"
RENAME COLUMN "error" TO "errorMessage";
-- AlterTable
ALTER TABLE "LoggedCallModelResponse"
RENAME COLUMN "respStatus" TO "statusCode";
-- AlterTable
ALTER TABLE "LoggedCallModelResponse"
RENAME COLUMN "totalCost" TO "cost";
-- AlterTable
ALTER TABLE "ModelResponse"
RENAME COLUMN "inputHash" TO "cacheKey";
-- AlterTable
ALTER TABLE "ModelResponse"
RENAME COLUMN "output" TO "respPayload";
-- CreateIndex
CREATE INDEX "LoggedCall_requestedAt_idx" ON "LoggedCall"("requestedAt");
-- CreateIndex
CREATE INDEX "ModelResponse_cacheKey_idx" ON "ModelResponse"("cacheKey");

View File

@@ -112,13 +112,13 @@ model ScenarioVariantCell {
model ModelResponse {
id String @id @default(uuid()) @db.Uuid
inputHash String
cacheKey String
requestedAt DateTime?
receivedAt DateTime?
output Json?
respPayload Json?
cost Float?
promptTokens Int?
completionTokens Int?
inputTokens Int?
outputTokens Int?
statusCode Int?
errorMessage String?
retryTime DateTime?
@@ -131,7 +131,7 @@ model ModelResponse {
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
outputEvaluations OutputEvaluation[]
@@index([inputHash])
@@index([cacheKey])
}
enum EvalType {
@@ -256,7 +256,7 @@ model WorldChampEntrant {
model LoggedCall {
id String @id @default(uuid()) @db.Uuid
startTime DateTime
requestedAt DateTime
// True if this call was served from the cache, false otherwise
cacheHit Boolean
@@ -278,7 +278,7 @@ model LoggedCall {
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
@@index([startTime])
@@index([requestedAt])
}
model LoggedCallModelResponse {
@@ -287,14 +287,14 @@ model LoggedCallModelResponse {
reqPayload Json
// The HTTP status returned by the model provider
respStatus Int?
statusCode Int?
respPayload Json?
// Should be null if the request was successful, and some string if the request failed.
error String?
errorMessage String?
startTime DateTime
endTime DateTime
requestedAt DateTime
receivedAt DateTime
// Note: the function to calculate the cacheKey should include the project
// ID so we don't share cached responses between projects, which could be an
@@ -308,7 +308,7 @@ model LoggedCallModelResponse {
outputTokens Int?
finishReason String?
completionId String?
totalCost Decimal? @db.Decimal(18, 12)
cost Decimal? @db.Decimal(18, 12)
// The LoggedCall that created this LoggedCallModelResponse
originalLoggedCallId String @unique @db.Uuid

View File

@@ -339,17 +339,17 @@ for (let i = 0; i < 1437; i++) {
MODEL_RESPONSE_TEMPLATES[Math.floor(Math.random() * MODEL_RESPONSE_TEMPLATES.length)]!;
const model = template.reqPayload.model;
// choose random time in the last two weeks, with a bias towards the last few days
const startTime = new Date(Date.now() - Math.pow(Math.random(), 2) * 1000 * 60 * 60 * 24 * 14);
const requestedAt = new Date(Date.now() - Math.pow(Math.random(), 2) * 1000 * 60 * 60 * 24 * 14);
// choose random delay anywhere from 2 to 10 seconds later for gpt-4, or 1 to 5 seconds for gpt-3.5
const delay =
model === "gpt-4" ? 1000 * 2 + Math.random() * 1000 * 8 : 1000 + Math.random() * 1000 * 4;
const endTime = new Date(startTime.getTime() + delay);
const receivedAt = new Date(requestedAt.getTime() + delay);
loggedCallsToCreate.push({
id: loggedCallId,
cacheHit: false,
startTime,
requestedAt,
projectId: project.id,
createdAt: startTime,
createdAt: requestedAt,
});
const { promptTokenPrice, completionTokenPrice } =
@@ -365,21 +365,20 @@ for (let i = 0; i < 1437; i++) {
loggedCallModelResponsesToCreate.push({
id: loggedCallModelResponseId,
startTime,
endTime,
requestedAt,
receivedAt,
originalLoggedCallId: loggedCallId,
reqPayload: template.reqPayload,
respPayload: template.respPayload,
respStatus: template.respStatus,
error: template.error,
createdAt: startTime,
statusCode: template.respStatus,
errorMessage: template.error,
createdAt: requestedAt,
cacheKey: hashRequest(project.id, template.reqPayload as JsonValue),
durationMs: endTime.getTime() - startTime.getTime(),
durationMs: receivedAt.getTime() - requestedAt.getTime(),
inputTokens: template.inputTokens,
outputTokens: template.outputTokens,
finishReason: template.finishReason,
totalCost:
template.inputTokens * promptTokenPrice + template.outputTokens * completionTokenPrice,
cost: template.inputTokens * promptTokenPrice + template.outputTokens * completionTokenPrice,
});
loggedCallsToUpdate.push({
where: {