diff --git a/src/components/OutputsTable/OutputCell.tsx b/src/components/OutputsTable/OutputCell.tsx index e1c434f..a40521c 100644 --- a/src/components/OutputsTable/OutputCell.tsx +++ b/src/components/OutputsTable/OutputCell.tsx @@ -9,7 +9,7 @@ import { useMemo, type ReactElement } from "react"; import { BsClock } from "react-icons/bs"; import { type ModelOutput } from "@prisma/client"; import { type ChatCompletion } from "openai/resources/chat"; -import { generateChannelId } from "~/server/utils/generateChannelId"; +import { generateChannel } from "~/utils/generateChannel"; import { isObject } from "lodash"; import useSocket from "~/utils/useSocket"; @@ -42,22 +42,22 @@ export default function OutputCell({ isObject(variant.config) && "stream" in variant.config && variant.config.stream === true; - const channelId = useMemo(() => { + const channel = useMemo(() => { if (!shouldStream) return; - return generateChannelId(); + return generateChannel(); }, [shouldStream]); const output = api.outputs.get.useQuery( { scenarioId: scenario.id, variantId: variant.id, - channelId, + channel, }, { enabled: disabledReason === null } ); // Disconnect from socket if we're not streaming anymore - const streamedMessage = useSocket(output.isLoading ? channelId : undefined); + const streamedMessage = useSocket(output.isLoading ? channel : undefined); const streamedContent = streamedMessage?.choices?.[0]?.message?.content; if (!vars) return null; diff --git a/src/server/api/routers/modelOutputs.router.ts b/src/server/api/routers/modelOutputs.router.ts index cb7f528..d144747 100644 --- a/src/server/api/routers/modelOutputs.router.ts +++ b/src/server/api/routers/modelOutputs.router.ts @@ -12,7 +12,7 @@ env; export const modelOutputsRouter = createTRPCRouter({ get: publicProcedure - .input(z.object({ scenarioId: z.string(), variantId: z.string(), channelId: z.string().optional() })) + .input(z.object({ scenarioId: z.string(), variantId: z.string(), channel: z.string().optional() })) .query(async ({ input }) => { const existing = await prisma.modelOutput.findUnique({ where: { @@ -64,7 +64,7 @@ export const modelOutputsRouter = createTRPCRouter({ timeToComplete: existingResponse.timeToComplete, }; } else { - modelResponse = await getChatCompletion(filledTemplate, env.OPENAI_API_KEY, input.channelId); + modelResponse = await getChatCompletion(filledTemplate, env.OPENAI_API_KEY, input.channel); } const modelOutput = await prisma.modelOutput.create({ diff --git a/src/server/utils/getChatCompletion.ts b/src/server/utils/getChatCompletion.ts index 90affe1..3669a9d 100644 --- a/src/server/utils/getChatCompletion.ts +++ b/src/server/utils/getChatCompletion.ts @@ -16,7 +16,7 @@ type CompletionResponse = { export async function getChatCompletion( payload: JSONSerializable, apiKey: string, - channelId?: string, + channel?: string, ): Promise { const start = Date.now(); const response = await fetch("https://api.openai.com/v1/chat/completions", { @@ -36,13 +36,13 @@ export async function getChatCompletion( }; try { - if (channelId) { + if (channel) { const completion = streamChatCompletion(payload as unknown as CompletionCreateParams); let finalOutput: ChatCompletion | null = null; await (async () => { for await (const partialCompletion of completion) { finalOutput = partialCompletion - wsConnection.emit("message", { channel: channelId, payload: partialCompletion }); + wsConnection.emit("message", { channel, payload: partialCompletion }); } })().catch((err) => console.error(err)); resp.output = finalOutput as unknown as Prisma.InputJsonValue; diff --git a/src/server/utils/generateChannelId.ts b/src/utils/generateChannel.ts similarity index 76% rename from src/server/utils/generateChannelId.ts rename to src/utils/generateChannel.ts index 273951f..c7e7c6b 100644 --- a/src/server/utils/generateChannelId.ts +++ b/src/utils/generateChannel.ts @@ -1,5 +1,5 @@ // generate random channel id -export const generateChannelId = () => { +export const generateChannel = () => { return Math.random().toString(36).substring(2, 15) + Math.random().toString(36).substring(2, 15); }; diff --git a/src/utils/useSocket.ts b/src/utils/useSocket.ts index 1ab7307..9b66307 100644 --- a/src/utils/useSocket.ts +++ b/src/utils/useSocket.ts @@ -5,7 +5,7 @@ import { env } from "~/env.mjs"; const url = env.NEXT_PUBLIC_SOCKET_URL; -export default function useSocket(channelId?: string) { +export default function useSocket(channel?: string) { const socketRef = useRef(); const [message, setMessage] = useState(null); @@ -15,8 +15,8 @@ export default function useSocket(channelId?: string) { socketRef.current.on("connect", () => { // Join the specific room - if (channelId) { - socketRef.current?.emit("join", channelId); + if (channel) { + socketRef.current?.emit("join", channel); // Listen for 'message' events socketRef.current?.on("message", (message: ChatCompletion) => { @@ -28,14 +28,14 @@ export default function useSocket(channelId?: string) { // Unsubscribe and disconnect on cleanup return () => { if (socketRef.current) { - if (channelId) { + if (channel) { socketRef.current.off("message"); } socketRef.current.disconnect(); } setMessage(null); }; - }, [channelId]); + }, [channel]); return message; }