diff --git a/package.json b/package.json index cd5402d..d4f6769 100644 --- a/package.json +++ b/package.json @@ -43,6 +43,7 @@ "next-auth": "^4.22.1", "nextjs-routes": "^2.0.1", "openai": "4.0.0-beta.2", + "pluralize": "^8.0.0", "posthog-js": "^1.68.4", "react": "18.2.0", "react-dom": "18.2.0", @@ -64,6 +65,7 @@ "@types/express": "^4.17.17", "@types/lodash": "^4.14.195", "@types/node": "^18.16.0", + "@types/pluralize": "^0.0.30", "@types/react": "^18.2.6", "@types/react-dom": "^18.2.4", "@types/react-syntax-highlighter": "^15.5.7", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 41da3d7..d9d47a6 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1,4 +1,4 @@ -lockfileVersion: '6.1' +lockfileVersion: '6.0' settings: autoInstallPeers: true @@ -92,6 +92,9 @@ dependencies: openai: specifier: 4.0.0-beta.2 version: 4.0.0-beta.2 + pluralize: + specifier: ^8.0.0 + version: 8.0.0 posthog-js: specifier: ^1.68.4 version: 1.68.4 @@ -151,6 +154,9 @@ devDependencies: '@types/node': specifier: ^18.16.0 version: 18.16.0 + '@types/pluralize': + specifier: ^0.0.30 + version: 0.0.30 '@types/react': specifier: ^18.2.6 version: 18.2.6 @@ -2179,6 +2185,10 @@ packages: resolution: {integrity: sha512-//oorEZjL6sbPcKUaCdIGlIUeH26mgzimjBB77G6XRgnDl/L5wOnpyBGRe/Mmf5CVW3PwEBE1NjiMZ/ssFh4wA==} dev: false + /@types/pluralize@0.0.30: + resolution: {integrity: sha512-kVww6xZrW/db5BR9OqiT71J9huRdQ+z/r+LbDuT7/EK50mCmj5FoaIARnVv0rvjUS/YpDox0cDU9lpQT011VBA==} + dev: true + /@types/prop-types@15.7.5: resolution: {integrity: sha512-JCB8C6SnDoQf0cNycqd/35A7MjcnK+ZTqE7judS6o7utxUCg6imJg3QK2qzHKszlTjcj2cn+NwMB2i96ubpj7w==} @@ -4883,6 +4893,11 @@ packages: resolution: {integrity: sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==} engines: {node: '>=8.6'} + /pluralize@8.0.0: + resolution: {integrity: sha512-Nc3IT5yHzflTfbjgqWcCPpo7DaKy4FnpB0l/zCAW0Tc7jxAiuqSxHasntB3D7887LSrA93kDJ9IXovxJYxyLCA==} + engines: {node: '>=4'} + dev: false + /postcss@8.4.14: resolution: {integrity: sha512-E398TUmfAYFPBSdzgeieK2Y1+1cpdxJx8yXbK/m57nRhKSmk1GB2tO4lbLBtlkfPQTDKfe4Xqv1ASWPpayPEig==} engines: {node: ^10 || ^12 || >=14} diff --git a/src/components/OutputsTable/OutputCell/ErrorHandler.tsx b/src/components/OutputsTable/OutputCell/ErrorHandler.tsx new file mode 100644 index 0000000..0e92a1a --- /dev/null +++ b/src/components/OutputsTable/OutputCell/ErrorHandler.tsx @@ -0,0 +1,102 @@ +import { type ModelOutput } from "@prisma/client"; +import { HStack, VStack, Text, Button, Icon } from "@chakra-ui/react"; +import { useEffect, useMemo, useState } from "react"; +import { BsArrowClockwise } from "react-icons/bs"; +import { rateLimitErrorMessage } from "~/sharedStrings"; +import pluralize from 'pluralize' + +const MAX_AUTO_RETRIES = 3; + +export const ErrorHandler = ({ + output, + refetchOutput, + numPreviousTries, +}: { + output: ModelOutput; + refetchOutput: () => void; + numPreviousTries: number; +}) => { + const [msToWait, setMsToWait] = useState(0); + const shouldAutoRetry = + output.errorMessage === rateLimitErrorMessage && numPreviousTries < MAX_AUTO_RETRIES; + + const errorMessage = useMemo(() => breakLongWords(output.errorMessage), [output.errorMessage]); + + useEffect(() => { + if (!shouldAutoRetry) return; + + const initialWaitTime = calculateDelay(numPreviousTries); + const msModuloOneSecond = initialWaitTime % 1000; + let remainingTime = initialWaitTime - msModuloOneSecond; + setMsToWait(remainingTime); + + let interval: NodeJS.Timeout; + const timeout = setTimeout(() => { + interval = setInterval(() => { + remainingTime -= 1000; + setMsToWait(remainingTime); + + if (remainingTime <= 0) { + refetchOutput(); + clearInterval(interval); + } + }, 1000); + }, msModuloOneSecond); + + return () => { + clearInterval(interval); + clearTimeout(timeout); + }; + }, [shouldAutoRetry, setMsToWait, refetchOutput, numPreviousTries]); + + return ( + + + + Error + + + + {errorMessage} + {msToWait > 0 && ( + + Retrying in {pluralize('second', Math.ceil(msToWait / 1000), true)}... + + )} + + ); +}; + +function breakLongWords(str: string | null): string { + if (!str) return ""; + const words = str.split(" "); + + const newWords = words.map((word) => { + return word.length > 20 ? word.slice(0, 20) + "\u200B" + word.slice(20) : word; + }); + + return newWords.join(" "); +} + +const MIN_DELAY = 500; // milliseconds +const MAX_DELAY = 5000; // milliseconds + +function calculateDelay(numPreviousTries: number): number { + const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries)); + const jitter = Math.random() * baseDelay; + return baseDelay + jitter; +} diff --git a/src/components/OutputsTable/OutputCell.tsx b/src/components/OutputsTable/OutputCell/OutputCell.tsx similarity index 57% rename from src/components/OutputsTable/OutputCell.tsx rename to src/components/OutputsTable/OutputCell/OutputCell.tsx index de6dcdb..012a6e9 100644 --- a/src/components/OutputsTable/OutputCell.tsx +++ b/src/components/OutputsTable/OutputCell/OutputCell.tsx @@ -1,21 +1,25 @@ import { type RouterOutputs, api } from "~/utils/api"; -import { type PromptVariant, type Scenario } from "./types"; -import { Spinner, Text, Box, Center, Flex, Icon, HStack } from "@chakra-ui/react"; +import { type PromptVariant, type Scenario } from "../types"; +import { + Spinner, + Text, + Box, + Center, + Flex, +} from "@chakra-ui/react"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import SyntaxHighlighter from "react-syntax-highlighter"; import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs"; import stringify from "json-stringify-pretty-compact"; -import { type ReactElement, useState, useEffect, useRef } from "react"; -import { BsCheck, BsClock, BsX, BsCurrencyDollar } from "react-icons/bs"; -import { type ModelOutput } from "@prisma/client"; +import { type ReactElement, useState, useEffect, useRef, useCallback } from "react"; import { type ChatCompletion } from "openai/resources/chat"; import { generateChannel } from "~/utils/generateChannel"; import { isObject } from "lodash"; import useSocket from "~/utils/useSocket"; -import { evaluateOutput } from "~/server/utils/evaluateOutput"; -import { calculateTokenCost } from "~/utils/calculateTokenCost"; -import { type JSONSerializable, type SupportedModel } from "~/server/types"; +import { type JSONSerializable } from "~/server/types"; import { getModelName } from "~/server/utils/getModelName"; +import { OutputStats } from "./OutputStats"; +import { ErrorHandler } from "./ErrorHandler"; export default function OutputCell({ scenario, @@ -47,31 +51,40 @@ export default function OutputCell({ const [output, setOutput] = useState(null); const [channel, setChannel] = useState(undefined); + const [numPreviousTries, setNumPreviousTries] = useState(0); + const fetchMutex = useRef(false); - const [fetchOutput, fetchingOutput] = useHandledAsyncCallback(async () => { - if (fetchMutex.current) return; - fetchMutex.current = true; - setOutput(null); + const [fetchOutput, fetchingOutput] = useHandledAsyncCallback( + async (forceRefetch?: boolean) => { + if (fetchMutex.current) return; + setNumPreviousTries((prev) => prev + 1); - const shouldStream = - isObject(variant) && - "config" in variant && - isObject(variant.config) && - "stream" in variant.config && - variant.config.stream === true; + fetchMutex.current = true; + setOutput(null); - const channel = shouldStream ? generateChannel() : undefined; - setChannel(channel); + const shouldStream = + isObject(variant) && + "config" in variant && + isObject(variant.config) && + "stream" in variant.config && + variant.config.stream === true; - const output = await outputMutation.mutateAsync({ - scenarioId: scenario.id, - variantId: variant.id, - channel, - }); - setOutput(output); - await utils.promptVariants.stats.invalidate(); - fetchMutex.current = false; - }, [outputMutation, scenario.id, variant.id]); + const channel = shouldStream ? generateChannel() : undefined; + setChannel(channel); + + const output = await outputMutation.mutateAsync({ + scenarioId: scenario.id, + variantId: variant.id, + channel, + forceRefetch, + }); + setOutput(output); + await utils.promptVariants.stats.invalidate(); + fetchMutex.current = false; + }, + [outputMutation, scenario.id, variant.id] + ); + const hardRefetch = useCallback(() => fetchOutput(true), [fetchOutput]); useEffect(fetchOutput, [scenario.id, variant.id]); @@ -93,7 +106,13 @@ export default function OutputCell({ if (!output && !fetchingOutput) return Error retrieving output; if (output && output.errorMessage) { - return Error: {output.errorMessage}; + return ( + + ); } const response = output?.output as unknown as ChatCompletion; @@ -142,54 +161,4 @@ export default function OutputCell({ ); } -const OutputStats = ({ - model, - modelOutput, - scenario, -}: { - model: SupportedModel | null; - modelOutput: ModelOutput; - scenario: Scenario; -}) => { - const timeToComplete = modelOutput.timeToComplete; - const experiment = useExperiment(); - const evals = - api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? []; - const promptTokens = modelOutput.promptTokens; - const completionTokens = modelOutput.completionTokens; - - const promptCost = promptTokens && model ? calculateTokenCost(model, promptTokens) : 0; - const completionCost = - completionTokens && model ? calculateTokenCost(model, completionTokens, true) : 0; - - const cost = promptCost + completionCost; - - return ( - - - {evals.map((evaluation) => { - const passed = evaluateOutput(modelOutput, scenario, evaluation); - return ( - - {evaluation.name} - - - ); - })} - - - - {cost.toFixed(3)} - - - - {(timeToComplete / 1000).toFixed(2)}s - - - ); -}; diff --git a/src/components/OutputsTable/OutputCell/OutputStats.tsx b/src/components/OutputsTable/OutputCell/OutputStats.tsx new file mode 100644 index 0000000..59bc17e --- /dev/null +++ b/src/components/OutputsTable/OutputCell/OutputStats.tsx @@ -0,0 +1,61 @@ +import { type ModelOutput } from "@prisma/client"; +import { type SupportedModel } from "~/server/types"; +import { type Scenario } from "../types"; +import { useExperiment } from "~/utils/hooks"; +import { api } from "~/utils/api"; +import { calculateTokenCost } from "~/utils/calculateTokenCost"; +import { evaluateOutput } from "~/server/utils/evaluateOutput"; +import { HStack, Icon, Text } from "@chakra-ui/react"; +import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs"; + +export const OutputStats = ({ + model, + modelOutput, + scenario, +}: { + model: SupportedModel | null; + modelOutput: ModelOutput; + scenario: Scenario; +}) => { + const timeToComplete = modelOutput.timeToComplete; + const experiment = useExperiment(); + const evals = + api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? []; + + const promptTokens = modelOutput.promptTokens; + const completionTokens = modelOutput.completionTokens; + + const promptCost = promptTokens && model ? calculateTokenCost(model, promptTokens) : 0; + const completionCost = + completionTokens && model ? calculateTokenCost(model, completionTokens, true) : 0; + + const cost = promptCost + completionCost; + + return ( + + + {evals.map((evaluation) => { + const passed = evaluateOutput(modelOutput, scenario, evaluation); + return ( + + {evaluation.name} + + + ); + })} + + + + {cost.toFixed(3)} + + + + {(timeToComplete / 1000).toFixed(2)}s + + + ); +}; diff --git a/src/components/OutputsTable/ScenarioRow.tsx b/src/components/OutputsTable/ScenarioRow.tsx index d12e1dc..da09867 100644 --- a/src/components/OutputsTable/ScenarioRow.tsx +++ b/src/components/OutputsTable/ScenarioRow.tsx @@ -1,7 +1,7 @@ import { Box, GridItem } from "@chakra-ui/react"; import React, { useState } from "react"; import { cellPadding } from "../constants"; -import OutputCell from "./OutputCell"; +import OutputCell from "./OutputCell/OutputCell"; import ScenarioEditor from "./ScenarioEditor"; import type { PromptVariant, Scenario } from "./types"; diff --git a/src/server/api/routers/modelOutputs.router.ts b/src/server/api/routers/modelOutputs.router.ts index 12d0192..5f5d1fc 100644 --- a/src/server/api/routers/modelOutputs.router.ts +++ b/src/server/api/routers/modelOutputs.router.ts @@ -11,7 +11,12 @@ import { getCompletion } from "~/server/utils/getCompletion"; export const modelOutputsRouter = createTRPCRouter({ get: publicProcedure .input( - z.object({ scenarioId: z.string(), variantId: z.string(), channel: z.string().optional() }) + z.object({ + scenarioId: z.string(), + variantId: z.string(), + channel: z.string().optional(), + forceRefetch: z.boolean().optional(), + }) ) .mutation(async ({ input }) => { const existing = await prisma.modelOutput.findUnique({ @@ -23,7 +28,7 @@ export const modelOutputsRouter = createTRPCRouter({ }, }); - if (existing) return existing; + if (existing && !input.forceRefetch) return existing; const variant = await prisma.promptVariant.findUnique({ where: { @@ -69,13 +74,22 @@ export const modelOutputsRouter = createTRPCRouter({ modelResponse = await getCompletion(filledTemplate, input.channel); } - const modelOutput = await prisma.modelOutput.create({ - data: { + const modelOutput = await prisma.modelOutput.upsert({ + where: { + promptVariantId_testScenarioId: { + promptVariantId: input.variantId, + testScenarioId: input.scenarioId, + } + }, + create: { promptVariantId: input.variantId, testScenarioId: input.scenarioId, inputHash, ...modelResponse, }, + update: { + ...modelResponse, + }, }); await reevaluateVariant(input.variantId); diff --git a/src/server/utils/getCompletion.ts b/src/server/utils/getCompletion.ts index 80cfa8c..2419eed 100644 --- a/src/server/utils/getCompletion.ts +++ b/src/server/utils/getCompletion.ts @@ -8,6 +8,7 @@ import { type JSONSerializable, OpenAIChatModel } from "../types"; import { env } from "~/env.mjs"; import { countOpenAIChatTokens } from "~/utils/countTokens"; import { getModelName } from "./getModelName"; +import { rateLimitErrorMessage } from "~/sharedStrings"; env; @@ -32,9 +33,7 @@ export async function getCompletion( errorMessage: "Invalid payload provided", timeToComplete: 0, }; - if ( - modelName in OpenAIChatModel - ) { + if (modelName in OpenAIChatModel) { return getOpenAIChatCompletion( payload as unknown as CompletionCreateParams, env.OPENAI_API_KEY, @@ -93,13 +92,15 @@ export async function getOpenAIChatCompletion( } if (!response.ok) { - // If it's an object, try to get the error message - if ( + if (response.status === 429) { + resp.errorMessage = rateLimitErrorMessage; + } else if ( isObject(resp.output) && "error" in resp.output && isObject(resp.output.error) && "message" in resp.output.error ) { + // If it's an object, try to get the error message resp.errorMessage = resp.output.error.message?.toString() ?? "Unknown error"; } } @@ -108,16 +109,13 @@ export async function getOpenAIChatCompletion( const usage = resp.output.usage as unknown as ChatCompletion.Usage; resp.promptTokens = usage.prompt_tokens; resp.completionTokens = usage.completion_tokens; - } else if (isObject(resp.output) && 'choices' in resp.output) { - const model = payload.model as unknown as OpenAIChatModel - resp.promptTokens = countOpenAIChatTokens( - model, - payload.messages - ); + } else if (isObject(resp.output) && "choices" in resp.output) { + const model = payload.model as unknown as OpenAIChatModel; + resp.promptTokens = countOpenAIChatTokens(model, payload.messages); const choices = resp.output.choices as unknown as ChatCompletion.Choice[]; - const message = choices[0]?.message + const message = choices[0]?.message; if (message) { - const messages = [message] + const messages = [message]; resp.completionTokens = countOpenAIChatTokens(model, messages); } } diff --git a/src/sharedStrings.ts b/src/sharedStrings.ts new file mode 100644 index 0000000..65214de --- /dev/null +++ b/src/sharedStrings.ts @@ -0,0 +1 @@ +export const rateLimitErrorMessage = "429 - Rate limit exceeded.";