Retry requests that receive 429 (#15)
* List number of scenarios * Retry requests after 429 * Rename requestCallback * Add sleep function * Allow manual retry on frontend * Remove unused utility functions * Auto refetch * Display wait time with Math.ceil * Take one second modulo into account * Add pluralize
This commit is contained in:
@@ -43,6 +43,7 @@
|
||||
"next-auth": "^4.22.1",
|
||||
"nextjs-routes": "^2.0.1",
|
||||
"openai": "4.0.0-beta.2",
|
||||
"pluralize": "^8.0.0",
|
||||
"posthog-js": "^1.68.4",
|
||||
"react": "18.2.0",
|
||||
"react-dom": "18.2.0",
|
||||
@@ -64,6 +65,7 @@
|
||||
"@types/express": "^4.17.17",
|
||||
"@types/lodash": "^4.14.195",
|
||||
"@types/node": "^18.16.0",
|
||||
"@types/pluralize": "^0.0.30",
|
||||
"@types/react": "^18.2.6",
|
||||
"@types/react-dom": "^18.2.4",
|
||||
"@types/react-syntax-highlighter": "^15.5.7",
|
||||
|
||||
17
pnpm-lock.yaml
generated
17
pnpm-lock.yaml
generated
@@ -1,4 +1,4 @@
|
||||
lockfileVersion: '6.1'
|
||||
lockfileVersion: '6.0'
|
||||
|
||||
settings:
|
||||
autoInstallPeers: true
|
||||
@@ -92,6 +92,9 @@ dependencies:
|
||||
openai:
|
||||
specifier: 4.0.0-beta.2
|
||||
version: 4.0.0-beta.2
|
||||
pluralize:
|
||||
specifier: ^8.0.0
|
||||
version: 8.0.0
|
||||
posthog-js:
|
||||
specifier: ^1.68.4
|
||||
version: 1.68.4
|
||||
@@ -151,6 +154,9 @@ devDependencies:
|
||||
'@types/node':
|
||||
specifier: ^18.16.0
|
||||
version: 18.16.0
|
||||
'@types/pluralize':
|
||||
specifier: ^0.0.30
|
||||
version: 0.0.30
|
||||
'@types/react':
|
||||
specifier: ^18.2.6
|
||||
version: 18.2.6
|
||||
@@ -2179,6 +2185,10 @@ packages:
|
||||
resolution: {integrity: sha512-//oorEZjL6sbPcKUaCdIGlIUeH26mgzimjBB77G6XRgnDl/L5wOnpyBGRe/Mmf5CVW3PwEBE1NjiMZ/ssFh4wA==}
|
||||
dev: false
|
||||
|
||||
/@types/pluralize@0.0.30:
|
||||
resolution: {integrity: sha512-kVww6xZrW/db5BR9OqiT71J9huRdQ+z/r+LbDuT7/EK50mCmj5FoaIARnVv0rvjUS/YpDox0cDU9lpQT011VBA==}
|
||||
dev: true
|
||||
|
||||
/@types/prop-types@15.7.5:
|
||||
resolution: {integrity: sha512-JCB8C6SnDoQf0cNycqd/35A7MjcnK+ZTqE7judS6o7utxUCg6imJg3QK2qzHKszlTjcj2cn+NwMB2i96ubpj7w==}
|
||||
|
||||
@@ -4883,6 +4893,11 @@ packages:
|
||||
resolution: {integrity: sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==}
|
||||
engines: {node: '>=8.6'}
|
||||
|
||||
/pluralize@8.0.0:
|
||||
resolution: {integrity: sha512-Nc3IT5yHzflTfbjgqWcCPpo7DaKy4FnpB0l/zCAW0Tc7jxAiuqSxHasntB3D7887LSrA93kDJ9IXovxJYxyLCA==}
|
||||
engines: {node: '>=4'}
|
||||
dev: false
|
||||
|
||||
/postcss@8.4.14:
|
||||
resolution: {integrity: sha512-E398TUmfAYFPBSdzgeieK2Y1+1cpdxJx8yXbK/m57nRhKSmk1GB2tO4lbLBtlkfPQTDKfe4Xqv1ASWPpayPEig==}
|
||||
engines: {node: ^10 || ^12 || >=14}
|
||||
|
||||
102
src/components/OutputsTable/OutputCell/ErrorHandler.tsx
Normal file
102
src/components/OutputsTable/OutputCell/ErrorHandler.tsx
Normal file
@@ -0,0 +1,102 @@
|
||||
import { type ModelOutput } from "@prisma/client";
|
||||
import { HStack, VStack, Text, Button, Icon } from "@chakra-ui/react";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import { BsArrowClockwise } from "react-icons/bs";
|
||||
import { rateLimitErrorMessage } from "~/sharedStrings";
|
||||
import pluralize from 'pluralize'
|
||||
|
||||
const MAX_AUTO_RETRIES = 3;
|
||||
|
||||
export const ErrorHandler = ({
|
||||
output,
|
||||
refetchOutput,
|
||||
numPreviousTries,
|
||||
}: {
|
||||
output: ModelOutput;
|
||||
refetchOutput: () => void;
|
||||
numPreviousTries: number;
|
||||
}) => {
|
||||
const [msToWait, setMsToWait] = useState(0);
|
||||
const shouldAutoRetry =
|
||||
output.errorMessage === rateLimitErrorMessage && numPreviousTries < MAX_AUTO_RETRIES;
|
||||
|
||||
const errorMessage = useMemo(() => breakLongWords(output.errorMessage), [output.errorMessage]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!shouldAutoRetry) return;
|
||||
|
||||
const initialWaitTime = calculateDelay(numPreviousTries);
|
||||
const msModuloOneSecond = initialWaitTime % 1000;
|
||||
let remainingTime = initialWaitTime - msModuloOneSecond;
|
||||
setMsToWait(remainingTime);
|
||||
|
||||
let interval: NodeJS.Timeout;
|
||||
const timeout = setTimeout(() => {
|
||||
interval = setInterval(() => {
|
||||
remainingTime -= 1000;
|
||||
setMsToWait(remainingTime);
|
||||
|
||||
if (remainingTime <= 0) {
|
||||
refetchOutput();
|
||||
clearInterval(interval);
|
||||
}
|
||||
}, 1000);
|
||||
}, msModuloOneSecond);
|
||||
|
||||
return () => {
|
||||
clearInterval(interval);
|
||||
clearTimeout(timeout);
|
||||
};
|
||||
}, [shouldAutoRetry, setMsToWait, refetchOutput, numPreviousTries]);
|
||||
|
||||
return (
|
||||
<VStack w="full">
|
||||
<HStack w="full" alignItems="flex-start" justifyContent="space-between">
|
||||
<Text color="red.600" fontWeight="bold">
|
||||
Error
|
||||
</Text>
|
||||
<Button
|
||||
size="xs"
|
||||
w={4}
|
||||
h={4}
|
||||
px={4}
|
||||
py={4}
|
||||
minW={0}
|
||||
borderRadius={8}
|
||||
variant="ghost"
|
||||
cursor="pointer"
|
||||
onClick={refetchOutput}
|
||||
aria-label="refetch output"
|
||||
>
|
||||
<Icon as={BsArrowClockwise} boxSize={6} />
|
||||
</Button>
|
||||
</HStack>
|
||||
<Text color="red.600">{errorMessage}</Text>
|
||||
{msToWait > 0 && (
|
||||
<Text color="red.600" fontSize="sm">
|
||||
Retrying in {pluralize('second', Math.ceil(msToWait / 1000), true)}...
|
||||
</Text>
|
||||
)}
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
function breakLongWords(str: string | null): string {
|
||||
if (!str) return "";
|
||||
const words = str.split(" ");
|
||||
|
||||
const newWords = words.map((word) => {
|
||||
return word.length > 20 ? word.slice(0, 20) + "\u200B" + word.slice(20) : word;
|
||||
});
|
||||
|
||||
return newWords.join(" ");
|
||||
}
|
||||
|
||||
const MIN_DELAY = 500; // milliseconds
|
||||
const MAX_DELAY = 5000; // milliseconds
|
||||
|
||||
function calculateDelay(numPreviousTries: number): number {
|
||||
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
|
||||
const jitter = Math.random() * baseDelay;
|
||||
return baseDelay + jitter;
|
||||
}
|
||||
@@ -1,21 +1,25 @@
|
||||
import { type RouterOutputs, api } from "~/utils/api";
|
||||
import { type PromptVariant, type Scenario } from "./types";
|
||||
import { Spinner, Text, Box, Center, Flex, Icon, HStack } from "@chakra-ui/react";
|
||||
import { type PromptVariant, type Scenario } from "../types";
|
||||
import {
|
||||
Spinner,
|
||||
Text,
|
||||
Box,
|
||||
Center,
|
||||
Flex,
|
||||
} from "@chakra-ui/react";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import SyntaxHighlighter from "react-syntax-highlighter";
|
||||
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
|
||||
import stringify from "json-stringify-pretty-compact";
|
||||
import { type ReactElement, useState, useEffect, useRef } from "react";
|
||||
import { BsCheck, BsClock, BsX, BsCurrencyDollar } from "react-icons/bs";
|
||||
import { type ModelOutput } from "@prisma/client";
|
||||
import { type ReactElement, useState, useEffect, useRef, useCallback } from "react";
|
||||
import { type ChatCompletion } from "openai/resources/chat";
|
||||
import { generateChannel } from "~/utils/generateChannel";
|
||||
import { isObject } from "lodash";
|
||||
import useSocket from "~/utils/useSocket";
|
||||
import { evaluateOutput } from "~/server/utils/evaluateOutput";
|
||||
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
||||
import { type JSONSerializable, type SupportedModel } from "~/server/types";
|
||||
import { type JSONSerializable } from "~/server/types";
|
||||
import { getModelName } from "~/server/utils/getModelName";
|
||||
import { OutputStats } from "./OutputStats";
|
||||
import { ErrorHandler } from "./ErrorHandler";
|
||||
|
||||
export default function OutputCell({
|
||||
scenario,
|
||||
@@ -47,9 +51,14 @@ export default function OutputCell({
|
||||
|
||||
const [output, setOutput] = useState<RouterOutputs["outputs"]["get"]>(null);
|
||||
const [channel, setChannel] = useState<string | undefined>(undefined);
|
||||
const [numPreviousTries, setNumPreviousTries] = useState(0);
|
||||
|
||||
const fetchMutex = useRef(false);
|
||||
const [fetchOutput, fetchingOutput] = useHandledAsyncCallback(async () => {
|
||||
const [fetchOutput, fetchingOutput] = useHandledAsyncCallback(
|
||||
async (forceRefetch?: boolean) => {
|
||||
if (fetchMutex.current) return;
|
||||
setNumPreviousTries((prev) => prev + 1);
|
||||
|
||||
fetchMutex.current = true;
|
||||
setOutput(null);
|
||||
|
||||
@@ -67,11 +76,15 @@ export default function OutputCell({
|
||||
scenarioId: scenario.id,
|
||||
variantId: variant.id,
|
||||
channel,
|
||||
forceRefetch,
|
||||
});
|
||||
setOutput(output);
|
||||
await utils.promptVariants.stats.invalidate();
|
||||
fetchMutex.current = false;
|
||||
}, [outputMutation, scenario.id, variant.id]);
|
||||
},
|
||||
[outputMutation, scenario.id, variant.id]
|
||||
);
|
||||
const hardRefetch = useCallback(() => fetchOutput(true), [fetchOutput]);
|
||||
|
||||
useEffect(fetchOutput, [scenario.id, variant.id]);
|
||||
|
||||
@@ -93,7 +106,13 @@ export default function OutputCell({
|
||||
if (!output && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
|
||||
|
||||
if (output && output.errorMessage) {
|
||||
return <Text color="red.600">Error: {output.errorMessage}</Text>;
|
||||
return (
|
||||
<ErrorHandler
|
||||
output={output}
|
||||
refetchOutput={hardRefetch}
|
||||
numPreviousTries={numPreviousTries}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
const response = output?.output as unknown as ChatCompletion;
|
||||
@@ -142,54 +161,4 @@ export default function OutputCell({
|
||||
);
|
||||
}
|
||||
|
||||
const OutputStats = ({
|
||||
model,
|
||||
modelOutput,
|
||||
scenario,
|
||||
}: {
|
||||
model: SupportedModel | null;
|
||||
modelOutput: ModelOutput;
|
||||
scenario: Scenario;
|
||||
}) => {
|
||||
const timeToComplete = modelOutput.timeToComplete;
|
||||
const experiment = useExperiment();
|
||||
const evals =
|
||||
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
|
||||
|
||||
const promptTokens = modelOutput.promptTokens;
|
||||
const completionTokens = modelOutput.completionTokens;
|
||||
|
||||
const promptCost = promptTokens && model ? calculateTokenCost(model, promptTokens) : 0;
|
||||
const completionCost =
|
||||
completionTokens && model ? calculateTokenCost(model, completionTokens, true) : 0;
|
||||
|
||||
const cost = promptCost + completionCost;
|
||||
|
||||
return (
|
||||
<HStack align="center" color="gray.500" fontSize="xs" mt={2}>
|
||||
<HStack flex={1}>
|
||||
{evals.map((evaluation) => {
|
||||
const passed = evaluateOutput(modelOutput, scenario, evaluation);
|
||||
return (
|
||||
<HStack spacing={0} key={evaluation.id}>
|
||||
<Text>{evaluation.name}</Text>
|
||||
<Icon
|
||||
as={passed ? BsCheck : BsX}
|
||||
color={passed ? "green.500" : "red.500"}
|
||||
boxSize={6}
|
||||
/>
|
||||
</HStack>
|
||||
);
|
||||
})}
|
||||
</HStack>
|
||||
<HStack spacing={0}>
|
||||
<Icon as={BsCurrencyDollar} />
|
||||
<Text mr={1}>{cost.toFixed(3)}</Text>
|
||||
</HStack>
|
||||
<HStack spacing={0.5}>
|
||||
<Icon as={BsClock} />
|
||||
<Text>{(timeToComplete / 1000).toFixed(2)}s</Text>
|
||||
</HStack>
|
||||
</HStack>
|
||||
);
|
||||
};
|
||||
61
src/components/OutputsTable/OutputCell/OutputStats.tsx
Normal file
61
src/components/OutputsTable/OutputCell/OutputStats.tsx
Normal file
@@ -0,0 +1,61 @@
|
||||
import { type ModelOutput } from "@prisma/client";
|
||||
import { type SupportedModel } from "~/server/types";
|
||||
import { type Scenario } from "../types";
|
||||
import { useExperiment } from "~/utils/hooks";
|
||||
import { api } from "~/utils/api";
|
||||
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
||||
import { evaluateOutput } from "~/server/utils/evaluateOutput";
|
||||
import { HStack, Icon, Text } from "@chakra-ui/react";
|
||||
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
|
||||
|
||||
export const OutputStats = ({
|
||||
model,
|
||||
modelOutput,
|
||||
scenario,
|
||||
}: {
|
||||
model: SupportedModel | null;
|
||||
modelOutput: ModelOutput;
|
||||
scenario: Scenario;
|
||||
}) => {
|
||||
const timeToComplete = modelOutput.timeToComplete;
|
||||
const experiment = useExperiment();
|
||||
const evals =
|
||||
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
|
||||
|
||||
const promptTokens = modelOutput.promptTokens;
|
||||
const completionTokens = modelOutput.completionTokens;
|
||||
|
||||
const promptCost = promptTokens && model ? calculateTokenCost(model, promptTokens) : 0;
|
||||
const completionCost =
|
||||
completionTokens && model ? calculateTokenCost(model, completionTokens, true) : 0;
|
||||
|
||||
const cost = promptCost + completionCost;
|
||||
|
||||
return (
|
||||
<HStack align="center" color="gray.500" fontSize="xs" mt={2}>
|
||||
<HStack flex={1}>
|
||||
{evals.map((evaluation) => {
|
||||
const passed = evaluateOutput(modelOutput, scenario, evaluation);
|
||||
return (
|
||||
<HStack spacing={0} key={evaluation.id}>
|
||||
<Text>{evaluation.name}</Text>
|
||||
<Icon
|
||||
as={passed ? BsCheck : BsX}
|
||||
color={passed ? "green.500" : "red.500"}
|
||||
boxSize={6}
|
||||
/>
|
||||
</HStack>
|
||||
);
|
||||
})}
|
||||
</HStack>
|
||||
<HStack spacing={0}>
|
||||
<Icon as={BsCurrencyDollar} />
|
||||
<Text mr={1}>{cost.toFixed(3)}</Text>
|
||||
</HStack>
|
||||
<HStack spacing={0.5}>
|
||||
<Icon as={BsClock} />
|
||||
<Text>{(timeToComplete / 1000).toFixed(2)}s</Text>
|
||||
</HStack>
|
||||
</HStack>
|
||||
);
|
||||
};
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Box, GridItem } from "@chakra-ui/react";
|
||||
import React, { useState } from "react";
|
||||
import { cellPadding } from "../constants";
|
||||
import OutputCell from "./OutputCell";
|
||||
import OutputCell from "./OutputCell/OutputCell";
|
||||
import ScenarioEditor from "./ScenarioEditor";
|
||||
import type { PromptVariant, Scenario } from "./types";
|
||||
|
||||
|
||||
@@ -11,7 +11,12 @@ import { getCompletion } from "~/server/utils/getCompletion";
|
||||
export const modelOutputsRouter = createTRPCRouter({
|
||||
get: publicProcedure
|
||||
.input(
|
||||
z.object({ scenarioId: z.string(), variantId: z.string(), channel: z.string().optional() })
|
||||
z.object({
|
||||
scenarioId: z.string(),
|
||||
variantId: z.string(),
|
||||
channel: z.string().optional(),
|
||||
forceRefetch: z.boolean().optional(),
|
||||
})
|
||||
)
|
||||
.mutation(async ({ input }) => {
|
||||
const existing = await prisma.modelOutput.findUnique({
|
||||
@@ -23,7 +28,7 @@ export const modelOutputsRouter = createTRPCRouter({
|
||||
},
|
||||
});
|
||||
|
||||
if (existing) return existing;
|
||||
if (existing && !input.forceRefetch) return existing;
|
||||
|
||||
const variant = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
@@ -69,13 +74,22 @@ export const modelOutputsRouter = createTRPCRouter({
|
||||
modelResponse = await getCompletion(filledTemplate, input.channel);
|
||||
}
|
||||
|
||||
const modelOutput = await prisma.modelOutput.create({
|
||||
data: {
|
||||
const modelOutput = await prisma.modelOutput.upsert({
|
||||
where: {
|
||||
promptVariantId_testScenarioId: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenarioId: input.scenarioId,
|
||||
}
|
||||
},
|
||||
create: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenarioId: input.scenarioId,
|
||||
inputHash,
|
||||
...modelResponse,
|
||||
},
|
||||
update: {
|
||||
...modelResponse,
|
||||
},
|
||||
});
|
||||
|
||||
await reevaluateVariant(input.variantId);
|
||||
|
||||
@@ -8,6 +8,7 @@ import { type JSONSerializable, OpenAIChatModel } from "../types";
|
||||
import { env } from "~/env.mjs";
|
||||
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||
import { getModelName } from "./getModelName";
|
||||
import { rateLimitErrorMessage } from "~/sharedStrings";
|
||||
|
||||
env;
|
||||
|
||||
@@ -32,9 +33,7 @@ export async function getCompletion(
|
||||
errorMessage: "Invalid payload provided",
|
||||
timeToComplete: 0,
|
||||
};
|
||||
if (
|
||||
modelName in OpenAIChatModel
|
||||
) {
|
||||
if (modelName in OpenAIChatModel) {
|
||||
return getOpenAIChatCompletion(
|
||||
payload as unknown as CompletionCreateParams,
|
||||
env.OPENAI_API_KEY,
|
||||
@@ -93,13 +92,15 @@ export async function getOpenAIChatCompletion(
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
// If it's an object, try to get the error message
|
||||
if (
|
||||
if (response.status === 429) {
|
||||
resp.errorMessage = rateLimitErrorMessage;
|
||||
} else if (
|
||||
isObject(resp.output) &&
|
||||
"error" in resp.output &&
|
||||
isObject(resp.output.error) &&
|
||||
"message" in resp.output.error
|
||||
) {
|
||||
// If it's an object, try to get the error message
|
||||
resp.errorMessage = resp.output.error.message?.toString() ?? "Unknown error";
|
||||
}
|
||||
}
|
||||
@@ -108,16 +109,13 @@ export async function getOpenAIChatCompletion(
|
||||
const usage = resp.output.usage as unknown as ChatCompletion.Usage;
|
||||
resp.promptTokens = usage.prompt_tokens;
|
||||
resp.completionTokens = usage.completion_tokens;
|
||||
} else if (isObject(resp.output) && 'choices' in resp.output) {
|
||||
const model = payload.model as unknown as OpenAIChatModel
|
||||
resp.promptTokens = countOpenAIChatTokens(
|
||||
model,
|
||||
payload.messages
|
||||
);
|
||||
} else if (isObject(resp.output) && "choices" in resp.output) {
|
||||
const model = payload.model as unknown as OpenAIChatModel;
|
||||
resp.promptTokens = countOpenAIChatTokens(model, payload.messages);
|
||||
const choices = resp.output.choices as unknown as ChatCompletion.Choice[];
|
||||
const message = choices[0]?.message
|
||||
const message = choices[0]?.message;
|
||||
if (message) {
|
||||
const messages = [message]
|
||||
const messages = [message];
|
||||
resp.completionTokens = countOpenAIChatTokens(model, messages);
|
||||
}
|
||||
}
|
||||
|
||||
1
src/sharedStrings.ts
Normal file
1
src/sharedStrings.ts
Normal file
@@ -0,0 +1 @@
|
||||
export const rateLimitErrorMessage = "429 - Rate limit exceeded.";
|
||||
Reference in New Issue
Block a user