Retry requests that receive 429 (#15)
* List number of scenarios * Retry requests after 429 * Rename requestCallback * Add sleep function * Allow manual retry on frontend * Remove unused utility functions * Auto refetch * Display wait time with Math.ceil * Take one second modulo into account * Add pluralize
This commit is contained in:
@@ -43,6 +43,7 @@
|
|||||||
"next-auth": "^4.22.1",
|
"next-auth": "^4.22.1",
|
||||||
"nextjs-routes": "^2.0.1",
|
"nextjs-routes": "^2.0.1",
|
||||||
"openai": "4.0.0-beta.2",
|
"openai": "4.0.0-beta.2",
|
||||||
|
"pluralize": "^8.0.0",
|
||||||
"posthog-js": "^1.68.4",
|
"posthog-js": "^1.68.4",
|
||||||
"react": "18.2.0",
|
"react": "18.2.0",
|
||||||
"react-dom": "18.2.0",
|
"react-dom": "18.2.0",
|
||||||
@@ -64,6 +65,7 @@
|
|||||||
"@types/express": "^4.17.17",
|
"@types/express": "^4.17.17",
|
||||||
"@types/lodash": "^4.14.195",
|
"@types/lodash": "^4.14.195",
|
||||||
"@types/node": "^18.16.0",
|
"@types/node": "^18.16.0",
|
||||||
|
"@types/pluralize": "^0.0.30",
|
||||||
"@types/react": "^18.2.6",
|
"@types/react": "^18.2.6",
|
||||||
"@types/react-dom": "^18.2.4",
|
"@types/react-dom": "^18.2.4",
|
||||||
"@types/react-syntax-highlighter": "^15.5.7",
|
"@types/react-syntax-highlighter": "^15.5.7",
|
||||||
|
|||||||
17
pnpm-lock.yaml
generated
17
pnpm-lock.yaml
generated
@@ -1,4 +1,4 @@
|
|||||||
lockfileVersion: '6.1'
|
lockfileVersion: '6.0'
|
||||||
|
|
||||||
settings:
|
settings:
|
||||||
autoInstallPeers: true
|
autoInstallPeers: true
|
||||||
@@ -92,6 +92,9 @@ dependencies:
|
|||||||
openai:
|
openai:
|
||||||
specifier: 4.0.0-beta.2
|
specifier: 4.0.0-beta.2
|
||||||
version: 4.0.0-beta.2
|
version: 4.0.0-beta.2
|
||||||
|
pluralize:
|
||||||
|
specifier: ^8.0.0
|
||||||
|
version: 8.0.0
|
||||||
posthog-js:
|
posthog-js:
|
||||||
specifier: ^1.68.4
|
specifier: ^1.68.4
|
||||||
version: 1.68.4
|
version: 1.68.4
|
||||||
@@ -151,6 +154,9 @@ devDependencies:
|
|||||||
'@types/node':
|
'@types/node':
|
||||||
specifier: ^18.16.0
|
specifier: ^18.16.0
|
||||||
version: 18.16.0
|
version: 18.16.0
|
||||||
|
'@types/pluralize':
|
||||||
|
specifier: ^0.0.30
|
||||||
|
version: 0.0.30
|
||||||
'@types/react':
|
'@types/react':
|
||||||
specifier: ^18.2.6
|
specifier: ^18.2.6
|
||||||
version: 18.2.6
|
version: 18.2.6
|
||||||
@@ -2179,6 +2185,10 @@ packages:
|
|||||||
resolution: {integrity: sha512-//oorEZjL6sbPcKUaCdIGlIUeH26mgzimjBB77G6XRgnDl/L5wOnpyBGRe/Mmf5CVW3PwEBE1NjiMZ/ssFh4wA==}
|
resolution: {integrity: sha512-//oorEZjL6sbPcKUaCdIGlIUeH26mgzimjBB77G6XRgnDl/L5wOnpyBGRe/Mmf5CVW3PwEBE1NjiMZ/ssFh4wA==}
|
||||||
dev: false
|
dev: false
|
||||||
|
|
||||||
|
/@types/pluralize@0.0.30:
|
||||||
|
resolution: {integrity: sha512-kVww6xZrW/db5BR9OqiT71J9huRdQ+z/r+LbDuT7/EK50mCmj5FoaIARnVv0rvjUS/YpDox0cDU9lpQT011VBA==}
|
||||||
|
dev: true
|
||||||
|
|
||||||
/@types/prop-types@15.7.5:
|
/@types/prop-types@15.7.5:
|
||||||
resolution: {integrity: sha512-JCB8C6SnDoQf0cNycqd/35A7MjcnK+ZTqE7judS6o7utxUCg6imJg3QK2qzHKszlTjcj2cn+NwMB2i96ubpj7w==}
|
resolution: {integrity: sha512-JCB8C6SnDoQf0cNycqd/35A7MjcnK+ZTqE7judS6o7utxUCg6imJg3QK2qzHKszlTjcj2cn+NwMB2i96ubpj7w==}
|
||||||
|
|
||||||
@@ -4883,6 +4893,11 @@ packages:
|
|||||||
resolution: {integrity: sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==}
|
resolution: {integrity: sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==}
|
||||||
engines: {node: '>=8.6'}
|
engines: {node: '>=8.6'}
|
||||||
|
|
||||||
|
/pluralize@8.0.0:
|
||||||
|
resolution: {integrity: sha512-Nc3IT5yHzflTfbjgqWcCPpo7DaKy4FnpB0l/zCAW0Tc7jxAiuqSxHasntB3D7887LSrA93kDJ9IXovxJYxyLCA==}
|
||||||
|
engines: {node: '>=4'}
|
||||||
|
dev: false
|
||||||
|
|
||||||
/postcss@8.4.14:
|
/postcss@8.4.14:
|
||||||
resolution: {integrity: sha512-E398TUmfAYFPBSdzgeieK2Y1+1cpdxJx8yXbK/m57nRhKSmk1GB2tO4lbLBtlkfPQTDKfe4Xqv1ASWPpayPEig==}
|
resolution: {integrity: sha512-E398TUmfAYFPBSdzgeieK2Y1+1cpdxJx8yXbK/m57nRhKSmk1GB2tO4lbLBtlkfPQTDKfe4Xqv1ASWPpayPEig==}
|
||||||
engines: {node: ^10 || ^12 || >=14}
|
engines: {node: ^10 || ^12 || >=14}
|
||||||
|
|||||||
102
src/components/OutputsTable/OutputCell/ErrorHandler.tsx
Normal file
102
src/components/OutputsTable/OutputCell/ErrorHandler.tsx
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
import { type ModelOutput } from "@prisma/client";
|
||||||
|
import { HStack, VStack, Text, Button, Icon } from "@chakra-ui/react";
|
||||||
|
import { useEffect, useMemo, useState } from "react";
|
||||||
|
import { BsArrowClockwise } from "react-icons/bs";
|
||||||
|
import { rateLimitErrorMessage } from "~/sharedStrings";
|
||||||
|
import pluralize from 'pluralize'
|
||||||
|
|
||||||
|
const MAX_AUTO_RETRIES = 3;
|
||||||
|
|
||||||
|
export const ErrorHandler = ({
|
||||||
|
output,
|
||||||
|
refetchOutput,
|
||||||
|
numPreviousTries,
|
||||||
|
}: {
|
||||||
|
output: ModelOutput;
|
||||||
|
refetchOutput: () => void;
|
||||||
|
numPreviousTries: number;
|
||||||
|
}) => {
|
||||||
|
const [msToWait, setMsToWait] = useState(0);
|
||||||
|
const shouldAutoRetry =
|
||||||
|
output.errorMessage === rateLimitErrorMessage && numPreviousTries < MAX_AUTO_RETRIES;
|
||||||
|
|
||||||
|
const errorMessage = useMemo(() => breakLongWords(output.errorMessage), [output.errorMessage]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!shouldAutoRetry) return;
|
||||||
|
|
||||||
|
const initialWaitTime = calculateDelay(numPreviousTries);
|
||||||
|
const msModuloOneSecond = initialWaitTime % 1000;
|
||||||
|
let remainingTime = initialWaitTime - msModuloOneSecond;
|
||||||
|
setMsToWait(remainingTime);
|
||||||
|
|
||||||
|
let interval: NodeJS.Timeout;
|
||||||
|
const timeout = setTimeout(() => {
|
||||||
|
interval = setInterval(() => {
|
||||||
|
remainingTime -= 1000;
|
||||||
|
setMsToWait(remainingTime);
|
||||||
|
|
||||||
|
if (remainingTime <= 0) {
|
||||||
|
refetchOutput();
|
||||||
|
clearInterval(interval);
|
||||||
|
}
|
||||||
|
}, 1000);
|
||||||
|
}, msModuloOneSecond);
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
clearInterval(interval);
|
||||||
|
clearTimeout(timeout);
|
||||||
|
};
|
||||||
|
}, [shouldAutoRetry, setMsToWait, refetchOutput, numPreviousTries]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<VStack w="full">
|
||||||
|
<HStack w="full" alignItems="flex-start" justifyContent="space-between">
|
||||||
|
<Text color="red.600" fontWeight="bold">
|
||||||
|
Error
|
||||||
|
</Text>
|
||||||
|
<Button
|
||||||
|
size="xs"
|
||||||
|
w={4}
|
||||||
|
h={4}
|
||||||
|
px={4}
|
||||||
|
py={4}
|
||||||
|
minW={0}
|
||||||
|
borderRadius={8}
|
||||||
|
variant="ghost"
|
||||||
|
cursor="pointer"
|
||||||
|
onClick={refetchOutput}
|
||||||
|
aria-label="refetch output"
|
||||||
|
>
|
||||||
|
<Icon as={BsArrowClockwise} boxSize={6} />
|
||||||
|
</Button>
|
||||||
|
</HStack>
|
||||||
|
<Text color="red.600">{errorMessage}</Text>
|
||||||
|
{msToWait > 0 && (
|
||||||
|
<Text color="red.600" fontSize="sm">
|
||||||
|
Retrying in {pluralize('second', Math.ceil(msToWait / 1000), true)}...
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
function breakLongWords(str: string | null): string {
|
||||||
|
if (!str) return "";
|
||||||
|
const words = str.split(" ");
|
||||||
|
|
||||||
|
const newWords = words.map((word) => {
|
||||||
|
return word.length > 20 ? word.slice(0, 20) + "\u200B" + word.slice(20) : word;
|
||||||
|
});
|
||||||
|
|
||||||
|
return newWords.join(" ");
|
||||||
|
}
|
||||||
|
|
||||||
|
const MIN_DELAY = 500; // milliseconds
|
||||||
|
const MAX_DELAY = 5000; // milliseconds
|
||||||
|
|
||||||
|
function calculateDelay(numPreviousTries: number): number {
|
||||||
|
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
|
||||||
|
const jitter = Math.random() * baseDelay;
|
||||||
|
return baseDelay + jitter;
|
||||||
|
}
|
||||||
@@ -1,21 +1,25 @@
|
|||||||
import { type RouterOutputs, api } from "~/utils/api";
|
import { type RouterOutputs, api } from "~/utils/api";
|
||||||
import { type PromptVariant, type Scenario } from "./types";
|
import { type PromptVariant, type Scenario } from "../types";
|
||||||
import { Spinner, Text, Box, Center, Flex, Icon, HStack } from "@chakra-ui/react";
|
import {
|
||||||
|
Spinner,
|
||||||
|
Text,
|
||||||
|
Box,
|
||||||
|
Center,
|
||||||
|
Flex,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
import SyntaxHighlighter from "react-syntax-highlighter";
|
import SyntaxHighlighter from "react-syntax-highlighter";
|
||||||
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
|
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
|
||||||
import stringify from "json-stringify-pretty-compact";
|
import stringify from "json-stringify-pretty-compact";
|
||||||
import { type ReactElement, useState, useEffect, useRef } from "react";
|
import { type ReactElement, useState, useEffect, useRef, useCallback } from "react";
|
||||||
import { BsCheck, BsClock, BsX, BsCurrencyDollar } from "react-icons/bs";
|
|
||||||
import { type ModelOutput } from "@prisma/client";
|
|
||||||
import { type ChatCompletion } from "openai/resources/chat";
|
import { type ChatCompletion } from "openai/resources/chat";
|
||||||
import { generateChannel } from "~/utils/generateChannel";
|
import { generateChannel } from "~/utils/generateChannel";
|
||||||
import { isObject } from "lodash";
|
import { isObject } from "lodash";
|
||||||
import useSocket from "~/utils/useSocket";
|
import useSocket from "~/utils/useSocket";
|
||||||
import { evaluateOutput } from "~/server/utils/evaluateOutput";
|
import { type JSONSerializable } from "~/server/types";
|
||||||
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
|
||||||
import { type JSONSerializable, type SupportedModel } from "~/server/types";
|
|
||||||
import { getModelName } from "~/server/utils/getModelName";
|
import { getModelName } from "~/server/utils/getModelName";
|
||||||
|
import { OutputStats } from "./OutputStats";
|
||||||
|
import { ErrorHandler } from "./ErrorHandler";
|
||||||
|
|
||||||
export default function OutputCell({
|
export default function OutputCell({
|
||||||
scenario,
|
scenario,
|
||||||
@@ -47,31 +51,40 @@ export default function OutputCell({
|
|||||||
|
|
||||||
const [output, setOutput] = useState<RouterOutputs["outputs"]["get"]>(null);
|
const [output, setOutput] = useState<RouterOutputs["outputs"]["get"]>(null);
|
||||||
const [channel, setChannel] = useState<string | undefined>(undefined);
|
const [channel, setChannel] = useState<string | undefined>(undefined);
|
||||||
|
const [numPreviousTries, setNumPreviousTries] = useState(0);
|
||||||
|
|
||||||
const fetchMutex = useRef(false);
|
const fetchMutex = useRef(false);
|
||||||
const [fetchOutput, fetchingOutput] = useHandledAsyncCallback(async () => {
|
const [fetchOutput, fetchingOutput] = useHandledAsyncCallback(
|
||||||
if (fetchMutex.current) return;
|
async (forceRefetch?: boolean) => {
|
||||||
fetchMutex.current = true;
|
if (fetchMutex.current) return;
|
||||||
setOutput(null);
|
setNumPreviousTries((prev) => prev + 1);
|
||||||
|
|
||||||
const shouldStream =
|
fetchMutex.current = true;
|
||||||
isObject(variant) &&
|
setOutput(null);
|
||||||
"config" in variant &&
|
|
||||||
isObject(variant.config) &&
|
|
||||||
"stream" in variant.config &&
|
|
||||||
variant.config.stream === true;
|
|
||||||
|
|
||||||
const channel = shouldStream ? generateChannel() : undefined;
|
const shouldStream =
|
||||||
setChannel(channel);
|
isObject(variant) &&
|
||||||
|
"config" in variant &&
|
||||||
|
isObject(variant.config) &&
|
||||||
|
"stream" in variant.config &&
|
||||||
|
variant.config.stream === true;
|
||||||
|
|
||||||
const output = await outputMutation.mutateAsync({
|
const channel = shouldStream ? generateChannel() : undefined;
|
||||||
scenarioId: scenario.id,
|
setChannel(channel);
|
||||||
variantId: variant.id,
|
|
||||||
channel,
|
const output = await outputMutation.mutateAsync({
|
||||||
});
|
scenarioId: scenario.id,
|
||||||
setOutput(output);
|
variantId: variant.id,
|
||||||
await utils.promptVariants.stats.invalidate();
|
channel,
|
||||||
fetchMutex.current = false;
|
forceRefetch,
|
||||||
}, [outputMutation, scenario.id, variant.id]);
|
});
|
||||||
|
setOutput(output);
|
||||||
|
await utils.promptVariants.stats.invalidate();
|
||||||
|
fetchMutex.current = false;
|
||||||
|
},
|
||||||
|
[outputMutation, scenario.id, variant.id]
|
||||||
|
);
|
||||||
|
const hardRefetch = useCallback(() => fetchOutput(true), [fetchOutput]);
|
||||||
|
|
||||||
useEffect(fetchOutput, [scenario.id, variant.id]);
|
useEffect(fetchOutput, [scenario.id, variant.id]);
|
||||||
|
|
||||||
@@ -93,7 +106,13 @@ export default function OutputCell({
|
|||||||
if (!output && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
|
if (!output && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
|
||||||
|
|
||||||
if (output && output.errorMessage) {
|
if (output && output.errorMessage) {
|
||||||
return <Text color="red.600">Error: {output.errorMessage}</Text>;
|
return (
|
||||||
|
<ErrorHandler
|
||||||
|
output={output}
|
||||||
|
refetchOutput={hardRefetch}
|
||||||
|
numPreviousTries={numPreviousTries}
|
||||||
|
/>
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = output?.output as unknown as ChatCompletion;
|
const response = output?.output as unknown as ChatCompletion;
|
||||||
@@ -142,54 +161,4 @@ export default function OutputCell({
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const OutputStats = ({
|
|
||||||
model,
|
|
||||||
modelOutput,
|
|
||||||
scenario,
|
|
||||||
}: {
|
|
||||||
model: SupportedModel | null;
|
|
||||||
modelOutput: ModelOutput;
|
|
||||||
scenario: Scenario;
|
|
||||||
}) => {
|
|
||||||
const timeToComplete = modelOutput.timeToComplete;
|
|
||||||
const experiment = useExperiment();
|
|
||||||
const evals =
|
|
||||||
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
|
|
||||||
|
|
||||||
const promptTokens = modelOutput.promptTokens;
|
|
||||||
const completionTokens = modelOutput.completionTokens;
|
|
||||||
|
|
||||||
const promptCost = promptTokens && model ? calculateTokenCost(model, promptTokens) : 0;
|
|
||||||
const completionCost =
|
|
||||||
completionTokens && model ? calculateTokenCost(model, completionTokens, true) : 0;
|
|
||||||
|
|
||||||
const cost = promptCost + completionCost;
|
|
||||||
|
|
||||||
return (
|
|
||||||
<HStack align="center" color="gray.500" fontSize="xs" mt={2}>
|
|
||||||
<HStack flex={1}>
|
|
||||||
{evals.map((evaluation) => {
|
|
||||||
const passed = evaluateOutput(modelOutput, scenario, evaluation);
|
|
||||||
return (
|
|
||||||
<HStack spacing={0} key={evaluation.id}>
|
|
||||||
<Text>{evaluation.name}</Text>
|
|
||||||
<Icon
|
|
||||||
as={passed ? BsCheck : BsX}
|
|
||||||
color={passed ? "green.500" : "red.500"}
|
|
||||||
boxSize={6}
|
|
||||||
/>
|
|
||||||
</HStack>
|
|
||||||
);
|
|
||||||
})}
|
|
||||||
</HStack>
|
|
||||||
<HStack spacing={0}>
|
|
||||||
<Icon as={BsCurrencyDollar} />
|
|
||||||
<Text mr={1}>{cost.toFixed(3)}</Text>
|
|
||||||
</HStack>
|
|
||||||
<HStack spacing={0.5}>
|
|
||||||
<Icon as={BsClock} />
|
|
||||||
<Text>{(timeToComplete / 1000).toFixed(2)}s</Text>
|
|
||||||
</HStack>
|
|
||||||
</HStack>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
61
src/components/OutputsTable/OutputCell/OutputStats.tsx
Normal file
61
src/components/OutputsTable/OutputCell/OutputStats.tsx
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import { type ModelOutput } from "@prisma/client";
|
||||||
|
import { type SupportedModel } from "~/server/types";
|
||||||
|
import { type Scenario } from "../types";
|
||||||
|
import { useExperiment } from "~/utils/hooks";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
||||||
|
import { evaluateOutput } from "~/server/utils/evaluateOutput";
|
||||||
|
import { HStack, Icon, Text } from "@chakra-ui/react";
|
||||||
|
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
|
||||||
|
|
||||||
|
export const OutputStats = ({
|
||||||
|
model,
|
||||||
|
modelOutput,
|
||||||
|
scenario,
|
||||||
|
}: {
|
||||||
|
model: SupportedModel | null;
|
||||||
|
modelOutput: ModelOutput;
|
||||||
|
scenario: Scenario;
|
||||||
|
}) => {
|
||||||
|
const timeToComplete = modelOutput.timeToComplete;
|
||||||
|
const experiment = useExperiment();
|
||||||
|
const evals =
|
||||||
|
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
|
||||||
|
|
||||||
|
const promptTokens = modelOutput.promptTokens;
|
||||||
|
const completionTokens = modelOutput.completionTokens;
|
||||||
|
|
||||||
|
const promptCost = promptTokens && model ? calculateTokenCost(model, promptTokens) : 0;
|
||||||
|
const completionCost =
|
||||||
|
completionTokens && model ? calculateTokenCost(model, completionTokens, true) : 0;
|
||||||
|
|
||||||
|
const cost = promptCost + completionCost;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<HStack align="center" color="gray.500" fontSize="xs" mt={2}>
|
||||||
|
<HStack flex={1}>
|
||||||
|
{evals.map((evaluation) => {
|
||||||
|
const passed = evaluateOutput(modelOutput, scenario, evaluation);
|
||||||
|
return (
|
||||||
|
<HStack spacing={0} key={evaluation.id}>
|
||||||
|
<Text>{evaluation.name}</Text>
|
||||||
|
<Icon
|
||||||
|
as={passed ? BsCheck : BsX}
|
||||||
|
color={passed ? "green.500" : "red.500"}
|
||||||
|
boxSize={6}
|
||||||
|
/>
|
||||||
|
</HStack>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</HStack>
|
||||||
|
<HStack spacing={0}>
|
||||||
|
<Icon as={BsCurrencyDollar} />
|
||||||
|
<Text mr={1}>{cost.toFixed(3)}</Text>
|
||||||
|
</HStack>
|
||||||
|
<HStack spacing={0.5}>
|
||||||
|
<Icon as={BsClock} />
|
||||||
|
<Text>{(timeToComplete / 1000).toFixed(2)}s</Text>
|
||||||
|
</HStack>
|
||||||
|
</HStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import { Box, GridItem } from "@chakra-ui/react";
|
import { Box, GridItem } from "@chakra-ui/react";
|
||||||
import React, { useState } from "react";
|
import React, { useState } from "react";
|
||||||
import { cellPadding } from "../constants";
|
import { cellPadding } from "../constants";
|
||||||
import OutputCell from "./OutputCell";
|
import OutputCell from "./OutputCell/OutputCell";
|
||||||
import ScenarioEditor from "./ScenarioEditor";
|
import ScenarioEditor from "./ScenarioEditor";
|
||||||
import type { PromptVariant, Scenario } from "./types";
|
import type { PromptVariant, Scenario } from "./types";
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,12 @@ import { getCompletion } from "~/server/utils/getCompletion";
|
|||||||
export const modelOutputsRouter = createTRPCRouter({
|
export const modelOutputsRouter = createTRPCRouter({
|
||||||
get: publicProcedure
|
get: publicProcedure
|
||||||
.input(
|
.input(
|
||||||
z.object({ scenarioId: z.string(), variantId: z.string(), channel: z.string().optional() })
|
z.object({
|
||||||
|
scenarioId: z.string(),
|
||||||
|
variantId: z.string(),
|
||||||
|
channel: z.string().optional(),
|
||||||
|
forceRefetch: z.boolean().optional(),
|
||||||
|
})
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input }) => {
|
||||||
const existing = await prisma.modelOutput.findUnique({
|
const existing = await prisma.modelOutput.findUnique({
|
||||||
@@ -23,7 +28,7 @@ export const modelOutputsRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
if (existing) return existing;
|
if (existing && !input.forceRefetch) return existing;
|
||||||
|
|
||||||
const variant = await prisma.promptVariant.findUnique({
|
const variant = await prisma.promptVariant.findUnique({
|
||||||
where: {
|
where: {
|
||||||
@@ -69,13 +74,22 @@ export const modelOutputsRouter = createTRPCRouter({
|
|||||||
modelResponse = await getCompletion(filledTemplate, input.channel);
|
modelResponse = await getCompletion(filledTemplate, input.channel);
|
||||||
}
|
}
|
||||||
|
|
||||||
const modelOutput = await prisma.modelOutput.create({
|
const modelOutput = await prisma.modelOutput.upsert({
|
||||||
data: {
|
where: {
|
||||||
|
promptVariantId_testScenarioId: {
|
||||||
|
promptVariantId: input.variantId,
|
||||||
|
testScenarioId: input.scenarioId,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
create: {
|
||||||
promptVariantId: input.variantId,
|
promptVariantId: input.variantId,
|
||||||
testScenarioId: input.scenarioId,
|
testScenarioId: input.scenarioId,
|
||||||
inputHash,
|
inputHash,
|
||||||
...modelResponse,
|
...modelResponse,
|
||||||
},
|
},
|
||||||
|
update: {
|
||||||
|
...modelResponse,
|
||||||
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await reevaluateVariant(input.variantId);
|
await reevaluateVariant(input.variantId);
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import { type JSONSerializable, OpenAIChatModel } from "../types";
|
|||||||
import { env } from "~/env.mjs";
|
import { env } from "~/env.mjs";
|
||||||
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||||
import { getModelName } from "./getModelName";
|
import { getModelName } from "./getModelName";
|
||||||
|
import { rateLimitErrorMessage } from "~/sharedStrings";
|
||||||
|
|
||||||
env;
|
env;
|
||||||
|
|
||||||
@@ -32,9 +33,7 @@ export async function getCompletion(
|
|||||||
errorMessage: "Invalid payload provided",
|
errorMessage: "Invalid payload provided",
|
||||||
timeToComplete: 0,
|
timeToComplete: 0,
|
||||||
};
|
};
|
||||||
if (
|
if (modelName in OpenAIChatModel) {
|
||||||
modelName in OpenAIChatModel
|
|
||||||
) {
|
|
||||||
return getOpenAIChatCompletion(
|
return getOpenAIChatCompletion(
|
||||||
payload as unknown as CompletionCreateParams,
|
payload as unknown as CompletionCreateParams,
|
||||||
env.OPENAI_API_KEY,
|
env.OPENAI_API_KEY,
|
||||||
@@ -93,13 +92,15 @@ export async function getOpenAIChatCompletion(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
// If it's an object, try to get the error message
|
if (response.status === 429) {
|
||||||
if (
|
resp.errorMessage = rateLimitErrorMessage;
|
||||||
|
} else if (
|
||||||
isObject(resp.output) &&
|
isObject(resp.output) &&
|
||||||
"error" in resp.output &&
|
"error" in resp.output &&
|
||||||
isObject(resp.output.error) &&
|
isObject(resp.output.error) &&
|
||||||
"message" in resp.output.error
|
"message" in resp.output.error
|
||||||
) {
|
) {
|
||||||
|
// If it's an object, try to get the error message
|
||||||
resp.errorMessage = resp.output.error.message?.toString() ?? "Unknown error";
|
resp.errorMessage = resp.output.error.message?.toString() ?? "Unknown error";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -108,16 +109,13 @@ export async function getOpenAIChatCompletion(
|
|||||||
const usage = resp.output.usage as unknown as ChatCompletion.Usage;
|
const usage = resp.output.usage as unknown as ChatCompletion.Usage;
|
||||||
resp.promptTokens = usage.prompt_tokens;
|
resp.promptTokens = usage.prompt_tokens;
|
||||||
resp.completionTokens = usage.completion_tokens;
|
resp.completionTokens = usage.completion_tokens;
|
||||||
} else if (isObject(resp.output) && 'choices' in resp.output) {
|
} else if (isObject(resp.output) && "choices" in resp.output) {
|
||||||
const model = payload.model as unknown as OpenAIChatModel
|
const model = payload.model as unknown as OpenAIChatModel;
|
||||||
resp.promptTokens = countOpenAIChatTokens(
|
resp.promptTokens = countOpenAIChatTokens(model, payload.messages);
|
||||||
model,
|
|
||||||
payload.messages
|
|
||||||
);
|
|
||||||
const choices = resp.output.choices as unknown as ChatCompletion.Choice[];
|
const choices = resp.output.choices as unknown as ChatCompletion.Choice[];
|
||||||
const message = choices[0]?.message
|
const message = choices[0]?.message;
|
||||||
if (message) {
|
if (message) {
|
||||||
const messages = [message]
|
const messages = [message];
|
||||||
resp.completionTokens = countOpenAIChatTokens(model, messages);
|
resp.completionTokens = countOpenAIChatTokens(model, messages);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
1
src/sharedStrings.ts
Normal file
1
src/sharedStrings.ts
Normal file
@@ -0,0 +1 @@
|
|||||||
|
export const rateLimitErrorMessage = "429 - Rate limit exceeded.";
|
||||||
Reference in New Issue
Block a user