Remove model from promptVariant and add cost
Storing the model on promptVariant is problematic because it isn't always in sync with the actual prompt definition. I'm removing it for now to see if we can get away with that -- might have to add it back in later if this causes trouble. Added `cost` to modelOutput as well so we can cache that, which is important given that the cost calculations won't be the same between different API providers.
This commit is contained in:
2
pnpm-lock.yaml
generated
2
pnpm-lock.yaml
generated
@@ -1,4 +1,4 @@
|
||||
lockfileVersion: '6.0'
|
||||
lockfileVersion: '6.1'
|
||||
|
||||
settings:
|
||||
autoInstallPeers: true
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- You are about to drop the column `model` on the `PromptVariant` table. All the data in the column will be lost.
|
||||
|
||||
*/
|
||||
-- AlterTable
|
||||
ALTER TABLE "ModelOutput" ADD COLUMN "cost" DOUBLE PRECISION;
|
||||
@@ -94,7 +94,7 @@ model ScenarioVariantCell {
|
||||
streamingChannel String?
|
||||
retrievalStatus CellRetrievalStatus @default(COMPLETE)
|
||||
|
||||
modelOutput ModelOutput?
|
||||
modelOutput ModelOutput?
|
||||
|
||||
promptVariantId String @db.Uuid
|
||||
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
|
||||
@@ -114,6 +114,7 @@ model ModelOutput {
|
||||
inputHash String
|
||||
output Json
|
||||
timeToComplete Int @default(0)
|
||||
cost Float?
|
||||
promptTokens Int?
|
||||
completionTokens Int?
|
||||
|
||||
|
||||
1126
prisma/seedDemo.ts
1126
prisma/seedDemo.ts
File diff suppressed because one or more lines are too long
@@ -129,7 +129,7 @@ export default function OutputCell({
|
||||
)}
|
||||
</SyntaxHighlighter>
|
||||
</VStack>
|
||||
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
|
||||
<OutputStats modelOutput={modelOutput} scenario={scenario} />
|
||||
</VStack>
|
||||
);
|
||||
}
|
||||
@@ -143,9 +143,7 @@ export default function OutputCell({
|
||||
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
|
||||
<Text>{contentToDisplay}</Text>
|
||||
</VStack>
|
||||
{modelOutput && (
|
||||
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
|
||||
)}
|
||||
{modelOutput && <OutputStats modelOutput={modelOutput} scenario={scenario} />}
|
||||
</VStack>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,19 +1,14 @@
|
||||
import { type SupportedModel } from "~/server/types";
|
||||
import { type Scenario } from "../types";
|
||||
import { type RouterOutputs } from "~/utils/api";
|
||||
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
||||
import { HStack, Icon, Text, Tooltip } from "@chakra-ui/react";
|
||||
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
|
||||
import { CostTooltip } from "~/components/tooltip/CostTooltip";
|
||||
|
||||
const SHOW_COST = true;
|
||||
const SHOW_TIME = true;
|
||||
|
||||
export const OutputStats = ({
|
||||
model,
|
||||
modelOutput,
|
||||
}: {
|
||||
model: SupportedModel | string | null;
|
||||
modelOutput: NonNullable<
|
||||
NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelOutput"]
|
||||
>;
|
||||
@@ -24,12 +19,6 @@ export const OutputStats = ({
|
||||
const promptTokens = modelOutput.promptTokens;
|
||||
const completionTokens = modelOutput.completionTokens;
|
||||
|
||||
const promptCost = promptTokens && model ? calculateTokenCost(model, promptTokens) : 0;
|
||||
const completionCost =
|
||||
completionTokens && model ? calculateTokenCost(model, completionTokens, true) : 0;
|
||||
|
||||
const cost = promptCost + completionCost;
|
||||
|
||||
return (
|
||||
<HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
|
||||
<HStack flex={1}>
|
||||
@@ -53,11 +42,15 @@ export const OutputStats = ({
|
||||
);
|
||||
})}
|
||||
</HStack>
|
||||
{SHOW_COST && (
|
||||
<CostTooltip promptTokens={promptTokens} completionTokens={completionTokens} cost={cost}>
|
||||
{modelOutput.cost && (
|
||||
<CostTooltip
|
||||
promptTokens={promptTokens}
|
||||
completionTokens={completionTokens}
|
||||
cost={modelOutput.cost}
|
||||
>
|
||||
<HStack spacing={0}>
|
||||
<Icon as={BsCurrencyDollar} />
|
||||
<Text mr={1}>{cost.toFixed(3)}</Text>
|
||||
<Text mr={1}>{modelOutput.cost.toFixed(3)}</Text>
|
||||
</HStack>
|
||||
</CostTooltip>
|
||||
)}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { HStack, Icon, Skeleton, Text, useToken } from "@chakra-ui/react";
|
||||
import { HStack, Icon, Text, useToken } from "@chakra-ui/react";
|
||||
import { type PromptVariant } from "./types";
|
||||
import { cellPadding } from "../constants";
|
||||
import { api } from "~/utils/api";
|
||||
@@ -69,7 +69,7 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
||||
);
|
||||
})}
|
||||
</HStack>
|
||||
{data.overallCost && !data.awaitingRetrievals ? (
|
||||
{data.overallCost && !data.awaitingRetrievals && (
|
||||
<CostTooltip
|
||||
promptTokens={data.promptTokens}
|
||||
completionTokens={data.completionTokens}
|
||||
@@ -80,8 +80,6 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
||||
<Text mr={1}>{data.overallCost.toFixed(3)}</Text>
|
||||
</HStack>
|
||||
</CostTooltip>
|
||||
) : (
|
||||
<Skeleton height={4} width={12} mr={1} />
|
||||
)}
|
||||
</HStack>
|
||||
);
|
||||
|
||||
@@ -14,12 +14,12 @@ import {
|
||||
} from "@chakra-ui/react";
|
||||
import { BsFillTrashFill, BsGear } from "react-icons/bs";
|
||||
import { FaRegClone } from "react-icons/fa";
|
||||
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||
import { AiOutlineDiff } from "react-icons/ai";
|
||||
import { useState } from "react";
|
||||
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
|
||||
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||
import { SelectModelModal } from "../SelectModelModal/SelectModelModal";
|
||||
import { type SupportedModel } from "~/server/types";
|
||||
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
|
||||
|
||||
export default function VariantHeaderMenuButton({
|
||||
variant,
|
||||
|
||||
@@ -7,7 +7,6 @@ import { OpenAIChatModel, type SupportedModel } from "~/server/types";
|
||||
import { constructPrompt } from "~/server/utils/constructPrompt";
|
||||
import userError from "~/server/utils/error";
|
||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
||||
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
|
||||
import { type PromptVariant } from "@prisma/client";
|
||||
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
|
||||
@@ -110,17 +109,14 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
},
|
||||
},
|
||||
_sum: {
|
||||
cost: true,
|
||||
promptTokens: true,
|
||||
completionTokens: true,
|
||||
},
|
||||
});
|
||||
|
||||
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
|
||||
const overallPromptCost = calculateTokenCost(variant.model, promptTokens);
|
||||
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
|
||||
const overallCompletionCost = calculateTokenCost(variant.model, completionTokens, true);
|
||||
|
||||
const overallCost = overallPromptCost + overallCompletionCost;
|
||||
|
||||
const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({
|
||||
where: {
|
||||
@@ -137,7 +133,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
evalResults,
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
overallCost,
|
||||
overallCost: overallTokens._sum?.cost ?? 0,
|
||||
scenarioCount,
|
||||
outputCount,
|
||||
awaitingRetrievals,
|
||||
@@ -302,9 +298,12 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
});
|
||||
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||
|
||||
const constructedPrompt = await constructPrompt({ constructFn: existing.constructFn }, null);
|
||||
|
||||
const promptConstructionFn = await deriveNewConstructFn(
|
||||
existing,
|
||||
existing.model as SupportedModel,
|
||||
// @ts-expect-error TODO clean this up
|
||||
constructedPrompt?.model as SupportedModel,
|
||||
input.instructions,
|
||||
);
|
||||
|
||||
|
||||
26
src/server/scripts/replicate-test.ts
Normal file
26
src/server/scripts/replicate-test.ts
Normal file
@@ -0,0 +1,26 @@
|
||||
// /* eslint-disable */
|
||||
|
||||
// import "dotenv/config";
|
||||
// import Replicate from "replicate";
|
||||
|
||||
// const replicate = new Replicate({
|
||||
// auth: process.env.REPLICATE_API_TOKEN || "",
|
||||
// });
|
||||
|
||||
// console.log("going to run");
|
||||
// const prediction = await replicate.predictions.create({
|
||||
// version: "e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48",
|
||||
// input: {
|
||||
// prompt: "...",
|
||||
// },
|
||||
// });
|
||||
|
||||
// console.log("waiting");
|
||||
// setInterval(() => {
|
||||
// replicate.predictions.get(prediction.id).then((prediction) => {
|
||||
// console.log(prediction.output);
|
||||
// });
|
||||
// }, 500);
|
||||
// // const output = await replicate.wait(prediction, {});
|
||||
|
||||
// // console.log(output);
|
||||
@@ -1,7 +1,7 @@
|
||||
import crypto from "crypto";
|
||||
import { prisma } from "~/server/db";
|
||||
import defineTask from "./defineTask";
|
||||
import { type CompletionResponse, getCompletion } from "../utils/getCompletion";
|
||||
import { type CompletionResponse, getOpenAIChatCompletion } from "../utils/getCompletion";
|
||||
import { type JSONSerializable } from "../types";
|
||||
import { sleep } from "../utils/sleep";
|
||||
import { shouldStream } from "../utils/shouldStream";
|
||||
@@ -29,7 +29,10 @@ const getCompletionWithRetries = async (
|
||||
let modelResponse: CompletionResponse | null = null;
|
||||
try {
|
||||
for (let i = 0; i < MAX_AUTO_RETRIES; i++) {
|
||||
modelResponse = await getCompletion(payload as unknown as CompletionCreateParams, channel);
|
||||
modelResponse = await getOpenAIChatCompletion(
|
||||
payload as unknown as CompletionCreateParams,
|
||||
channel,
|
||||
);
|
||||
if (modelResponse.statusCode !== 429 || i === MAX_AUTO_RETRIES - 1) {
|
||||
return modelResponse;
|
||||
}
|
||||
@@ -50,7 +53,7 @@ const getCompletionWithRetries = async (
|
||||
return {
|
||||
statusCode: modelResponse?.statusCode ?? 500,
|
||||
errorMessage: modelResponse?.errorMessage ?? (error as Error).message,
|
||||
output: null as unknown as Prisma.InputJsonValue,
|
||||
output: null,
|
||||
timeToComplete: 0,
|
||||
};
|
||||
}
|
||||
@@ -149,10 +152,11 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
data: {
|
||||
scenarioVariantCellId,
|
||||
inputHash,
|
||||
output: modelResponse.output,
|
||||
output: modelResponse.output as unknown as Prisma.InputJsonObject,
|
||||
timeToComplete: modelResponse.timeToComplete,
|
||||
promptTokens: modelResponse.promptTokens,
|
||||
completionTokens: modelResponse.completionTokens,
|
||||
cost: modelResponse.cost,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
@@ -62,6 +62,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
|
||||
inputHash,
|
||||
output: matchingModelOutput.output as Prisma.InputJsonValue,
|
||||
timeToComplete: matchingModelOutput.timeToComplete,
|
||||
cost: matchingModelOutput.cost,
|
||||
promptTokens: matchingModelOutput.promptTokens,
|
||||
completionTokens: matchingModelOutput.completionTokens,
|
||||
createdAt: matchingModelOutput.createdAt,
|
||||
|
||||
@@ -1,30 +1,24 @@
|
||||
/* eslint-disable @typescript-eslint/no-unsafe-call */
|
||||
import { isObject } from "lodash-es";
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { streamChatCompletion } from "./openai";
|
||||
import { wsConnection } from "~/utils/wsConnection";
|
||||
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
|
||||
import { type OpenAIChatModel } from "../types";
|
||||
import { type SupportedModel, type OpenAIChatModel } from "../types";
|
||||
import { env } from "~/env.mjs";
|
||||
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||
import { rateLimitErrorMessage } from "~/sharedStrings";
|
||||
import { modelStats } from "../modelStats";
|
||||
|
||||
export type CompletionResponse = {
|
||||
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
|
||||
output: ChatCompletion | null;
|
||||
statusCode: number;
|
||||
errorMessage: string | null;
|
||||
timeToComplete: number;
|
||||
promptTokens?: number;
|
||||
completionTokens?: number;
|
||||
cost?: number;
|
||||
};
|
||||
|
||||
export async function getCompletion(
|
||||
payload: CompletionCreateParams,
|
||||
channel?: string,
|
||||
): Promise<CompletionResponse> {
|
||||
return getOpenAIChatCompletion(payload, channel);
|
||||
}
|
||||
|
||||
export async function getOpenAIChatCompletion(
|
||||
payload: CompletionCreateParams,
|
||||
channel?: string,
|
||||
@@ -42,7 +36,7 @@ export async function getOpenAIChatCompletion(
|
||||
});
|
||||
|
||||
const resp: CompletionResponse = {
|
||||
output: Prisma.JsonNull,
|
||||
output: null,
|
||||
errorMessage: null,
|
||||
statusCode: response.status,
|
||||
timeToComplete: 0,
|
||||
@@ -59,7 +53,7 @@ export async function getOpenAIChatCompletion(
|
||||
}
|
||||
})().catch((err) => console.error(err));
|
||||
if (finalOutput) {
|
||||
resp.output = finalOutput as unknown as Prisma.InputJsonValue;
|
||||
resp.output = finalOutput;
|
||||
resp.timeToComplete = Date.now() - start;
|
||||
}
|
||||
} else {
|
||||
@@ -95,6 +89,13 @@ export async function getOpenAIChatCompletion(
|
||||
resp.completionTokens = countOpenAIChatTokens(model, messages);
|
||||
}
|
||||
}
|
||||
|
||||
const stats = modelStats[resp.output?.model as SupportedModel];
|
||||
if (stats && resp.promptTokens && resp.completionTokens) {
|
||||
resp.cost =
|
||||
resp.promptTokens * stats.promptTokenPrice +
|
||||
resp.completionTokens * stats.completionTokenPrice;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
if (response.ok) {
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
import { modelStats } from "~/server/modelStats";
|
||||
import { type SupportedModel, OpenAIChatModel } from "~/server/types";
|
||||
|
||||
export const calculateTokenCost = (
|
||||
model: SupportedModel | string | null,
|
||||
numTokens: number,
|
||||
isCompletion = false,
|
||||
) => {
|
||||
if (!model) return 0;
|
||||
if (model in OpenAIChatModel) {
|
||||
return calculateOpenAIChatTokenCost(model as OpenAIChatModel, numTokens, isCompletion);
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
|
||||
const calculateOpenAIChatTokenCost = (
|
||||
model: OpenAIChatModel,
|
||||
numTokens: number,
|
||||
isCompletion: boolean,
|
||||
) => {
|
||||
const tokensToDollars = isCompletion
|
||||
? modelStats[model].completionTokenPrice
|
||||
: modelStats[model].promptTokenPrice;
|
||||
return tokensToDollars * numTokens;
|
||||
};
|
||||
Reference in New Issue
Block a user