Streaming works for normal text

This commit is contained in:
David Corbitt
2023-07-03 19:51:34 -07:00
parent 2569943ecb
commit 5f11b258ca
9 changed files with 961 additions and 63 deletions

View File

@@ -5,7 +5,9 @@
"license": "Apache-2.0",
"scripts": {
"build": "next build",
"dev": "next dev",
"dev:next": "next dev",
"dev:wss": "pnpm tsx --watch src/wss-server.ts",
"dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss'",
"postinstall": "prisma generate",
"lint": "next lint",
"start": "next start",
@@ -27,28 +29,35 @@
"@trpc/next": "^10.26.0",
"@trpc/react-query": "^10.26.0",
"@trpc/server": "^10.26.0",
"concurrently": "^8.2.0",
"cors": "^2.8.5",
"dayjs": "^1.11.8",
"dotenv": "^16.3.1",
"express": "^4.18.2",
"framer-motion": "^10.12.17",
"json-stringify-pretty-compact": "^4.0.0",
"lodash": "^4.17.21",
"next": "^13.4.2",
"next-auth": "^4.22.1",
"nextjs-routes": "^2.0.1",
"openai": "^3.3.0",
"openai": "4.0.0-beta.2",
"posthog-js": "^1.68.4",
"react": "18.2.0",
"react-dom": "18.2.0",
"react-icons": "^4.10.1",
"react-syntax-highlighter": "^15.5.0",
"react-textarea-autosize": "^8.5.0",
"socket.io": "^4.7.1",
"socket.io-client": "^4.7.1",
"superjson": "1.12.2",
"tsx": "^3.12.7",
"zod": "^3.21.4"
},
"devDependencies": {
"@openapi-contrib/openapi-schema-to-json-schema": "^4.0.5",
"@types/cors": "^2.8.13",
"@types/eslint": "^8.37.0",
"@types/express": "^4.17.17",
"@types/lodash": "^4.14.195",
"@types/node": "^18.16.0",
"@types/react": "^18.2.6",

856
pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -2,13 +2,16 @@ import { api } from "~/utils/api";
import { type PromptVariant, type Scenario } from "./types";
import { Spinner, Text, Box, Center, Flex, Icon } from "@chakra-ui/react";
import { useExperiment } from "~/utils/hooks";
import { type CreateChatCompletionResponse } from "openai";
import SyntaxHighlighter from "react-syntax-highlighter";
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
import stringify from "json-stringify-pretty-compact";
import { type ReactElement } from "react";
import { useMemo, type ReactElement } from "react";
import { BsClock } from "react-icons/bs";
import { type ModelOutput } from "@prisma/client";
import { type ChatCompletion } from "openai/resources/chat";
import { generateChannelId } from "~/server/utils/generateChannelId";
import { isObject } from "lodash";
import useSocket from "~/utils/useSocket";
export default function OutputCell({
scenario,
@@ -33,35 +36,52 @@ export default function OutputCell({
if (variant.config === null || Object.keys(variant.config).length === 0)
disabledReason = "Save your prompt variant to see output";
const shouldStream =
isObject(variant) &&
"config" in variant &&
isObject(variant.config) &&
"stream" in variant.config &&
variant.config.stream === true;
const channelId = useMemo(() => {
if (!shouldStream) return;
return generateChannelId();
}, [shouldStream]);
const output = api.outputs.get.useQuery(
{
scenarioId: scenario.id,
variantId: variant.id,
channelId,
},
{ enabled: disabledReason === null }
);
// Disconnect from socket if we're not streaming anymore
const streamedMessage = useSocket(output.isLoading ? channelId : undefined);
const streamedContent = streamedMessage?.choices?.[0]?.message?.content;
if (!vars) return null;
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
if (output.isLoading)
if (output.isLoading && !streamedMessage)
return (
<Center h="100%" w="100%">
<Spinner />
</Center>
);
if (!output.data) return <Text color="gray.500">Error retrieving output</Text>;
if (!output.data && !output.isLoading)
return <Text color="gray.500">Error retrieving output</Text>;
if (output.data.errorMessage) {
if (output.data && output.data.errorMessage) {
return <Text color="red.600">Error: {output.data.errorMessage}</Text>;
}
const response = output.data?.output as unknown as CreateChatCompletionResponse;
const response = output.data?.output as unknown as ChatCompletion;
const message = response?.choices?.[0]?.message;
if (message?.function_call) {
if (output.data && message?.function_call) {
const rawArgs = message.function_call.arguments ?? "null";
let parsedArgs: string;
try {
@@ -94,10 +114,12 @@ export default function OutputCell({
);
}
const contentToDisplay = message?.content ?? streamedContent ?? JSON.stringify(output.data?.output);
return (
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
{message?.content ?? JSON.stringify(output.data.output)}
<OutputStats modelOutput={output.data} />
{contentToDisplay}
{output.data && <OutputStats modelOutput={output.data} />}
</Flex>
);
}

View File

@@ -12,7 +12,7 @@ env;
export const modelOutputsRouter = createTRPCRouter({
get: publicProcedure
.input(z.object({ scenarioId: z.string(), variantId: z.string() }))
.input(z.object({ scenarioId: z.string(), variantId: z.string(), channelId: z.string().optional() }))
.query(async ({ input }) => {
const existing = await prisma.modelOutput.findUnique({
where: {
@@ -64,7 +64,7 @@ export const modelOutputsRouter = createTRPCRouter({
timeToComplete: existingResponse.timeToComplete,
};
} else {
modelResponse = await getChatCompletion(filledTemplate, env.OPENAI_API_KEY);
modelResponse = await getChatCompletion(filledTemplate, env.OPENAI_API_KEY, input.channelId);
}
const modelOutput = await prisma.modelOutput.create({

View File

@@ -0,0 +1,5 @@
// generate random channel id
export const generateChannelId = () => {
return Math.random().toString(36).substring(2, 15) + Math.random().toString(36).substring(2, 15);
};

View File

@@ -1,6 +1,10 @@
/* eslint-disable @typescript-eslint/no-unsafe-call */
import { isObject } from "lodash";
import { type JSONSerializable } from "../types";
import { Prisma } from "@prisma/client";
import { streamChatCompletion } from "./openai";
import { wsConnection } from "~/utils/wsConnection";
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
type CompletionResponse = {
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
@@ -11,7 +15,8 @@ type CompletionResponse = {
export async function getChatCompletion(
payload: JSONSerializable,
apiKey: string
apiKey: string,
channelId?: string,
): Promise<CompletionResponse> {
const start = Date.now();
const response = await fetch("https://api.openai.com/v1/chat/completions", {
@@ -31,8 +36,21 @@ export async function getChatCompletion(
};
try {
resp.timeToComplete = Date.now() - start;
resp.output = await response.json();
if (channelId) {
const completion = streamChatCompletion(payload as unknown as CompletionCreateParams);
let finalOutput: ChatCompletion | null = null;
await (async () => {
for await (const partialCompletion of completion) {
finalOutput = partialCompletion
wsConnection.emit("message", { channel: channelId, payload: partialCompletion });
}
})().catch((err) => console.error(err));
resp.output = finalOutput as unknown as Prisma.InputJsonValue;
resp.timeToComplete = Date.now() - start;
} else {
resp.timeToComplete = Date.now() - start;
resp.output = await response.json();
}
if (!response.ok) {
// If it's an object, try to get the error message

42
src/utils/useSocket.ts Normal file
View File

@@ -0,0 +1,42 @@
import { type ChatCompletion } from "openai/resources/chat";
import { useRef, useState, useEffect } from "react";
import { io, type Socket } from "socket.io-client";
import { env } from "~/env.mjs";
const url = env.NEXT_PUBLIC_SOCKET_URL;
export default function useSocket(channelId?: string) {
const socketRef = useRef<Socket>();
const [message, setMessage] = useState<ChatCompletion | null>(null);
useEffect(() => {
// Create websocket connection
socketRef.current = io(url);
socketRef.current.on("connect", () => {
// Join the specific room
if (channelId) {
socketRef.current?.emit("join", channelId);
// Listen for 'message' events
socketRef.current?.on("message", (message: ChatCompletion) => {
console.log("message", message);
setMessage(message);
});
}
});
// Unsubscribe and disconnect on cleanup
return () => {
if (socketRef.current) {
if (channelId) {
socketRef.current.off("message");
}
socketRef.current.disconnect();
}
setMessage(null);
};
}, [channelId]);
return message;
}

View File

@@ -0,0 +1,4 @@
import { io } from "socket.io-client";
import { env } from "~/env.mjs";
export const wsConnection = io(env.NEXT_PUBLIC_SOCKET_URL);

36
src/wss-server.ts Normal file
View File

@@ -0,0 +1,36 @@
import "dotenv/config";
import express from "express";
import { createServer } from "http";
import { Server } from "socket.io";
import { env } from "./env.mjs";
import cors from "cors";
// Get the port from SOCKET_URL
const port = env.NEXT_PUBLIC_SOCKET_URL?.split(":")[2] || 3318;
const app = express();
app.use(cors());
const server = createServer(app);
const io = new Server(server, {
cors: {
origin: "*",
methods: ["GET", "POST"],
},
});
io.on("connection", (socket) => {
// Listen to 'join' event to add this socket to a specific room.
socket.on("join", (room: string) => {
socket.join(room)?.catch((err) => console.log(err));
});
// When a 'message' event is received, emit it to the room specified
socket.on("message", (msg: { channel: string; payload: any }) => {
socket.to(msg.channel).emit("message", msg.payload);
});
});
server.listen(port, () => {
console.log(`listening on *:${port}`);
});