Retry requests that receive 429 (#15)

* List number of scenarios

* Retry requests after 429

* Rename requestCallback

* Add sleep function

* Allow manual retry on frontend

* Remove unused utility functions

* Auto refetch

* Display wait time with Math.ceil

* Take one second modulo into account

* Add pluralize
This commit is contained in:
arcticfly
2023-07-06 21:39:23 -07:00
committed by GitHub
parent cb15216d0b
commit a2c7ef73ec
9 changed files with 261 additions and 99 deletions

View File

@@ -43,6 +43,7 @@
"next-auth": "^4.22.1", "next-auth": "^4.22.1",
"nextjs-routes": "^2.0.1", "nextjs-routes": "^2.0.1",
"openai": "4.0.0-beta.2", "openai": "4.0.0-beta.2",
"pluralize": "^8.0.0",
"posthog-js": "^1.68.4", "posthog-js": "^1.68.4",
"react": "18.2.0", "react": "18.2.0",
"react-dom": "18.2.0", "react-dom": "18.2.0",
@@ -64,6 +65,7 @@
"@types/express": "^4.17.17", "@types/express": "^4.17.17",
"@types/lodash": "^4.14.195", "@types/lodash": "^4.14.195",
"@types/node": "^18.16.0", "@types/node": "^18.16.0",
"@types/pluralize": "^0.0.30",
"@types/react": "^18.2.6", "@types/react": "^18.2.6",
"@types/react-dom": "^18.2.4", "@types/react-dom": "^18.2.4",
"@types/react-syntax-highlighter": "^15.5.7", "@types/react-syntax-highlighter": "^15.5.7",

17
pnpm-lock.yaml generated
View File

@@ -1,4 +1,4 @@
lockfileVersion: '6.1' lockfileVersion: '6.0'
settings: settings:
autoInstallPeers: true autoInstallPeers: true
@@ -92,6 +92,9 @@ dependencies:
openai: openai:
specifier: 4.0.0-beta.2 specifier: 4.0.0-beta.2
version: 4.0.0-beta.2 version: 4.0.0-beta.2
pluralize:
specifier: ^8.0.0
version: 8.0.0
posthog-js: posthog-js:
specifier: ^1.68.4 specifier: ^1.68.4
version: 1.68.4 version: 1.68.4
@@ -151,6 +154,9 @@ devDependencies:
'@types/node': '@types/node':
specifier: ^18.16.0 specifier: ^18.16.0
version: 18.16.0 version: 18.16.0
'@types/pluralize':
specifier: ^0.0.30
version: 0.0.30
'@types/react': '@types/react':
specifier: ^18.2.6 specifier: ^18.2.6
version: 18.2.6 version: 18.2.6
@@ -2179,6 +2185,10 @@ packages:
resolution: {integrity: sha512-//oorEZjL6sbPcKUaCdIGlIUeH26mgzimjBB77G6XRgnDl/L5wOnpyBGRe/Mmf5CVW3PwEBE1NjiMZ/ssFh4wA==} resolution: {integrity: sha512-//oorEZjL6sbPcKUaCdIGlIUeH26mgzimjBB77G6XRgnDl/L5wOnpyBGRe/Mmf5CVW3PwEBE1NjiMZ/ssFh4wA==}
dev: false dev: false
/@types/pluralize@0.0.30:
resolution: {integrity: sha512-kVww6xZrW/db5BR9OqiT71J9huRdQ+z/r+LbDuT7/EK50mCmj5FoaIARnVv0rvjUS/YpDox0cDU9lpQT011VBA==}
dev: true
/@types/prop-types@15.7.5: /@types/prop-types@15.7.5:
resolution: {integrity: sha512-JCB8C6SnDoQf0cNycqd/35A7MjcnK+ZTqE7judS6o7utxUCg6imJg3QK2qzHKszlTjcj2cn+NwMB2i96ubpj7w==} resolution: {integrity: sha512-JCB8C6SnDoQf0cNycqd/35A7MjcnK+ZTqE7judS6o7utxUCg6imJg3QK2qzHKszlTjcj2cn+NwMB2i96ubpj7w==}
@@ -4883,6 +4893,11 @@ packages:
resolution: {integrity: sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==} resolution: {integrity: sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==}
engines: {node: '>=8.6'} engines: {node: '>=8.6'}
/pluralize@8.0.0:
resolution: {integrity: sha512-Nc3IT5yHzflTfbjgqWcCPpo7DaKy4FnpB0l/zCAW0Tc7jxAiuqSxHasntB3D7887LSrA93kDJ9IXovxJYxyLCA==}
engines: {node: '>=4'}
dev: false
/postcss@8.4.14: /postcss@8.4.14:
resolution: {integrity: sha512-E398TUmfAYFPBSdzgeieK2Y1+1cpdxJx8yXbK/m57nRhKSmk1GB2tO4lbLBtlkfPQTDKfe4Xqv1ASWPpayPEig==} resolution: {integrity: sha512-E398TUmfAYFPBSdzgeieK2Y1+1cpdxJx8yXbK/m57nRhKSmk1GB2tO4lbLBtlkfPQTDKfe4Xqv1ASWPpayPEig==}
engines: {node: ^10 || ^12 || >=14} engines: {node: ^10 || ^12 || >=14}

View File

@@ -0,0 +1,102 @@
import { type ModelOutput } from "@prisma/client";
import { HStack, VStack, Text, Button, Icon } from "@chakra-ui/react";
import { useEffect, useMemo, useState } from "react";
import { BsArrowClockwise } from "react-icons/bs";
import { rateLimitErrorMessage } from "~/sharedStrings";
import pluralize from 'pluralize'
const MAX_AUTO_RETRIES = 3;
export const ErrorHandler = ({
output,
refetchOutput,
numPreviousTries,
}: {
output: ModelOutput;
refetchOutput: () => void;
numPreviousTries: number;
}) => {
const [msToWait, setMsToWait] = useState(0);
const shouldAutoRetry =
output.errorMessage === rateLimitErrorMessage && numPreviousTries < MAX_AUTO_RETRIES;
const errorMessage = useMemo(() => breakLongWords(output.errorMessage), [output.errorMessage]);
useEffect(() => {
if (!shouldAutoRetry) return;
const initialWaitTime = calculateDelay(numPreviousTries);
const msModuloOneSecond = initialWaitTime % 1000;
let remainingTime = initialWaitTime - msModuloOneSecond;
setMsToWait(remainingTime);
let interval: NodeJS.Timeout;
const timeout = setTimeout(() => {
interval = setInterval(() => {
remainingTime -= 1000;
setMsToWait(remainingTime);
if (remainingTime <= 0) {
refetchOutput();
clearInterval(interval);
}
}, 1000);
}, msModuloOneSecond);
return () => {
clearInterval(interval);
clearTimeout(timeout);
};
}, [shouldAutoRetry, setMsToWait, refetchOutput, numPreviousTries]);
return (
<VStack w="full">
<HStack w="full" alignItems="flex-start" justifyContent="space-between">
<Text color="red.600" fontWeight="bold">
Error
</Text>
<Button
size="xs"
w={4}
h={4}
px={4}
py={4}
minW={0}
borderRadius={8}
variant="ghost"
cursor="pointer"
onClick={refetchOutput}
aria-label="refetch output"
>
<Icon as={BsArrowClockwise} boxSize={6} />
</Button>
</HStack>
<Text color="red.600">{errorMessage}</Text>
{msToWait > 0 && (
<Text color="red.600" fontSize="sm">
Retrying in {pluralize('second', Math.ceil(msToWait / 1000), true)}...
</Text>
)}
</VStack>
);
};
function breakLongWords(str: string | null): string {
if (!str) return "";
const words = str.split(" ");
const newWords = words.map((word) => {
return word.length > 20 ? word.slice(0, 20) + "\u200B" + word.slice(20) : word;
});
return newWords.join(" ");
}
const MIN_DELAY = 500; // milliseconds
const MAX_DELAY = 5000; // milliseconds
function calculateDelay(numPreviousTries: number): number {
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
const jitter = Math.random() * baseDelay;
return baseDelay + jitter;
}

View File

@@ -1,21 +1,25 @@
import { type RouterOutputs, api } from "~/utils/api"; import { type RouterOutputs, api } from "~/utils/api";
import { type PromptVariant, type Scenario } from "./types"; import { type PromptVariant, type Scenario } from "../types";
import { Spinner, Text, Box, Center, Flex, Icon, HStack } from "@chakra-ui/react"; import {
Spinner,
Text,
Box,
Center,
Flex,
} from "@chakra-ui/react";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import SyntaxHighlighter from "react-syntax-highlighter"; import SyntaxHighlighter from "react-syntax-highlighter";
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs"; import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
import stringify from "json-stringify-pretty-compact"; import stringify from "json-stringify-pretty-compact";
import { type ReactElement, useState, useEffect, useRef } from "react"; import { type ReactElement, useState, useEffect, useRef, useCallback } from "react";
import { BsCheck, BsClock, BsX, BsCurrencyDollar } from "react-icons/bs";
import { type ModelOutput } from "@prisma/client";
import { type ChatCompletion } from "openai/resources/chat"; import { type ChatCompletion } from "openai/resources/chat";
import { generateChannel } from "~/utils/generateChannel"; import { generateChannel } from "~/utils/generateChannel";
import { isObject } from "lodash"; import { isObject } from "lodash";
import useSocket from "~/utils/useSocket"; import useSocket from "~/utils/useSocket";
import { evaluateOutput } from "~/server/utils/evaluateOutput"; import { type JSONSerializable } from "~/server/types";
import { calculateTokenCost } from "~/utils/calculateTokenCost";
import { type JSONSerializable, type SupportedModel } from "~/server/types";
import { getModelName } from "~/server/utils/getModelName"; import { getModelName } from "~/server/utils/getModelName";
import { OutputStats } from "./OutputStats";
import { ErrorHandler } from "./ErrorHandler";
export default function OutputCell({ export default function OutputCell({
scenario, scenario,
@@ -47,31 +51,40 @@ export default function OutputCell({
const [output, setOutput] = useState<RouterOutputs["outputs"]["get"]>(null); const [output, setOutput] = useState<RouterOutputs["outputs"]["get"]>(null);
const [channel, setChannel] = useState<string | undefined>(undefined); const [channel, setChannel] = useState<string | undefined>(undefined);
const [numPreviousTries, setNumPreviousTries] = useState(0);
const fetchMutex = useRef(false); const fetchMutex = useRef(false);
const [fetchOutput, fetchingOutput] = useHandledAsyncCallback(async () => { const [fetchOutput, fetchingOutput] = useHandledAsyncCallback(
if (fetchMutex.current) return; async (forceRefetch?: boolean) => {
fetchMutex.current = true; if (fetchMutex.current) return;
setOutput(null); setNumPreviousTries((prev) => prev + 1);
const shouldStream = fetchMutex.current = true;
isObject(variant) && setOutput(null);
"config" in variant &&
isObject(variant.config) &&
"stream" in variant.config &&
variant.config.stream === true;
const channel = shouldStream ? generateChannel() : undefined; const shouldStream =
setChannel(channel); isObject(variant) &&
"config" in variant &&
isObject(variant.config) &&
"stream" in variant.config &&
variant.config.stream === true;
const output = await outputMutation.mutateAsync({ const channel = shouldStream ? generateChannel() : undefined;
scenarioId: scenario.id, setChannel(channel);
variantId: variant.id,
channel, const output = await outputMutation.mutateAsync({
}); scenarioId: scenario.id,
setOutput(output); variantId: variant.id,
await utils.promptVariants.stats.invalidate(); channel,
fetchMutex.current = false; forceRefetch,
}, [outputMutation, scenario.id, variant.id]); });
setOutput(output);
await utils.promptVariants.stats.invalidate();
fetchMutex.current = false;
},
[outputMutation, scenario.id, variant.id]
);
const hardRefetch = useCallback(() => fetchOutput(true), [fetchOutput]);
useEffect(fetchOutput, [scenario.id, variant.id]); useEffect(fetchOutput, [scenario.id, variant.id]);
@@ -93,7 +106,13 @@ export default function OutputCell({
if (!output && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>; if (!output && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
if (output && output.errorMessage) { if (output && output.errorMessage) {
return <Text color="red.600">Error: {output.errorMessage}</Text>; return (
<ErrorHandler
output={output}
refetchOutput={hardRefetch}
numPreviousTries={numPreviousTries}
/>
);
} }
const response = output?.output as unknown as ChatCompletion; const response = output?.output as unknown as ChatCompletion;
@@ -142,54 +161,4 @@ export default function OutputCell({
); );
} }
const OutputStats = ({
model,
modelOutput,
scenario,
}: {
model: SupportedModel | null;
modelOutput: ModelOutput;
scenario: Scenario;
}) => {
const timeToComplete = modelOutput.timeToComplete;
const experiment = useExperiment();
const evals =
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
const promptTokens = modelOutput.promptTokens;
const completionTokens = modelOutput.completionTokens;
const promptCost = promptTokens && model ? calculateTokenCost(model, promptTokens) : 0;
const completionCost =
completionTokens && model ? calculateTokenCost(model, completionTokens, true) : 0;
const cost = promptCost + completionCost;
return (
<HStack align="center" color="gray.500" fontSize="xs" mt={2}>
<HStack flex={1}>
{evals.map((evaluation) => {
const passed = evaluateOutput(modelOutput, scenario, evaluation);
return (
<HStack spacing={0} key={evaluation.id}>
<Text>{evaluation.name}</Text>
<Icon
as={passed ? BsCheck : BsX}
color={passed ? "green.500" : "red.500"}
boxSize={6}
/>
</HStack>
);
})}
</HStack>
<HStack spacing={0}>
<Icon as={BsCurrencyDollar} />
<Text mr={1}>{cost.toFixed(3)}</Text>
</HStack>
<HStack spacing={0.5}>
<Icon as={BsClock} />
<Text>{(timeToComplete / 1000).toFixed(2)}s</Text>
</HStack>
</HStack>
);
};

View File

@@ -0,0 +1,61 @@
import { type ModelOutput } from "@prisma/client";
import { type SupportedModel } from "~/server/types";
import { type Scenario } from "../types";
import { useExperiment } from "~/utils/hooks";
import { api } from "~/utils/api";
import { calculateTokenCost } from "~/utils/calculateTokenCost";
import { evaluateOutput } from "~/server/utils/evaluateOutput";
import { HStack, Icon, Text } from "@chakra-ui/react";
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
export const OutputStats = ({
model,
modelOutput,
scenario,
}: {
model: SupportedModel | null;
modelOutput: ModelOutput;
scenario: Scenario;
}) => {
const timeToComplete = modelOutput.timeToComplete;
const experiment = useExperiment();
const evals =
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
const promptTokens = modelOutput.promptTokens;
const completionTokens = modelOutput.completionTokens;
const promptCost = promptTokens && model ? calculateTokenCost(model, promptTokens) : 0;
const completionCost =
completionTokens && model ? calculateTokenCost(model, completionTokens, true) : 0;
const cost = promptCost + completionCost;
return (
<HStack align="center" color="gray.500" fontSize="xs" mt={2}>
<HStack flex={1}>
{evals.map((evaluation) => {
const passed = evaluateOutput(modelOutput, scenario, evaluation);
return (
<HStack spacing={0} key={evaluation.id}>
<Text>{evaluation.name}</Text>
<Icon
as={passed ? BsCheck : BsX}
color={passed ? "green.500" : "red.500"}
boxSize={6}
/>
</HStack>
);
})}
</HStack>
<HStack spacing={0}>
<Icon as={BsCurrencyDollar} />
<Text mr={1}>{cost.toFixed(3)}</Text>
</HStack>
<HStack spacing={0.5}>
<Icon as={BsClock} />
<Text>{(timeToComplete / 1000).toFixed(2)}s</Text>
</HStack>
</HStack>
);
};

View File

@@ -1,7 +1,7 @@
import { Box, GridItem } from "@chakra-ui/react"; import { Box, GridItem } from "@chakra-ui/react";
import React, { useState } from "react"; import React, { useState } from "react";
import { cellPadding } from "../constants"; import { cellPadding } from "../constants";
import OutputCell from "./OutputCell"; import OutputCell from "./OutputCell/OutputCell";
import ScenarioEditor from "./ScenarioEditor"; import ScenarioEditor from "./ScenarioEditor";
import type { PromptVariant, Scenario } from "./types"; import type { PromptVariant, Scenario } from "./types";

View File

@@ -11,7 +11,12 @@ import { getCompletion } from "~/server/utils/getCompletion";
export const modelOutputsRouter = createTRPCRouter({ export const modelOutputsRouter = createTRPCRouter({
get: publicProcedure get: publicProcedure
.input( .input(
z.object({ scenarioId: z.string(), variantId: z.string(), channel: z.string().optional() }) z.object({
scenarioId: z.string(),
variantId: z.string(),
channel: z.string().optional(),
forceRefetch: z.boolean().optional(),
})
) )
.mutation(async ({ input }) => { .mutation(async ({ input }) => {
const existing = await prisma.modelOutput.findUnique({ const existing = await prisma.modelOutput.findUnique({
@@ -23,7 +28,7 @@ export const modelOutputsRouter = createTRPCRouter({
}, },
}); });
if (existing) return existing; if (existing && !input.forceRefetch) return existing;
const variant = await prisma.promptVariant.findUnique({ const variant = await prisma.promptVariant.findUnique({
where: { where: {
@@ -69,13 +74,22 @@ export const modelOutputsRouter = createTRPCRouter({
modelResponse = await getCompletion(filledTemplate, input.channel); modelResponse = await getCompletion(filledTemplate, input.channel);
} }
const modelOutput = await prisma.modelOutput.create({ const modelOutput = await prisma.modelOutput.upsert({
data: { where: {
promptVariantId_testScenarioId: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
}
},
create: {
promptVariantId: input.variantId, promptVariantId: input.variantId,
testScenarioId: input.scenarioId, testScenarioId: input.scenarioId,
inputHash, inputHash,
...modelResponse, ...modelResponse,
}, },
update: {
...modelResponse,
},
}); });
await reevaluateVariant(input.variantId); await reevaluateVariant(input.variantId);

View File

@@ -8,6 +8,7 @@ import { type JSONSerializable, OpenAIChatModel } from "../types";
import { env } from "~/env.mjs"; import { env } from "~/env.mjs";
import { countOpenAIChatTokens } from "~/utils/countTokens"; import { countOpenAIChatTokens } from "~/utils/countTokens";
import { getModelName } from "./getModelName"; import { getModelName } from "./getModelName";
import { rateLimitErrorMessage } from "~/sharedStrings";
env; env;
@@ -32,9 +33,7 @@ export async function getCompletion(
errorMessage: "Invalid payload provided", errorMessage: "Invalid payload provided",
timeToComplete: 0, timeToComplete: 0,
}; };
if ( if (modelName in OpenAIChatModel) {
modelName in OpenAIChatModel
) {
return getOpenAIChatCompletion( return getOpenAIChatCompletion(
payload as unknown as CompletionCreateParams, payload as unknown as CompletionCreateParams,
env.OPENAI_API_KEY, env.OPENAI_API_KEY,
@@ -93,13 +92,15 @@ export async function getOpenAIChatCompletion(
} }
if (!response.ok) { if (!response.ok) {
// If it's an object, try to get the error message if (response.status === 429) {
if ( resp.errorMessage = rateLimitErrorMessage;
} else if (
isObject(resp.output) && isObject(resp.output) &&
"error" in resp.output && "error" in resp.output &&
isObject(resp.output.error) && isObject(resp.output.error) &&
"message" in resp.output.error "message" in resp.output.error
) { ) {
// If it's an object, try to get the error message
resp.errorMessage = resp.output.error.message?.toString() ?? "Unknown error"; resp.errorMessage = resp.output.error.message?.toString() ?? "Unknown error";
} }
} }
@@ -108,16 +109,13 @@ export async function getOpenAIChatCompletion(
const usage = resp.output.usage as unknown as ChatCompletion.Usage; const usage = resp.output.usage as unknown as ChatCompletion.Usage;
resp.promptTokens = usage.prompt_tokens; resp.promptTokens = usage.prompt_tokens;
resp.completionTokens = usage.completion_tokens; resp.completionTokens = usage.completion_tokens;
} else if (isObject(resp.output) && 'choices' in resp.output) { } else if (isObject(resp.output) && "choices" in resp.output) {
const model = payload.model as unknown as OpenAIChatModel const model = payload.model as unknown as OpenAIChatModel;
resp.promptTokens = countOpenAIChatTokens( resp.promptTokens = countOpenAIChatTokens(model, payload.messages);
model,
payload.messages
);
const choices = resp.output.choices as unknown as ChatCompletion.Choice[]; const choices = resp.output.choices as unknown as ChatCompletion.Choice[];
const message = choices[0]?.message const message = choices[0]?.message;
if (message) { if (message) {
const messages = [message] const messages = [message];
resp.completionTokens = countOpenAIChatTokens(model, messages); resp.completionTokens = countOpenAIChatTokens(model, messages);
} }
} }

1
src/sharedStrings.ts Normal file
View File

@@ -0,0 +1 @@
export const rateLimitErrorMessage = "429 - Rate limit exceeded.";