From 4131aa67d0480ce87446293777d300a36ffddee8 Mon Sep 17 00:00:00 2001
From: arcticfly <41524992+arcticfly@users.noreply.github.com>
Date: Mon, 17 Jul 2023 18:04:38 -0700
Subject: [PATCH] Continue polling VariantStats while LLM retrieval in
progress, minor UI fixes (#54)
* Prevent zoom in on iOS
* Expand function return code background to fill cell
* Keep OutputStats on far right of cells
* Continue polling prompt stats while cells are retrieving from LLM
* Add comment to _document.tsx
* Fix prettier
---
.../OutputsTable/OutputCell/OutputCell.tsx | 26 +++++++++++-----
.../OutputsTable/OutputCell/OutputStats.tsx | 2 +-
src/components/OutputsTable/VariantStats.tsx | 30 ++++++++++++++-----
src/components/tooltip/CostTooltip.tsx | 1 -
src/pages/_document.tsx | 23 ++++++++++++++
.../api/routers/promptVariants.router.ts | 21 ++++++++++++-
6 files changed, 85 insertions(+), 18 deletions(-)
create mode 100644 src/pages/_document.tsx
diff --git a/src/components/OutputsTable/OutputCell/OutputCell.tsx b/src/components/OutputsTable/OutputCell/OutputCell.tsx
index 2e1585a..315db45 100644
--- a/src/components/OutputsTable/OutputCell/OutputCell.tsx
+++ b/src/components/OutputsTable/OutputCell/OutputCell.tsx
@@ -1,6 +1,6 @@
import { api } from "~/utils/api";
import { type PromptVariant, type Scenario } from "../types";
-import { Spinner, Text, Box, Center, Flex, VStack } from "@chakra-ui/react";
+import { Spinner, Text, Center, VStack } from "@chakra-ui/react";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import SyntaxHighlighter from "react-syntax-highlighter";
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
@@ -55,7 +55,10 @@ export default function OutputCell({
const fetchingOutput = queryLoading || refetchingOutput;
const awaitingOutput =
- !cell || cell.retrievalStatus === "PENDING" || cell.retrievalStatus === "IN_PROGRESS";
+ !cell ||
+ cell.retrievalStatus === "PENDING" ||
+ cell.retrievalStatus === "IN_PROGRESS" ||
+ refetchingOutput;
useEffect(() => setRefetchInterval(awaitingOutput ? 1000 : 0), [awaitingOutput]);
const modelOutput = cell?.modelOutput;
@@ -95,11 +98,18 @@ export default function OutputCell({
}
return (
-
-
+
+
-
+
);
}
@@ -125,7 +135,7 @@ export default function OutputCell({
message?.content ?? streamedContent ?? JSON.stringify(modelOutput?.output);
return (
-
+
{contentToDisplay}
@@ -133,6 +143,6 @@ export default function OutputCell({
{modelOutput && (
)}
-
+
);
}
diff --git a/src/components/OutputsTable/OutputCell/OutputStats.tsx b/src/components/OutputsTable/OutputCell/OutputStats.tsx
index 3f511df..7d2fa99 100644
--- a/src/components/OutputsTable/OutputCell/OutputStats.tsx
+++ b/src/components/OutputsTable/OutputCell/OutputStats.tsx
@@ -31,7 +31,7 @@ export const OutputStats = ({
const cost = promptCost + completionCost;
return (
-
+
{modelOutput.outputEvaluation.map((evaluation) => {
const passed = evaluation.result > 0.5;
diff --git a/src/components/OutputsTable/VariantStats.tsx b/src/components/OutputsTable/VariantStats.tsx
index 9bd3c66..52e0f11 100644
--- a/src/components/OutputsTable/VariantStats.tsx
+++ b/src/components/OutputsTable/VariantStats.tsx
@@ -1,12 +1,14 @@
-import { HStack, Icon, Text, useToken } from "@chakra-ui/react";
+import { HStack, Icon, Skeleton, Text, useToken } from "@chakra-ui/react";
import { type PromptVariant } from "./types";
import { cellPadding } from "../constants";
import { api } from "~/utils/api";
import chroma from "chroma-js";
import { BsCurrencyDollar } from "react-icons/bs";
import { CostTooltip } from "../tooltip/CostTooltip";
+import { useEffect, useState } from "react";
export default function VariantStats(props: { variant: PromptVariant }) {
+ const [refetchInterval, setRefetchInterval] = useState(0);
const { data } = api.promptVariants.stats.useQuery(
{
variantId: props.variant.id,
@@ -19,10 +21,18 @@ export default function VariantStats(props: { variant: PromptVariant }) {
completionTokens: 0,
scenarioCount: 0,
outputCount: 0,
+ awaitingRetrievals: false,
},
+ refetchInterval,
},
);
+ // Poll every two seconds while we are waiting for LLM retrievals to finish
+ useEffect(
+ () => setRefetchInterval(data.awaitingRetrievals ? 2000 : 0),
+ [data.awaitingRetrievals],
+ );
+
const [passColor, neutralColor, failColor] = useToken("colors", [
"green.500",
"gray.500",
@@ -33,16 +43,20 @@ export default function VariantStats(props: { variant: PromptVariant }) {
const showNumFinished = data.scenarioCount > 0 && data.scenarioCount !== data.outputCount;
- if (!(data.evalResults.length > 0) && !data.overallCost) return null;
-
return (
-
+
{showNumFinished && (
{data.outputCount} / {data.scenarioCount}
)}
-
+
{data.evalResults.map((result) => {
const passedFrac = result.passCount / result.totalCount;
return (
@@ -55,17 +69,19 @@ export default function VariantStats(props: { variant: PromptVariant }) {
);
})}
- {data.overallCost && (
+ {data.overallCost && !data.awaitingRetrievals ? (
-
+
{data.overallCost.toFixed(3)}
+ ) : (
+
)}
);
diff --git a/src/components/tooltip/CostTooltip.tsx b/src/components/tooltip/CostTooltip.tsx
index 0d76752..68cf3ea 100644
--- a/src/components/tooltip/CostTooltip.tsx
+++ b/src/components/tooltip/CostTooltip.tsx
@@ -20,7 +20,6 @@ export const CostTooltip = ({
color="gray.800"
bgColor="gray.50"
borderWidth={1}
- py={2}
hasArrow
shouldWrapChildren
label={
diff --git a/src/pages/_document.tsx b/src/pages/_document.tsx
new file mode 100644
index 0000000..4c5f82d
--- /dev/null
+++ b/src/pages/_document.tsx
@@ -0,0 +1,23 @@
+import Document, { Html, Head, Main, NextScript } from "next/document";
+
+class MyDocument extends Document {
+ render() {
+ return (
+
+
+ {/* Prevent automatic zoom-in on iPhone when focusing on text input */}
+
+
+
+
+
+
+
+ );
+ }
+}
+
+export default MyDocument;
diff --git a/src/server/api/routers/promptVariants.router.ts b/src/server/api/routers/promptVariants.router.ts
index 20f50d0..b9ea6dd 100644
--- a/src/server/api/routers/promptVariants.router.ts
+++ b/src/server/api/routers/promptVariants.router.ts
@@ -109,7 +109,26 @@ export const promptVariantsRouter = createTRPCRouter({
const overallCost = overallPromptCost + overallCompletionCost;
- return { evalResults, promptTokens, completionTokens, overallCost, scenarioCount, outputCount };
+ const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({
+ where: {
+ promptVariantId: input.variantId,
+ testScenario: { visible: true },
+ // Check if is PENDING or IN_PROGRESS
+ retrievalStatus: {
+ in: ["PENDING", "IN_PROGRESS"],
+ },
+ },
+ }));
+
+ return {
+ evalResults,
+ promptTokens,
+ completionTokens,
+ overallCost,
+ scenarioCount,
+ outputCount,
+ awaitingRetrievals,
+ };
}),
create: publicProcedure