From 1f8e3b820fc12c8bcd39c05098aa22916737174f Mon Sep 17 00:00:00 2001 From: David Corbitt Date: Thu, 10 Aug 2023 19:49:18 -0700 Subject: [PATCH] Rename prompt and completion tokens to input and output tokens --- .../migration.sql | 15 +++++++++++++++ app/prisma/schema.prisma | 4 ++-- .../OutputsTable/OutputCell/OutputStats.tsx | 8 ++++---- app/src/components/OutputsTable/VariantStats.tsx | 8 ++++---- app/src/components/tooltip/CostTooltip.tsx | 12 ++++++------ app/src/modelProviders/types.ts | 7 ++++--- .../server/api/routers/promptVariants.router.ts | 12 ++++++------ app/src/server/tasks/queryModel.task.ts | 7 ++++--- 8 files changed, 45 insertions(+), 28 deletions(-) create mode 100644 app/prisma/migrations/20230811023536_standardize_on_input_and_output_tokens_as_names/migration.sql diff --git a/app/prisma/migrations/20230811023536_standardize_on_input_and_output_tokens_as_names/migration.sql b/app/prisma/migrations/20230811023536_standardize_on_input_and_output_tokens_as_names/migration.sql new file mode 100644 index 0000000..10c5f45 --- /dev/null +++ b/app/prisma/migrations/20230811023536_standardize_on_input_and_output_tokens_as_names/migration.sql @@ -0,0 +1,15 @@ +/* + Warnings: + + - You are about to rename the column `completionTokens` to `outputTokens` on the `ModelResponse` table. + - You are about to rename the column `promptTokens` to `inputTokens` on the `ModelResponse` table. + +*/ + +-- Rename completionTokens to outputTokens +ALTER TABLE "ModelResponse" +RENAME COLUMN "completionTokens" TO "outputTokens"; + +-- Rename promptTokens to inputTokens +ALTER TABLE "ModelResponse" +RENAME COLUMN "promptTokens" TO "inputTokens"; diff --git a/app/prisma/schema.prisma b/app/prisma/schema.prisma index 1600669..85aaa12 100644 --- a/app/prisma/schema.prisma +++ b/app/prisma/schema.prisma @@ -117,8 +117,8 @@ model ModelResponse { receivedAt DateTime? output Json? cost Float? - promptTokens Int? - completionTokens Int? + inputTokens Int? + outputTokens Int? statusCode Int? errorMessage String? retryTime DateTime? diff --git a/app/src/components/OutputsTable/OutputCell/OutputStats.tsx b/app/src/components/OutputsTable/OutputCell/OutputStats.tsx index ab0fafb..afb9c62 100644 --- a/app/src/components/OutputsTable/OutputCell/OutputStats.tsx +++ b/app/src/components/OutputsTable/OutputCell/OutputStats.tsx @@ -19,8 +19,8 @@ export const OutputStats = ({ ? modelResponse.receivedAt.getTime() - modelResponse.requestedAt.getTime() : 0; - const promptTokens = modelResponse.promptTokens; - const completionTokens = modelResponse.completionTokens; + const inputTokens = modelResponse.inputTokens; + const outputTokens = modelResponse.outputTokens; return ( {modelResponse.cost && ( diff --git a/app/src/components/OutputsTable/VariantStats.tsx b/app/src/components/OutputsTable/VariantStats.tsx index 40bae1d..8c81fdb 100644 --- a/app/src/components/OutputsTable/VariantStats.tsx +++ b/app/src/components/OutputsTable/VariantStats.tsx @@ -17,8 +17,8 @@ export default function VariantStats(props: { variant: PromptVariant }) { initialData: { evalResults: [], overallCost: 0, - promptTokens: 0, - completionTokens: 0, + inputTokens: 0, + outputTokens: 0, scenarioCount: 0, outputCount: 0, awaitingEvals: false, @@ -68,8 +68,8 @@ export default function VariantStats(props: { variant: PromptVariant }) { {data.overallCost && ( diff --git a/app/src/components/tooltip/CostTooltip.tsx b/app/src/components/tooltip/CostTooltip.tsx index 68cf3ea..0e2cd17 100644 --- a/app/src/components/tooltip/CostTooltip.tsx +++ b/app/src/components/tooltip/CostTooltip.tsx @@ -2,14 +2,14 @@ import { HStack, Icon, Text, Tooltip, type TooltipProps, VStack, Divider } from import { BsCurrencyDollar } from "react-icons/bs"; type CostTooltipProps = { - promptTokens: number | null; - completionTokens: number | null; + inputTokens: number | null; + outputTokens: number | null; cost: number; } & TooltipProps; export const CostTooltip = ({ - promptTokens, - completionTokens, + inputTokens, + outputTokens, cost, children, ...props @@ -36,12 +36,12 @@ export const CostTooltip = ({ Prompt - {promptTokens ?? 0} + {inputTokens ?? 0} Completion - {completionTokens ?? 0} + {outputTokens ?? 0} diff --git a/app/src/modelProviders/types.ts b/app/src/modelProviders/types.ts index 5e5bf26..6b5e09e 100644 --- a/app/src/modelProviders/types.ts +++ b/app/src/modelProviders/types.ts @@ -43,9 +43,6 @@ export type CompletionResponse = value: T; timeToComplete: number; statusCode: number; - promptTokens?: number; - completionTokens?: number; - cost?: number; }; export type ModelProvider = { @@ -56,6 +53,10 @@ export type ModelProvider void) | null, ) => Promise>; + getUsage: ( + input: InputSchema, + output: OutputSchema, + ) => { gpuRuntime?: number; inputTokens?: number; outputTokens?: number; cost?: number } | null; // This is just a convenience for type inference, don't use it at runtime _outputSchema?: OutputSchema | null; diff --git a/app/src/server/api/routers/promptVariants.router.ts b/app/src/server/api/routers/promptVariants.router.ts index f19bbd0..5a0d1fc 100644 --- a/app/src/server/api/routers/promptVariants.router.ts +++ b/app/src/server/api/routers/promptVariants.router.ts @@ -123,13 +123,13 @@ export const promptVariantsRouter = createTRPCRouter({ }, _sum: { cost: true, - promptTokens: true, - completionTokens: true, + inputTokens: true, + outputTokens: true, }, }); - const promptTokens = overallTokens._sum?.promptTokens ?? 0; - const completionTokens = overallTokens._sum?.completionTokens ?? 0; + const inputTokens = overallTokens._sum?.inputTokens ?? 0; + const outputTokens = overallTokens._sum?.outputTokens ?? 0; const awaitingEvals = !!evalResults.find( (result) => result.totalCount < scenarioCount * evals.length, @@ -137,8 +137,8 @@ export const promptVariantsRouter = createTRPCRouter({ return { evalResults, - promptTokens, - completionTokens, + inputTokens, + outputTokens, overallCost: overallTokens._sum?.cost ?? 0, scenarioCount, outputCount, diff --git a/app/src/server/tasks/queryModel.task.ts b/app/src/server/tasks/queryModel.task.ts index d7a5dc8..e66fdfa 100644 --- a/app/src/server/tasks/queryModel.task.ts +++ b/app/src/server/tasks/queryModel.task.ts @@ -110,15 +110,16 @@ export const queryModel = defineTask("queryModel", async (task) = }); const response = await provider.getCompletion(prompt.modelInput, onStream); if (response.type === "success") { + const usage = provider.getUsage(prompt.modelInput, response.value); modelResponse = await prisma.modelResponse.update({ where: { id: modelResponse.id }, data: { output: response.value as Prisma.InputJsonObject, statusCode: response.statusCode, receivedAt: new Date(), - promptTokens: response.promptTokens, - completionTokens: response.completionTokens, - cost: response.cost, + inputTokens: usage?.inputTokens, + outputTokens: usage?.outputTokens, + cost: usage?.cost, }, });