diff --git a/src/components/OutputsTable/OutputCell.tsx b/src/components/OutputsTable/OutputCell.tsx index 892c092..de6dcdb 100644 --- a/src/components/OutputsTable/OutputCell.tsx +++ b/src/components/OutputsTable/OutputCell.tsx @@ -5,7 +5,7 @@ import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import SyntaxHighlighter from "react-syntax-highlighter"; import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs"; import stringify from "json-stringify-pretty-compact"; -import { useMemo, type ReactElement, useState, useEffect } from "react"; +import { type ReactElement, useState, useEffect, useRef } from "react"; import { BsCheck, BsClock, BsX, BsCurrencyDollar } from "react-icons/bs"; import { type ModelOutput } from "@prisma/client"; import { type ChatCompletion } from "openai/resources/chat"; @@ -43,23 +43,26 @@ export default function OutputCell({ const model = getModelName(variant.config as JSONSerializable); - const shouldStream = - isObject(variant) && - "config" in variant && - isObject(variant.config) && - "stream" in variant.config && - variant.config.stream === true; - const channel = useMemo(() => { - if (!shouldStream) return; - return generateChannel(); - }, [shouldStream]); - const outputMutation = api.outputs.get.useMutation(); const [output, setOutput] = useState(null); - + const [channel, setChannel] = useState(undefined); + const fetchMutex = useRef(false); const [fetchOutput, fetchingOutput] = useHandledAsyncCallback(async () => { + if (fetchMutex.current) return; + fetchMutex.current = true; setOutput(null); + + const shouldStream = + isObject(variant) && + "config" in variant && + isObject(variant.config) && + "stream" in variant.config && + variant.config.stream === true; + + const channel = shouldStream ? generateChannel() : undefined; + setChannel(channel); + const output = await outputMutation.mutateAsync({ scenarioId: scenario.id, variantId: variant.id, @@ -67,9 +70,10 @@ export default function OutputCell({ }); setOutput(output); await utils.promptVariants.stats.invalidate(); - }, [outputMutation, scenario.id, variant.id, channel]); + fetchMutex.current = false; + }, [outputMutation, scenario.id, variant.id]); - useEffect(fetchOutput, [scenario.id, variant.id, channel]); + useEffect(fetchOutput, [scenario.id, variant.id]); // Disconnect from socket if we're not streaming anymore const streamedMessage = useSocket(fetchingOutput ? channel : undefined); diff --git a/src/utils/hooks.ts b/src/utils/hooks.ts index 367207e..a82c8c0 100644 --- a/src/utils/hooks.ts +++ b/src/utils/hooks.ts @@ -18,11 +18,11 @@ export function useHandledAsyncCallback( callback: AsyncFunction, deps: React.DependencyList ) { - const [loading, setLoading] = useState(false); + const [loading, setLoading] = useState(0); const [error, setError] = useState(null); const wrappedCallback = useCallback((...args: T) => { - setLoading(true); + setLoading((loading) => loading + 1); setError(null); callback(...args) @@ -31,13 +31,13 @@ export function useHandledAsyncCallback( console.error(error); }) .finally(() => { - setLoading(false); + setLoading((loading) => loading - 1); }); // eslint-disable-next-line react-hooks/exhaustive-deps }, deps); - return [wrappedCallback, loading, error] as const; + return [wrappedCallback, loading > 0, error] as const; } // Have to do this ugly thing to convince Next not to try to access `navigator`