Remove model from promptVariant and add cost

Storing the model on promptVariant is problematic because it isn't always in sync with the actual prompt definition. I'm removing it for now to see if we can get away with that -- might have to add it back in later if this causes trouble.

Added `cost` to modelOutput as well so we can cache that, which is important given that the cost calculations won't be the same between different API providers.
This commit is contained in:
Kyle Corbitt
2023-07-19 15:48:54 -07:00
parent 4c97b9f147
commit 60765e51ac
14 changed files with 78 additions and 1200 deletions

2
pnpm-lock.yaml generated
View File

@@ -1,4 +1,4 @@
lockfileVersion: '6.0'
lockfileVersion: '6.1'
settings:
autoInstallPeers: true

View File

@@ -0,0 +1,8 @@
/*
Warnings:
- You are about to drop the column `model` on the `PromptVariant` table. All the data in the column will be lost.
*/
-- AlterTable
ALTER TABLE "ModelOutput" ADD COLUMN "cost" DOUBLE PRECISION;

View File

@@ -114,6 +114,7 @@ model ModelOutput {
inputHash String
output Json
timeToComplete Int @default(0)
cost Float?
promptTokens Int?
completionTokens Int?

File diff suppressed because one or more lines are too long

View File

@@ -129,7 +129,7 @@ export default function OutputCell({
)}
</SyntaxHighlighter>
</VStack>
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
<OutputStats modelOutput={modelOutput} scenario={scenario} />
</VStack>
);
}
@@ -143,9 +143,7 @@ export default function OutputCell({
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
<Text>{contentToDisplay}</Text>
</VStack>
{modelOutput && (
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
)}
{modelOutput && <OutputStats modelOutput={modelOutput} scenario={scenario} />}
</VStack>
);
}

View File

@@ -1,19 +1,14 @@
import { type SupportedModel } from "~/server/types";
import { type Scenario } from "../types";
import { type RouterOutputs } from "~/utils/api";
import { calculateTokenCost } from "~/utils/calculateTokenCost";
import { HStack, Icon, Text, Tooltip } from "@chakra-ui/react";
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
import { CostTooltip } from "~/components/tooltip/CostTooltip";
const SHOW_COST = true;
const SHOW_TIME = true;
export const OutputStats = ({
model,
modelOutput,
}: {
model: SupportedModel | string | null;
modelOutput: NonNullable<
NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelOutput"]
>;
@@ -24,12 +19,6 @@ export const OutputStats = ({
const promptTokens = modelOutput.promptTokens;
const completionTokens = modelOutput.completionTokens;
const promptCost = promptTokens && model ? calculateTokenCost(model, promptTokens) : 0;
const completionCost =
completionTokens && model ? calculateTokenCost(model, completionTokens, true) : 0;
const cost = promptCost + completionCost;
return (
<HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
<HStack flex={1}>
@@ -53,11 +42,15 @@ export const OutputStats = ({
);
})}
</HStack>
{SHOW_COST && (
<CostTooltip promptTokens={promptTokens} completionTokens={completionTokens} cost={cost}>
{modelOutput.cost && (
<CostTooltip
promptTokens={promptTokens}
completionTokens={completionTokens}
cost={modelOutput.cost}
>
<HStack spacing={0}>
<Icon as={BsCurrencyDollar} />
<Text mr={1}>{cost.toFixed(3)}</Text>
<Text mr={1}>{modelOutput.cost.toFixed(3)}</Text>
</HStack>
</CostTooltip>
)}

View File

@@ -1,4 +1,4 @@
import { HStack, Icon, Skeleton, Text, useToken } from "@chakra-ui/react";
import { HStack, Icon, Text, useToken } from "@chakra-ui/react";
import { type PromptVariant } from "./types";
import { cellPadding } from "../constants";
import { api } from "~/utils/api";
@@ -69,7 +69,7 @@ export default function VariantStats(props: { variant: PromptVariant }) {
);
})}
</HStack>
{data.overallCost && !data.awaitingRetrievals ? (
{data.overallCost && !data.awaitingRetrievals && (
<CostTooltip
promptTokens={data.promptTokens}
completionTokens={data.completionTokens}
@@ -80,8 +80,6 @@ export default function VariantStats(props: { variant: PromptVariant }) {
<Text mr={1}>{data.overallCost.toFixed(3)}</Text>
</HStack>
</CostTooltip>
) : (
<Skeleton height={4} width={12} mr={1} />
)}
</HStack>
);

View File

@@ -14,12 +14,12 @@ import {
} from "@chakra-ui/react";
import { BsFillTrashFill, BsGear } from "react-icons/bs";
import { FaRegClone } from "react-icons/fa";
import { RiExchangeFundsFill } from "react-icons/ri";
import { AiOutlineDiff } from "react-icons/ai";
import { useState } from "react";
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
import { RiExchangeFundsFill } from "react-icons/ri";
import { SelectModelModal } from "../SelectModelModal/SelectModelModal";
import { type SupportedModel } from "~/server/types";
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
export default function VariantHeaderMenuButton({
variant,

View File

@@ -7,7 +7,6 @@ import { OpenAIChatModel, type SupportedModel } from "~/server/types";
import { constructPrompt } from "~/server/utils/constructPrompt";
import userError from "~/server/utils/error";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
import { calculateTokenCost } from "~/utils/calculateTokenCost";
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
import { type PromptVariant } from "@prisma/client";
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
@@ -110,17 +109,14 @@ export const promptVariantsRouter = createTRPCRouter({
},
},
_sum: {
cost: true,
promptTokens: true,
completionTokens: true,
},
});
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
const overallPromptCost = calculateTokenCost(variant.model, promptTokens);
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
const overallCompletionCost = calculateTokenCost(variant.model, completionTokens, true);
const overallCost = overallPromptCost + overallCompletionCost;
const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({
where: {
@@ -137,7 +133,7 @@ export const promptVariantsRouter = createTRPCRouter({
evalResults,
promptTokens,
completionTokens,
overallCost,
overallCost: overallTokens._sum?.cost ?? 0,
scenarioCount,
outputCount,
awaitingRetrievals,
@@ -302,9 +298,12 @@ export const promptVariantsRouter = createTRPCRouter({
});
await requireCanModifyExperiment(existing.experimentId, ctx);
const constructedPrompt = await constructPrompt({ constructFn: existing.constructFn }, null);
const promptConstructionFn = await deriveNewConstructFn(
existing,
existing.model as SupportedModel,
// @ts-expect-error TODO clean this up
constructedPrompt?.model as SupportedModel,
input.instructions,
);

View File

@@ -0,0 +1,26 @@
// /* eslint-disable */
// import "dotenv/config";
// import Replicate from "replicate";
// const replicate = new Replicate({
// auth: process.env.REPLICATE_API_TOKEN || "",
// });
// console.log("going to run");
// const prediction = await replicate.predictions.create({
// version: "e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48",
// input: {
// prompt: "...",
// },
// });
// console.log("waiting");
// setInterval(() => {
// replicate.predictions.get(prediction.id).then((prediction) => {
// console.log(prediction.output);
// });
// }, 500);
// // const output = await replicate.wait(prediction, {});
// // console.log(output);

View File

@@ -1,7 +1,7 @@
import crypto from "crypto";
import { prisma } from "~/server/db";
import defineTask from "./defineTask";
import { type CompletionResponse, getCompletion } from "../utils/getCompletion";
import { type CompletionResponse, getOpenAIChatCompletion } from "../utils/getCompletion";
import { type JSONSerializable } from "../types";
import { sleep } from "../utils/sleep";
import { shouldStream } from "../utils/shouldStream";
@@ -29,7 +29,10 @@ const getCompletionWithRetries = async (
let modelResponse: CompletionResponse | null = null;
try {
for (let i = 0; i < MAX_AUTO_RETRIES; i++) {
modelResponse = await getCompletion(payload as unknown as CompletionCreateParams, channel);
modelResponse = await getOpenAIChatCompletion(
payload as unknown as CompletionCreateParams,
channel,
);
if (modelResponse.statusCode !== 429 || i === MAX_AUTO_RETRIES - 1) {
return modelResponse;
}
@@ -50,7 +53,7 @@ const getCompletionWithRetries = async (
return {
statusCode: modelResponse?.statusCode ?? 500,
errorMessage: modelResponse?.errorMessage ?? (error as Error).message,
output: null as unknown as Prisma.InputJsonValue,
output: null,
timeToComplete: 0,
};
}
@@ -149,10 +152,11 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
data: {
scenarioVariantCellId,
inputHash,
output: modelResponse.output,
output: modelResponse.output as unknown as Prisma.InputJsonObject,
timeToComplete: modelResponse.timeToComplete,
promptTokens: modelResponse.promptTokens,
completionTokens: modelResponse.completionTokens,
cost: modelResponse.cost,
},
});
}

View File

@@ -62,6 +62,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
inputHash,
output: matchingModelOutput.output as Prisma.InputJsonValue,
timeToComplete: matchingModelOutput.timeToComplete,
cost: matchingModelOutput.cost,
promptTokens: matchingModelOutput.promptTokens,
completionTokens: matchingModelOutput.completionTokens,
createdAt: matchingModelOutput.createdAt,

View File

@@ -1,30 +1,24 @@
/* eslint-disable @typescript-eslint/no-unsafe-call */
import { isObject } from "lodash-es";
import { Prisma } from "@prisma/client";
import { streamChatCompletion } from "./openai";
import { wsConnection } from "~/utils/wsConnection";
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
import { type OpenAIChatModel } from "../types";
import { type SupportedModel, type OpenAIChatModel } from "../types";
import { env } from "~/env.mjs";
import { countOpenAIChatTokens } from "~/utils/countTokens";
import { rateLimitErrorMessage } from "~/sharedStrings";
import { modelStats } from "../modelStats";
export type CompletionResponse = {
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
output: ChatCompletion | null;
statusCode: number;
errorMessage: string | null;
timeToComplete: number;
promptTokens?: number;
completionTokens?: number;
cost?: number;
};
export async function getCompletion(
payload: CompletionCreateParams,
channel?: string,
): Promise<CompletionResponse> {
return getOpenAIChatCompletion(payload, channel);
}
export async function getOpenAIChatCompletion(
payload: CompletionCreateParams,
channel?: string,
@@ -42,7 +36,7 @@ export async function getOpenAIChatCompletion(
});
const resp: CompletionResponse = {
output: Prisma.JsonNull,
output: null,
errorMessage: null,
statusCode: response.status,
timeToComplete: 0,
@@ -59,7 +53,7 @@ export async function getOpenAIChatCompletion(
}
})().catch((err) => console.error(err));
if (finalOutput) {
resp.output = finalOutput as unknown as Prisma.InputJsonValue;
resp.output = finalOutput;
resp.timeToComplete = Date.now() - start;
}
} else {
@@ -95,6 +89,13 @@ export async function getOpenAIChatCompletion(
resp.completionTokens = countOpenAIChatTokens(model, messages);
}
}
const stats = modelStats[resp.output?.model as SupportedModel];
if (stats && resp.promptTokens && resp.completionTokens) {
resp.cost =
resp.promptTokens * stats.promptTokenPrice +
resp.completionTokens * stats.completionTokenPrice;
}
} catch (e) {
console.error(e);
if (response.ok) {

View File

@@ -1,25 +0,0 @@
import { modelStats } from "~/server/modelStats";
import { type SupportedModel, OpenAIChatModel } from "~/server/types";
export const calculateTokenCost = (
model: SupportedModel | string | null,
numTokens: number,
isCompletion = false,
) => {
if (!model) return 0;
if (model in OpenAIChatModel) {
return calculateOpenAIChatTokenCost(model as OpenAIChatModel, numTokens, isCompletion);
}
return 0;
};
const calculateOpenAIChatTokenCost = (
model: OpenAIChatModel,
numTokens: number,
isCompletion: boolean,
) => {
const tokensToDollars = isCompletion
? modelStats[model].completionTokenPrice
: modelStats[model].promptTokenPrice;
return tokensToDollars * numTokens;
};