diff --git a/package.json b/package.json
index cd5402d..d4f6769 100644
--- a/package.json
+++ b/package.json
@@ -43,6 +43,7 @@
"next-auth": "^4.22.1",
"nextjs-routes": "^2.0.1",
"openai": "4.0.0-beta.2",
+ "pluralize": "^8.0.0",
"posthog-js": "^1.68.4",
"react": "18.2.0",
"react-dom": "18.2.0",
@@ -64,6 +65,7 @@
"@types/express": "^4.17.17",
"@types/lodash": "^4.14.195",
"@types/node": "^18.16.0",
+ "@types/pluralize": "^0.0.30",
"@types/react": "^18.2.6",
"@types/react-dom": "^18.2.4",
"@types/react-syntax-highlighter": "^15.5.7",
diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml
index 41da3d7..d9d47a6 100644
--- a/pnpm-lock.yaml
+++ b/pnpm-lock.yaml
@@ -1,4 +1,4 @@
-lockfileVersion: '6.1'
+lockfileVersion: '6.0'
settings:
autoInstallPeers: true
@@ -92,6 +92,9 @@ dependencies:
openai:
specifier: 4.0.0-beta.2
version: 4.0.0-beta.2
+ pluralize:
+ specifier: ^8.0.0
+ version: 8.0.0
posthog-js:
specifier: ^1.68.4
version: 1.68.4
@@ -151,6 +154,9 @@ devDependencies:
'@types/node':
specifier: ^18.16.0
version: 18.16.0
+ '@types/pluralize':
+ specifier: ^0.0.30
+ version: 0.0.30
'@types/react':
specifier: ^18.2.6
version: 18.2.6
@@ -2179,6 +2185,10 @@ packages:
resolution: {integrity: sha512-//oorEZjL6sbPcKUaCdIGlIUeH26mgzimjBB77G6XRgnDl/L5wOnpyBGRe/Mmf5CVW3PwEBE1NjiMZ/ssFh4wA==}
dev: false
+ /@types/pluralize@0.0.30:
+ resolution: {integrity: sha512-kVww6xZrW/db5BR9OqiT71J9huRdQ+z/r+LbDuT7/EK50mCmj5FoaIARnVv0rvjUS/YpDox0cDU9lpQT011VBA==}
+ dev: true
+
/@types/prop-types@15.7.5:
resolution: {integrity: sha512-JCB8C6SnDoQf0cNycqd/35A7MjcnK+ZTqE7judS6o7utxUCg6imJg3QK2qzHKszlTjcj2cn+NwMB2i96ubpj7w==}
@@ -4883,6 +4893,11 @@ packages:
resolution: {integrity: sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==}
engines: {node: '>=8.6'}
+ /pluralize@8.0.0:
+ resolution: {integrity: sha512-Nc3IT5yHzflTfbjgqWcCPpo7DaKy4FnpB0l/zCAW0Tc7jxAiuqSxHasntB3D7887LSrA93kDJ9IXovxJYxyLCA==}
+ engines: {node: '>=4'}
+ dev: false
+
/postcss@8.4.14:
resolution: {integrity: sha512-E398TUmfAYFPBSdzgeieK2Y1+1cpdxJx8yXbK/m57nRhKSmk1GB2tO4lbLBtlkfPQTDKfe4Xqv1ASWPpayPEig==}
engines: {node: ^10 || ^12 || >=14}
diff --git a/src/components/OutputsTable/OutputCell/ErrorHandler.tsx b/src/components/OutputsTable/OutputCell/ErrorHandler.tsx
new file mode 100644
index 0000000..0e92a1a
--- /dev/null
+++ b/src/components/OutputsTable/OutputCell/ErrorHandler.tsx
@@ -0,0 +1,102 @@
+import { type ModelOutput } from "@prisma/client";
+import { HStack, VStack, Text, Button, Icon } from "@chakra-ui/react";
+import { useEffect, useMemo, useState } from "react";
+import { BsArrowClockwise } from "react-icons/bs";
+import { rateLimitErrorMessage } from "~/sharedStrings";
+import pluralize from 'pluralize'
+
+const MAX_AUTO_RETRIES = 3;
+
+export const ErrorHandler = ({
+ output,
+ refetchOutput,
+ numPreviousTries,
+}: {
+ output: ModelOutput;
+ refetchOutput: () => void;
+ numPreviousTries: number;
+}) => {
+ const [msToWait, setMsToWait] = useState(0);
+ const shouldAutoRetry =
+ output.errorMessage === rateLimitErrorMessage && numPreviousTries < MAX_AUTO_RETRIES;
+
+ const errorMessage = useMemo(() => breakLongWords(output.errorMessage), [output.errorMessage]);
+
+ useEffect(() => {
+ if (!shouldAutoRetry) return;
+
+ const initialWaitTime = calculateDelay(numPreviousTries);
+ const msModuloOneSecond = initialWaitTime % 1000;
+ let remainingTime = initialWaitTime - msModuloOneSecond;
+ setMsToWait(remainingTime);
+
+ let interval: NodeJS.Timeout;
+ const timeout = setTimeout(() => {
+ interval = setInterval(() => {
+ remainingTime -= 1000;
+ setMsToWait(remainingTime);
+
+ if (remainingTime <= 0) {
+ refetchOutput();
+ clearInterval(interval);
+ }
+ }, 1000);
+ }, msModuloOneSecond);
+
+ return () => {
+ clearInterval(interval);
+ clearTimeout(timeout);
+ };
+ }, [shouldAutoRetry, setMsToWait, refetchOutput, numPreviousTries]);
+
+ return (
+
+
+
+ Error
+
+
+
+ {errorMessage}
+ {msToWait > 0 && (
+
+ Retrying in {pluralize('second', Math.ceil(msToWait / 1000), true)}...
+
+ )}
+
+ );
+};
+
+function breakLongWords(str: string | null): string {
+ if (!str) return "";
+ const words = str.split(" ");
+
+ const newWords = words.map((word) => {
+ return word.length > 20 ? word.slice(0, 20) + "\u200B" + word.slice(20) : word;
+ });
+
+ return newWords.join(" ");
+}
+
+const MIN_DELAY = 500; // milliseconds
+const MAX_DELAY = 5000; // milliseconds
+
+function calculateDelay(numPreviousTries: number): number {
+ const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
+ const jitter = Math.random() * baseDelay;
+ return baseDelay + jitter;
+}
diff --git a/src/components/OutputsTable/OutputCell.tsx b/src/components/OutputsTable/OutputCell/OutputCell.tsx
similarity index 57%
rename from src/components/OutputsTable/OutputCell.tsx
rename to src/components/OutputsTable/OutputCell/OutputCell.tsx
index de6dcdb..012a6e9 100644
--- a/src/components/OutputsTable/OutputCell.tsx
+++ b/src/components/OutputsTable/OutputCell/OutputCell.tsx
@@ -1,21 +1,25 @@
import { type RouterOutputs, api } from "~/utils/api";
-import { type PromptVariant, type Scenario } from "./types";
-import { Spinner, Text, Box, Center, Flex, Icon, HStack } from "@chakra-ui/react";
+import { type PromptVariant, type Scenario } from "../types";
+import {
+ Spinner,
+ Text,
+ Box,
+ Center,
+ Flex,
+} from "@chakra-ui/react";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import SyntaxHighlighter from "react-syntax-highlighter";
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
import stringify from "json-stringify-pretty-compact";
-import { type ReactElement, useState, useEffect, useRef } from "react";
-import { BsCheck, BsClock, BsX, BsCurrencyDollar } from "react-icons/bs";
-import { type ModelOutput } from "@prisma/client";
+import { type ReactElement, useState, useEffect, useRef, useCallback } from "react";
import { type ChatCompletion } from "openai/resources/chat";
import { generateChannel } from "~/utils/generateChannel";
import { isObject } from "lodash";
import useSocket from "~/utils/useSocket";
-import { evaluateOutput } from "~/server/utils/evaluateOutput";
-import { calculateTokenCost } from "~/utils/calculateTokenCost";
-import { type JSONSerializable, type SupportedModel } from "~/server/types";
+import { type JSONSerializable } from "~/server/types";
import { getModelName } from "~/server/utils/getModelName";
+import { OutputStats } from "./OutputStats";
+import { ErrorHandler } from "./ErrorHandler";
export default function OutputCell({
scenario,
@@ -47,31 +51,40 @@ export default function OutputCell({
const [output, setOutput] = useState(null);
const [channel, setChannel] = useState(undefined);
+ const [numPreviousTries, setNumPreviousTries] = useState(0);
+
const fetchMutex = useRef(false);
- const [fetchOutput, fetchingOutput] = useHandledAsyncCallback(async () => {
- if (fetchMutex.current) return;
- fetchMutex.current = true;
- setOutput(null);
+ const [fetchOutput, fetchingOutput] = useHandledAsyncCallback(
+ async (forceRefetch?: boolean) => {
+ if (fetchMutex.current) return;
+ setNumPreviousTries((prev) => prev + 1);
- const shouldStream =
- isObject(variant) &&
- "config" in variant &&
- isObject(variant.config) &&
- "stream" in variant.config &&
- variant.config.stream === true;
+ fetchMutex.current = true;
+ setOutput(null);
- const channel = shouldStream ? generateChannel() : undefined;
- setChannel(channel);
+ const shouldStream =
+ isObject(variant) &&
+ "config" in variant &&
+ isObject(variant.config) &&
+ "stream" in variant.config &&
+ variant.config.stream === true;
- const output = await outputMutation.mutateAsync({
- scenarioId: scenario.id,
- variantId: variant.id,
- channel,
- });
- setOutput(output);
- await utils.promptVariants.stats.invalidate();
- fetchMutex.current = false;
- }, [outputMutation, scenario.id, variant.id]);
+ const channel = shouldStream ? generateChannel() : undefined;
+ setChannel(channel);
+
+ const output = await outputMutation.mutateAsync({
+ scenarioId: scenario.id,
+ variantId: variant.id,
+ channel,
+ forceRefetch,
+ });
+ setOutput(output);
+ await utils.promptVariants.stats.invalidate();
+ fetchMutex.current = false;
+ },
+ [outputMutation, scenario.id, variant.id]
+ );
+ const hardRefetch = useCallback(() => fetchOutput(true), [fetchOutput]);
useEffect(fetchOutput, [scenario.id, variant.id]);
@@ -93,7 +106,13 @@ export default function OutputCell({
if (!output && !fetchingOutput) return Error retrieving output;
if (output && output.errorMessage) {
- return Error: {output.errorMessage};
+ return (
+
+ );
}
const response = output?.output as unknown as ChatCompletion;
@@ -142,54 +161,4 @@ export default function OutputCell({
);
}
-const OutputStats = ({
- model,
- modelOutput,
- scenario,
-}: {
- model: SupportedModel | null;
- modelOutput: ModelOutput;
- scenario: Scenario;
-}) => {
- const timeToComplete = modelOutput.timeToComplete;
- const experiment = useExperiment();
- const evals =
- api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
- const promptTokens = modelOutput.promptTokens;
- const completionTokens = modelOutput.completionTokens;
-
- const promptCost = promptTokens && model ? calculateTokenCost(model, promptTokens) : 0;
- const completionCost =
- completionTokens && model ? calculateTokenCost(model, completionTokens, true) : 0;
-
- const cost = promptCost + completionCost;
-
- return (
-
-
- {evals.map((evaluation) => {
- const passed = evaluateOutput(modelOutput, scenario, evaluation);
- return (
-
- {evaluation.name}
-
-
- );
- })}
-
-
-
- {cost.toFixed(3)}
-
-
-
- {(timeToComplete / 1000).toFixed(2)}s
-
-
- );
-};
diff --git a/src/components/OutputsTable/OutputCell/OutputStats.tsx b/src/components/OutputsTable/OutputCell/OutputStats.tsx
new file mode 100644
index 0000000..59bc17e
--- /dev/null
+++ b/src/components/OutputsTable/OutputCell/OutputStats.tsx
@@ -0,0 +1,61 @@
+import { type ModelOutput } from "@prisma/client";
+import { type SupportedModel } from "~/server/types";
+import { type Scenario } from "../types";
+import { useExperiment } from "~/utils/hooks";
+import { api } from "~/utils/api";
+import { calculateTokenCost } from "~/utils/calculateTokenCost";
+import { evaluateOutput } from "~/server/utils/evaluateOutput";
+import { HStack, Icon, Text } from "@chakra-ui/react";
+import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
+
+export const OutputStats = ({
+ model,
+ modelOutput,
+ scenario,
+}: {
+ model: SupportedModel | null;
+ modelOutput: ModelOutput;
+ scenario: Scenario;
+}) => {
+ const timeToComplete = modelOutput.timeToComplete;
+ const experiment = useExperiment();
+ const evals =
+ api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
+
+ const promptTokens = modelOutput.promptTokens;
+ const completionTokens = modelOutput.completionTokens;
+
+ const promptCost = promptTokens && model ? calculateTokenCost(model, promptTokens) : 0;
+ const completionCost =
+ completionTokens && model ? calculateTokenCost(model, completionTokens, true) : 0;
+
+ const cost = promptCost + completionCost;
+
+ return (
+
+
+ {evals.map((evaluation) => {
+ const passed = evaluateOutput(modelOutput, scenario, evaluation);
+ return (
+
+ {evaluation.name}
+
+
+ );
+ })}
+
+
+
+ {cost.toFixed(3)}
+
+
+
+ {(timeToComplete / 1000).toFixed(2)}s
+
+
+ );
+};
diff --git a/src/components/OutputsTable/ScenarioRow.tsx b/src/components/OutputsTable/ScenarioRow.tsx
index d12e1dc..da09867 100644
--- a/src/components/OutputsTable/ScenarioRow.tsx
+++ b/src/components/OutputsTable/ScenarioRow.tsx
@@ -1,7 +1,7 @@
import { Box, GridItem } from "@chakra-ui/react";
import React, { useState } from "react";
import { cellPadding } from "../constants";
-import OutputCell from "./OutputCell";
+import OutputCell from "./OutputCell/OutputCell";
import ScenarioEditor from "./ScenarioEditor";
import type { PromptVariant, Scenario } from "./types";
diff --git a/src/server/api/routers/modelOutputs.router.ts b/src/server/api/routers/modelOutputs.router.ts
index 12d0192..5f5d1fc 100644
--- a/src/server/api/routers/modelOutputs.router.ts
+++ b/src/server/api/routers/modelOutputs.router.ts
@@ -11,7 +11,12 @@ import { getCompletion } from "~/server/utils/getCompletion";
export const modelOutputsRouter = createTRPCRouter({
get: publicProcedure
.input(
- z.object({ scenarioId: z.string(), variantId: z.string(), channel: z.string().optional() })
+ z.object({
+ scenarioId: z.string(),
+ variantId: z.string(),
+ channel: z.string().optional(),
+ forceRefetch: z.boolean().optional(),
+ })
)
.mutation(async ({ input }) => {
const existing = await prisma.modelOutput.findUnique({
@@ -23,7 +28,7 @@ export const modelOutputsRouter = createTRPCRouter({
},
});
- if (existing) return existing;
+ if (existing && !input.forceRefetch) return existing;
const variant = await prisma.promptVariant.findUnique({
where: {
@@ -69,13 +74,22 @@ export const modelOutputsRouter = createTRPCRouter({
modelResponse = await getCompletion(filledTemplate, input.channel);
}
- const modelOutput = await prisma.modelOutput.create({
- data: {
+ const modelOutput = await prisma.modelOutput.upsert({
+ where: {
+ promptVariantId_testScenarioId: {
+ promptVariantId: input.variantId,
+ testScenarioId: input.scenarioId,
+ }
+ },
+ create: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
inputHash,
...modelResponse,
},
+ update: {
+ ...modelResponse,
+ },
});
await reevaluateVariant(input.variantId);
diff --git a/src/server/utils/getCompletion.ts b/src/server/utils/getCompletion.ts
index 80cfa8c..2419eed 100644
--- a/src/server/utils/getCompletion.ts
+++ b/src/server/utils/getCompletion.ts
@@ -8,6 +8,7 @@ import { type JSONSerializable, OpenAIChatModel } from "../types";
import { env } from "~/env.mjs";
import { countOpenAIChatTokens } from "~/utils/countTokens";
import { getModelName } from "./getModelName";
+import { rateLimitErrorMessage } from "~/sharedStrings";
env;
@@ -32,9 +33,7 @@ export async function getCompletion(
errorMessage: "Invalid payload provided",
timeToComplete: 0,
};
- if (
- modelName in OpenAIChatModel
- ) {
+ if (modelName in OpenAIChatModel) {
return getOpenAIChatCompletion(
payload as unknown as CompletionCreateParams,
env.OPENAI_API_KEY,
@@ -93,13 +92,15 @@ export async function getOpenAIChatCompletion(
}
if (!response.ok) {
- // If it's an object, try to get the error message
- if (
+ if (response.status === 429) {
+ resp.errorMessage = rateLimitErrorMessage;
+ } else if (
isObject(resp.output) &&
"error" in resp.output &&
isObject(resp.output.error) &&
"message" in resp.output.error
) {
+ // If it's an object, try to get the error message
resp.errorMessage = resp.output.error.message?.toString() ?? "Unknown error";
}
}
@@ -108,16 +109,13 @@ export async function getOpenAIChatCompletion(
const usage = resp.output.usage as unknown as ChatCompletion.Usage;
resp.promptTokens = usage.prompt_tokens;
resp.completionTokens = usage.completion_tokens;
- } else if (isObject(resp.output) && 'choices' in resp.output) {
- const model = payload.model as unknown as OpenAIChatModel
- resp.promptTokens = countOpenAIChatTokens(
- model,
- payload.messages
- );
+ } else if (isObject(resp.output) && "choices" in resp.output) {
+ const model = payload.model as unknown as OpenAIChatModel;
+ resp.promptTokens = countOpenAIChatTokens(model, payload.messages);
const choices = resp.output.choices as unknown as ChatCompletion.Choice[];
- const message = choices[0]?.message
+ const message = choices[0]?.message;
if (message) {
- const messages = [message]
+ const messages = [message];
resp.completionTokens = countOpenAIChatTokens(model, messages);
}
}
diff --git a/src/sharedStrings.ts b/src/sharedStrings.ts
new file mode 100644
index 0000000..65214de
--- /dev/null
+++ b/src/sharedStrings.ts
@@ -0,0 +1 @@
+export const rateLimitErrorMessage = "429 - Rate limit exceeded.";