From 4131aa67d0480ce87446293777d300a36ffddee8 Mon Sep 17 00:00:00 2001 From: arcticfly <41524992+arcticfly@users.noreply.github.com> Date: Mon, 17 Jul 2023 18:04:38 -0700 Subject: [PATCH] Continue polling VariantStats while LLM retrieval in progress, minor UI fixes (#54) * Prevent zoom in on iOS * Expand function return code background to fill cell * Keep OutputStats on far right of cells * Continue polling prompt stats while cells are retrieving from LLM * Add comment to _document.tsx * Fix prettier --- .../OutputsTable/OutputCell/OutputCell.tsx | 26 +++++++++++----- .../OutputsTable/OutputCell/OutputStats.tsx | 2 +- src/components/OutputsTable/VariantStats.tsx | 30 ++++++++++++++----- src/components/tooltip/CostTooltip.tsx | 1 - src/pages/_document.tsx | 23 ++++++++++++++ .../api/routers/promptVariants.router.ts | 21 ++++++++++++- 6 files changed, 85 insertions(+), 18 deletions(-) create mode 100644 src/pages/_document.tsx diff --git a/src/components/OutputsTable/OutputCell/OutputCell.tsx b/src/components/OutputsTable/OutputCell/OutputCell.tsx index 2e1585a..315db45 100644 --- a/src/components/OutputsTable/OutputCell/OutputCell.tsx +++ b/src/components/OutputsTable/OutputCell/OutputCell.tsx @@ -1,6 +1,6 @@ import { api } from "~/utils/api"; import { type PromptVariant, type Scenario } from "../types"; -import { Spinner, Text, Box, Center, Flex, VStack } from "@chakra-ui/react"; +import { Spinner, Text, Center, VStack } from "@chakra-ui/react"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import SyntaxHighlighter from "react-syntax-highlighter"; import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs"; @@ -55,7 +55,10 @@ export default function OutputCell({ const fetchingOutput = queryLoading || refetchingOutput; const awaitingOutput = - !cell || cell.retrievalStatus === "PENDING" || cell.retrievalStatus === "IN_PROGRESS"; + !cell || + cell.retrievalStatus === "PENDING" || + cell.retrievalStatus === "IN_PROGRESS" || + refetchingOutput; useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]); const modelOutput = cell?.modelOutput; @@ -95,11 +98,18 @@ export default function OutputCell({ } return ( - - + + - + ); } @@ -125,7 +135,7 @@ export default function OutputCell({ message?.content ?? streamedContent ?? JSON.stringify(modelOutput?.output); return ( - + {contentToDisplay} @@ -133,6 +143,6 @@ export default function OutputCell({ {modelOutput && ( )} - + ); } diff --git a/src/components/OutputsTable/OutputCell/OutputStats.tsx b/src/components/OutputsTable/OutputCell/OutputStats.tsx index 3f511df..7d2fa99 100644 --- a/src/components/OutputsTable/OutputCell/OutputStats.tsx +++ b/src/components/OutputsTable/OutputCell/OutputStats.tsx @@ -31,7 +31,7 @@ export const OutputStats = ({ const cost = promptCost + completionCost; return ( - + {modelOutput.outputEvaluation.map((evaluation) => { const passed = evaluation.result > 0.5; diff --git a/src/components/OutputsTable/VariantStats.tsx b/src/components/OutputsTable/VariantStats.tsx index 9bd3c66..52e0f11 100644 --- a/src/components/OutputsTable/VariantStats.tsx +++ b/src/components/OutputsTable/VariantStats.tsx @@ -1,12 +1,14 @@ -import { HStack, Icon, Text, useToken } from "@chakra-ui/react"; +import { HStack, Icon, Skeleton, Text, useToken } from "@chakra-ui/react"; import { type PromptVariant } from "./types"; import { cellPadding } from "../constants"; import { api } from "~/utils/api"; import chroma from "chroma-js"; import { BsCurrencyDollar } from "react-icons/bs"; import { CostTooltip } from "../tooltip/CostTooltip"; +import { useEffect, useState } from "react"; export default function VariantStats(props: { variant: PromptVariant }) { + const [refetchInterval, setRefetchInterval] = useState(0); const { data } = api.promptVariants.stats.useQuery( { variantId: props.variant.id, @@ -19,10 +21,18 @@ export default function VariantStats(props: { variant: PromptVariant }) { completionTokens: 0, scenarioCount: 0, outputCount: 0, + awaitingRetrievals: false, }, + refetchInterval, }, ); + // Poll every two seconds while we are waiting for LLM retrievals to finish + useEffect( + () => setRefetchInterval(data.awaitingRetrievals ? 2000 : 0), + [data.awaitingRetrievals], + ); + const [passColor, neutralColor, failColor] = useToken("colors", [ "green.500", "gray.500", @@ -33,16 +43,20 @@ export default function VariantStats(props: { variant: PromptVariant }) { const showNumFinished = data.scenarioCount > 0 && data.scenarioCount !== data.outputCount; - if (!(data.evalResults.length > 0) && !data.overallCost) return null; - return ( - + {showNumFinished && ( {data.outputCount} / {data.scenarioCount} )} - + {data.evalResults.map((result) => { const passedFrac = result.passCount / result.totalCount; return ( @@ -55,17 +69,19 @@ export default function VariantStats(props: { variant: PromptVariant }) { ); })} - {data.overallCost && ( + {data.overallCost && !data.awaitingRetrievals ? ( - + {data.overallCost.toFixed(3)} + ) : ( + )} ); diff --git a/src/components/tooltip/CostTooltip.tsx b/src/components/tooltip/CostTooltip.tsx index 0d76752..68cf3ea 100644 --- a/src/components/tooltip/CostTooltip.tsx +++ b/src/components/tooltip/CostTooltip.tsx @@ -20,7 +20,6 @@ export const CostTooltip = ({ color="gray.800" bgColor="gray.50" borderWidth={1} - py={2} hasArrow shouldWrapChildren label={ diff --git a/src/pages/_document.tsx b/src/pages/_document.tsx new file mode 100644 index 0000000..4c5f82d --- /dev/null +++ b/src/pages/_document.tsx @@ -0,0 +1,23 @@ +import Document, { Html, Head, Main, NextScript } from "next/document"; + +class MyDocument extends Document { + render() { + return ( + + + {/* Prevent automatic zoom-in on iPhone when focusing on text input */} + + + +
+ + + + ); + } +} + +export default MyDocument; diff --git a/src/server/api/routers/promptVariants.router.ts b/src/server/api/routers/promptVariants.router.ts index 20f50d0..b9ea6dd 100644 --- a/src/server/api/routers/promptVariants.router.ts +++ b/src/server/api/routers/promptVariants.router.ts @@ -109,7 +109,26 @@ export const promptVariantsRouter = createTRPCRouter({ const overallCost = overallPromptCost + overallCompletionCost; - return { evalResults, promptTokens, completionTokens, overallCost, scenarioCount, outputCount }; + const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({ + where: { + promptVariantId: input.variantId, + testScenario: { visible: true }, + // Check if is PENDING or IN_PROGRESS + retrievalStatus: { + in: ["PENDING", "IN_PROGRESS"], + }, + }, + })); + + return { + evalResults, + promptTokens, + completionTokens, + overallCost, + scenarioCount, + outputCount, + awaitingRetrievals, + }; }), create: publicProcedure