Streaming works for normal text

This commit is contained in:
David Corbitt
2023-07-03 19:51:34 -07:00
parent 2569943ecb
commit 5f11b258ca
9 changed files with 961 additions and 63 deletions

View File

@@ -2,13 +2,16 @@ import { api } from "~/utils/api";
import { type PromptVariant, type Scenario } from "./types";
import { Spinner, Text, Box, Center, Flex, Icon } from "@chakra-ui/react";
import { useExperiment } from "~/utils/hooks";
import { type CreateChatCompletionResponse } from "openai";
import SyntaxHighlighter from "react-syntax-highlighter";
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
import stringify from "json-stringify-pretty-compact";
import { type ReactElement } from "react";
import { useMemo, type ReactElement } from "react";
import { BsClock } from "react-icons/bs";
import { type ModelOutput } from "@prisma/client";
import { type ChatCompletion } from "openai/resources/chat";
import { generateChannelId } from "~/server/utils/generateChannelId";
import { isObject } from "lodash";
import useSocket from "~/utils/useSocket";
export default function OutputCell({
scenario,
@@ -33,35 +36,52 @@ export default function OutputCell({
if (variant.config === null || Object.keys(variant.config).length === 0)
disabledReason = "Save your prompt variant to see output";
const shouldStream =
isObject(variant) &&
"config" in variant &&
isObject(variant.config) &&
"stream" in variant.config &&
variant.config.stream === true;
const channelId = useMemo(() => {
if (!shouldStream) return;
return generateChannelId();
}, [shouldStream]);
const output = api.outputs.get.useQuery(
{
scenarioId: scenario.id,
variantId: variant.id,
channelId,
},
{ enabled: disabledReason === null }
);
// Disconnect from socket if we're not streaming anymore
const streamedMessage = useSocket(output.isLoading ? channelId : undefined);
const streamedContent = streamedMessage?.choices?.[0]?.message?.content;
if (!vars) return null;
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
if (output.isLoading)
if (output.isLoading && !streamedMessage)
return (
<Center h="100%" w="100%">
<Spinner />
</Center>
);
if (!output.data) return <Text color="gray.500">Error retrieving output</Text>;
if (!output.data && !output.isLoading)
return <Text color="gray.500">Error retrieving output</Text>;
if (output.data.errorMessage) {
if (output.data && output.data.errorMessage) {
return <Text color="red.600">Error: {output.data.errorMessage}</Text>;
}
const response = output.data?.output as unknown as CreateChatCompletionResponse;
const response = output.data?.output as unknown as ChatCompletion;
const message = response?.choices?.[0]?.message;
if (message?.function_call) {
if (output.data && message?.function_call) {
const rawArgs = message.function_call.arguments ?? "null";
let parsedArgs: string;
try {
@@ -94,10 +114,12 @@ export default function OutputCell({
);
}
const contentToDisplay = message?.content ?? streamedContent ?? JSON.stringify(output.data?.output);
return (
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
{message?.content ?? JSON.stringify(output.data.output)}
<OutputStats modelOutput={output.data} />
{contentToDisplay}
{output.data && <OutputStats modelOutput={output.data} />}
</Flex>
);
}