Trigger llm output retrieval on server (#39)
* Rename tables, add graphile workers, update types * Add dev:worker command * Update pnpm-lock.yaml * Remove sentry config import from worker.ts * Stop generating new cells in cell router get query * Generate new cells for new scenarios, variants, and experiments * Remove most error throwing from queryLLM.task.ts * Remove promptVariantId and testScenarioId from ModelOutput * Remove duplicate index from ModelOutput * Move inputHash from cell to output * Add TODO * Add todo * Show cost and time for each cell * Always show output stats if there is output * Trigger LLM outputs when scenario variables are updated * Add newlines to ends of files * Add another newline * Cascade ModelOutput deletion * Fix linting and prettier * Return instead of throwing for non-pending cell * Remove pnpm dev:worker from pnpm:dev * Update pnpm-lock.yaml
This commit is contained in:
33
src/components/OutputsTable/OutputCell/CellOptions.tsx
Normal file
33
src/components/OutputsTable/OutputCell/CellOptions.tsx
Normal file
@@ -0,0 +1,33 @@
|
||||
import { Button, HStack, Icon } from "@chakra-ui/react";
|
||||
import { BsArrowClockwise } from "react-icons/bs";
|
||||
|
||||
export const CellOptions = ({
|
||||
refetchingOutput,
|
||||
refetchOutput,
|
||||
}: {
|
||||
refetchingOutput: boolean;
|
||||
refetchOutput: () => void;
|
||||
}) => {
|
||||
return (
|
||||
<HStack justifyContent="flex-end" w="full">
|
||||
{!refetchingOutput && (
|
||||
<Button
|
||||
size="xs"
|
||||
w={4}
|
||||
h={4}
|
||||
py={4}
|
||||
px={4}
|
||||
minW={0}
|
||||
borderRadius={8}
|
||||
color="gray.500"
|
||||
variant="ghost"
|
||||
cursor="pointer"
|
||||
onClick={refetchOutput}
|
||||
aria-label="refetch output"
|
||||
>
|
||||
<Icon as={BsArrowClockwise} boxSize={4} />
|
||||
</Button>
|
||||
)}
|
||||
</HStack>
|
||||
);
|
||||
};
|
||||
@@ -1,29 +1,21 @@
|
||||
import { type ModelOutput } from "@prisma/client";
|
||||
import { HStack, VStack, Text, Button, Icon } from "@chakra-ui/react";
|
||||
import { type ScenarioVariantCell } from "@prisma/client";
|
||||
import { VStack, Text } from "@chakra-ui/react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { BsArrowClockwise } from "react-icons/bs";
|
||||
import { rateLimitErrorMessage } from "~/sharedStrings";
|
||||
import pluralize from "pluralize";
|
||||
|
||||
const MAX_AUTO_RETRIES = 3;
|
||||
|
||||
export const ErrorHandler = ({
|
||||
output,
|
||||
cell,
|
||||
refetchOutput,
|
||||
numPreviousTries,
|
||||
}: {
|
||||
output: ModelOutput;
|
||||
cell: ScenarioVariantCell;
|
||||
refetchOutput: () => void;
|
||||
numPreviousTries: number;
|
||||
}) => {
|
||||
const [msToWait, setMsToWait] = useState(0);
|
||||
const shouldAutoRetry =
|
||||
output.errorMessage === rateLimitErrorMessage && numPreviousTries < MAX_AUTO_RETRIES;
|
||||
|
||||
useEffect(() => {
|
||||
if (!shouldAutoRetry) return;
|
||||
if (!cell.retryTime) return;
|
||||
|
||||
const initialWaitTime = calculateDelay(numPreviousTries);
|
||||
const initialWaitTime = cell.retryTime.getTime() - Date.now();
|
||||
const msModuloOneSecond = initialWaitTime % 1000;
|
||||
let remainingTime = initialWaitTime - msModuloOneSecond;
|
||||
setMsToWait(remainingTime);
|
||||
@@ -35,7 +27,6 @@ export const ErrorHandler = ({
|
||||
setMsToWait(remainingTime);
|
||||
|
||||
if (remainingTime <= 0) {
|
||||
refetchOutput();
|
||||
clearInterval(interval);
|
||||
}
|
||||
}, 1000);
|
||||
@@ -45,32 +36,12 @@ export const ErrorHandler = ({
|
||||
clearInterval(interval);
|
||||
clearTimeout(timeout);
|
||||
};
|
||||
}, [shouldAutoRetry, setMsToWait, refetchOutput, numPreviousTries]);
|
||||
}, [cell.retryTime, cell.statusCode, setMsToWait, refetchOutput]);
|
||||
|
||||
return (
|
||||
<VStack w="full">
|
||||
<HStack w="full" alignItems="flex-start" justifyContent="space-between">
|
||||
<Text color="red.600" fontWeight="bold">
|
||||
Error
|
||||
</Text>
|
||||
<Button
|
||||
size="xs"
|
||||
w={4}
|
||||
h={4}
|
||||
px={4}
|
||||
py={4}
|
||||
minW={0}
|
||||
borderRadius={8}
|
||||
variant="ghost"
|
||||
cursor="pointer"
|
||||
onClick={refetchOutput}
|
||||
aria-label="refetch output"
|
||||
>
|
||||
<Icon as={BsArrowClockwise} boxSize={6} />
|
||||
</Button>
|
||||
</HStack>
|
||||
<Text color="red.600" wordBreak="break-word">
|
||||
{output.errorMessage}
|
||||
{cell.errorMessage}
|
||||
</Text>
|
||||
{msToWait > 0 && (
|
||||
<Text color="red.600" fontSize="sm">
|
||||
@@ -80,12 +51,3 @@ export const ErrorHandler = ({
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
const MIN_DELAY = 500; // milliseconds
|
||||
const MAX_DELAY = 5000; // milliseconds
|
||||
|
||||
function calculateDelay(numPreviousTries: number): number {
|
||||
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
|
||||
const jitter = Math.random() * baseDelay;
|
||||
return baseDelay + jitter;
|
||||
}
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
import { type RouterOutputs, api } from "~/utils/api";
|
||||
import { api } from "~/utils/api";
|
||||
import { type PromptVariant, type Scenario } from "../types";
|
||||
import { Spinner, Text, Box, Center, Flex } from "@chakra-ui/react";
|
||||
import { Spinner, Text, Box, Center, Flex, VStack } from "@chakra-ui/react";
|
||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||
import SyntaxHighlighter from "react-syntax-highlighter";
|
||||
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
|
||||
import stringify from "json-stringify-pretty-compact";
|
||||
import { type ReactElement, useState, useEffect, useRef, useCallback } from "react";
|
||||
import { type ReactElement, useState, useEffect } from "react";
|
||||
import { type ChatCompletion } from "openai/resources/chat";
|
||||
import { generateChannel } from "~/utils/generateChannel";
|
||||
import { isObject } from "lodash";
|
||||
import useSocket from "~/utils/useSocket";
|
||||
import { OutputStats } from "./OutputStats";
|
||||
import { ErrorHandler } from "./ErrorHandler";
|
||||
import { CellOptions } from "./CellOptions";
|
||||
|
||||
export default function OutputCell({
|
||||
scenario,
|
||||
@@ -37,116 +36,103 @@ export default function OutputCell({
|
||||
// if (variant.config === null || Object.keys(variant.config).length === 0)
|
||||
// disabledReason = "Save your prompt variant to see output";
|
||||
|
||||
const outputMutation = api.outputs.get.useMutation();
|
||||
|
||||
const [output, setOutput] = useState<RouterOutputs["outputs"]["get"]>(null);
|
||||
const [channel, setChannel] = useState<string | undefined>(undefined);
|
||||
const [numPreviousTries, setNumPreviousTries] = useState(0);
|
||||
|
||||
const fetchMutex = useRef(false);
|
||||
const [fetchOutput, fetchingOutput] = useHandledAsyncCallback(
|
||||
async (forceRefetch?: boolean) => {
|
||||
if (fetchMutex.current) return;
|
||||
setNumPreviousTries((prev) => prev + 1);
|
||||
|
||||
fetchMutex.current = true;
|
||||
setOutput(null);
|
||||
|
||||
const shouldStream =
|
||||
isObject(variant) &&
|
||||
"config" in variant &&
|
||||
isObject(variant.config) &&
|
||||
"stream" in variant.config &&
|
||||
variant.config.stream === true;
|
||||
|
||||
const channel = shouldStream ? generateChannel() : undefined;
|
||||
setChannel(channel);
|
||||
|
||||
const output = await outputMutation.mutateAsync({
|
||||
scenarioId: scenario.id,
|
||||
variantId: variant.id,
|
||||
channel,
|
||||
forceRefetch,
|
||||
});
|
||||
setOutput(output);
|
||||
await utils.promptVariants.stats.invalidate();
|
||||
fetchMutex.current = false;
|
||||
},
|
||||
[outputMutation, scenario.id, variant.id],
|
||||
const [refetchInterval, setRefetchInterval] = useState(0);
|
||||
const { data: cell, isLoading: queryLoading } = api.scenarioVariantCells.get.useQuery(
|
||||
{ scenarioId: scenario.id, variantId: variant.id },
|
||||
{ refetchInterval },
|
||||
);
|
||||
const hardRefetch = useCallback(() => fetchOutput(true), [fetchOutput]);
|
||||
|
||||
useEffect(fetchOutput, [scenario.id, variant.id]);
|
||||
const { mutateAsync: hardRefetchMutate, isLoading: refetchingOutput } =
|
||||
api.scenarioVariantCells.forceRefetch.useMutation();
|
||||
const [hardRefetch] = useHandledAsyncCallback(async () => {
|
||||
await hardRefetchMutate({ scenarioId: scenario.id, variantId: variant.id });
|
||||
await utils.scenarioVariantCells.get.invalidate({
|
||||
scenarioId: scenario.id,
|
||||
variantId: variant.id,
|
||||
});
|
||||
}, [hardRefetchMutate, scenario.id, variant.id]);
|
||||
|
||||
const fetchingOutput = queryLoading || refetchingOutput;
|
||||
|
||||
const awaitingOutput =
|
||||
!cell || cell.retrievalStatus === "PENDING" || cell.retrievalStatus === "IN_PROGRESS";
|
||||
useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]);
|
||||
|
||||
const modelOutput = cell?.modelOutput;
|
||||
|
||||
// Disconnect from socket if we're not streaming anymore
|
||||
const streamedMessage = useSocket(fetchingOutput ? channel : undefined);
|
||||
const streamedMessage = useSocket(cell?.streamingChannel);
|
||||
const streamedContent = streamedMessage?.choices?.[0]?.message?.content;
|
||||
|
||||
if (!vars) return null;
|
||||
|
||||
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
|
||||
|
||||
if (fetchingOutput && !streamedMessage)
|
||||
if (awaitingOutput && !streamedMessage)
|
||||
return (
|
||||
<Center h="100%" w="100%">
|
||||
<Spinner />
|
||||
</Center>
|
||||
);
|
||||
|
||||
if (!output && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
|
||||
if (!cell && !fetchingOutput) return <Text color="gray.500">Error retrieving output</Text>;
|
||||
|
||||
if (output && output.errorMessage) {
|
||||
return (
|
||||
<ErrorHandler
|
||||
output={output}
|
||||
refetchOutput={hardRefetch}
|
||||
numPreviousTries={numPreviousTries}
|
||||
/>
|
||||
);
|
||||
if (cell && cell.errorMessage) {
|
||||
return <ErrorHandler cell={cell} refetchOutput={hardRefetch} />;
|
||||
}
|
||||
|
||||
const response = output?.output as unknown as ChatCompletion;
|
||||
const response = modelOutput?.output as unknown as ChatCompletion;
|
||||
const message = response?.choices?.[0]?.message;
|
||||
|
||||
if (output && message?.function_call) {
|
||||
if (modelOutput && message?.function_call) {
|
||||
const rawArgs = message.function_call.arguments ?? "null";
|
||||
let parsedArgs: string;
|
||||
try {
|
||||
parsedArgs = JSON.parse(rawArgs);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
} catch (e: any) {
|
||||
parsedArgs = `Failed to parse arguments as JSON: '${rawArgs}' ERROR: ${e.message as string}`;
|
||||
}
|
||||
|
||||
return (
|
||||
<Box fontSize="xs" width="100%" flexWrap="wrap" overflowX="auto">
|
||||
<SyntaxHighlighter
|
||||
customStyle={{ overflowX: "unset" }}
|
||||
language="json"
|
||||
style={docco}
|
||||
lineProps={{
|
||||
style: { wordBreak: "break-all", whiteSpace: "pre-wrap" },
|
||||
}}
|
||||
wrapLines
|
||||
>
|
||||
{stringify(
|
||||
{
|
||||
function: message.function_call.name,
|
||||
args: parsedArgs,
|
||||
},
|
||||
{ maxLength: 40 },
|
||||
)}
|
||||
</SyntaxHighlighter>
|
||||
<OutputStats model={variant.model} modelOutput={output} scenario={scenario} />
|
||||
<VStack w="full" spacing={0}>
|
||||
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
|
||||
<SyntaxHighlighter
|
||||
customStyle={{ overflowX: "unset" }}
|
||||
language="json"
|
||||
style={docco}
|
||||
lineProps={{
|
||||
style: { wordBreak: "break-all", whiteSpace: "pre-wrap" },
|
||||
}}
|
||||
wrapLines
|
||||
>
|
||||
{stringify(
|
||||
{
|
||||
function: message.function_call.name,
|
||||
args: parsedArgs,
|
||||
},
|
||||
{ maxLength: 40 },
|
||||
)}
|
||||
</SyntaxHighlighter>
|
||||
</VStack>
|
||||
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
const contentToDisplay = message?.content ?? streamedContent ?? JSON.stringify(output?.output);
|
||||
const contentToDisplay =
|
||||
message?.content ?? streamedContent ?? JSON.stringify(modelOutput?.output);
|
||||
|
||||
return (
|
||||
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
|
||||
{contentToDisplay}
|
||||
{output && <OutputStats model={variant.model} modelOutput={output} scenario={scenario} />}
|
||||
<VStack w="full" alignItems="flex-start" spacing={0}>
|
||||
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
|
||||
<Text>{contentToDisplay}</Text>
|
||||
</VStack>
|
||||
{modelOutput && (
|
||||
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -9,8 +9,8 @@ import { HStack, Icon, Text } from "@chakra-ui/react";
|
||||
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
|
||||
import { CostTooltip } from "~/components/tooltip/CostTooltip";
|
||||
|
||||
const SHOW_COST = false;
|
||||
const SHOW_TIME = false;
|
||||
const SHOW_COST = true;
|
||||
const SHOW_TIME = true;
|
||||
|
||||
export const OutputStats = ({
|
||||
model,
|
||||
@@ -35,8 +35,6 @@ export const OutputStats = ({
|
||||
|
||||
const cost = promptCost + completionCost;
|
||||
|
||||
if (!evals.length) return null;
|
||||
|
||||
return (
|
||||
<HStack align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
|
||||
<HStack flex={1}>
|
||||
|
||||
Reference in New Issue
Block a user