Remove model from promptVariant and add cost

Storing the model on promptVariant is problematic because it isn't always in sync with the actual prompt definition. I'm removing it for now to see if we can get away with that -- might have to add it back in later if this causes trouble.

Added `cost` to modelOutput as well so we can cache that, which is important given that the cost calculations won't be the same between different API providers.
This commit is contained in:
Kyle Corbitt
2023-07-19 15:48:54 -07:00
parent 4c97b9f147
commit 60765e51ac
14 changed files with 78 additions and 1200 deletions

View File

@@ -62,6 +62,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
inputHash,
output: matchingModelOutput.output as Prisma.InputJsonValue,
timeToComplete: matchingModelOutput.timeToComplete,
cost: matchingModelOutput.cost,
promptTokens: matchingModelOutput.promptTokens,
completionTokens: matchingModelOutput.completionTokens,
createdAt: matchingModelOutput.createdAt,

View File

@@ -1,30 +1,24 @@
/* eslint-disable @typescript-eslint/no-unsafe-call */
import { isObject } from "lodash-es";
import { Prisma } from "@prisma/client";
import { streamChatCompletion } from "./openai";
import { wsConnection } from "~/utils/wsConnection";
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
import { type OpenAIChatModel } from "../types";
import { type SupportedModel, type OpenAIChatModel } from "../types";
import { env } from "~/env.mjs";
import { countOpenAIChatTokens } from "~/utils/countTokens";
import { rateLimitErrorMessage } from "~/sharedStrings";
import { modelStats } from "../modelStats";
export type CompletionResponse = {
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
output: ChatCompletion | null;
statusCode: number;
errorMessage: string | null;
timeToComplete: number;
promptTokens?: number;
completionTokens?: number;
cost?: number;
};
export async function getCompletion(
payload: CompletionCreateParams,
channel?: string,
): Promise<CompletionResponse> {
return getOpenAIChatCompletion(payload, channel);
}
export async function getOpenAIChatCompletion(
payload: CompletionCreateParams,
channel?: string,
@@ -42,7 +36,7 @@ export async function getOpenAIChatCompletion(
});
const resp: CompletionResponse = {
output: Prisma.JsonNull,
output: null,
errorMessage: null,
statusCode: response.status,
timeToComplete: 0,
@@ -59,7 +53,7 @@ export async function getOpenAIChatCompletion(
}
})().catch((err) => console.error(err));
if (finalOutput) {
resp.output = finalOutput as unknown as Prisma.InputJsonValue;
resp.output = finalOutput;
resp.timeToComplete = Date.now() - start;
}
} else {
@@ -95,6 +89,13 @@ export async function getOpenAIChatCompletion(
resp.completionTokens = countOpenAIChatTokens(model, messages);
}
}
const stats = modelStats[resp.output?.model as SupportedModel];
if (stats && resp.promptTokens && resp.completionTokens) {
resp.cost =
resp.promptTokens * stats.promptTokenPrice +
resp.completionTokens * stats.completionTokenPrice;
}
} catch (e) {
console.error(e);
if (response.ok) {