Rename channelId to channel
This commit is contained in:
@@ -9,7 +9,7 @@ import { useMemo, type ReactElement } from "react";
|
|||||||
import { BsClock } from "react-icons/bs";
|
import { BsClock } from "react-icons/bs";
|
||||||
import { type ModelOutput } from "@prisma/client";
|
import { type ModelOutput } from "@prisma/client";
|
||||||
import { type ChatCompletion } from "openai/resources/chat";
|
import { type ChatCompletion } from "openai/resources/chat";
|
||||||
import { generateChannelId } from "~/server/utils/generateChannelId";
|
import { generateChannel } from "~/utils/generateChannel";
|
||||||
import { isObject } from "lodash";
|
import { isObject } from "lodash";
|
||||||
import useSocket from "~/utils/useSocket";
|
import useSocket from "~/utils/useSocket";
|
||||||
|
|
||||||
@@ -42,22 +42,22 @@ export default function OutputCell({
|
|||||||
isObject(variant.config) &&
|
isObject(variant.config) &&
|
||||||
"stream" in variant.config &&
|
"stream" in variant.config &&
|
||||||
variant.config.stream === true;
|
variant.config.stream === true;
|
||||||
const channelId = useMemo(() => {
|
const channel = useMemo(() => {
|
||||||
if (!shouldStream) return;
|
if (!shouldStream) return;
|
||||||
return generateChannelId();
|
return generateChannel();
|
||||||
}, [shouldStream]);
|
}, [shouldStream]);
|
||||||
|
|
||||||
const output = api.outputs.get.useQuery(
|
const output = api.outputs.get.useQuery(
|
||||||
{
|
{
|
||||||
scenarioId: scenario.id,
|
scenarioId: scenario.id,
|
||||||
variantId: variant.id,
|
variantId: variant.id,
|
||||||
channelId,
|
channel,
|
||||||
},
|
},
|
||||||
{ enabled: disabledReason === null }
|
{ enabled: disabledReason === null }
|
||||||
);
|
);
|
||||||
|
|
||||||
// Disconnect from socket if we're not streaming anymore
|
// Disconnect from socket if we're not streaming anymore
|
||||||
const streamedMessage = useSocket(output.isLoading ? channelId : undefined);
|
const streamedMessage = useSocket(output.isLoading ? channel : undefined);
|
||||||
const streamedContent = streamedMessage?.choices?.[0]?.message?.content;
|
const streamedContent = streamedMessage?.choices?.[0]?.message?.content;
|
||||||
|
|
||||||
if (!vars) return null;
|
if (!vars) return null;
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ env;
|
|||||||
|
|
||||||
export const modelOutputsRouter = createTRPCRouter({
|
export const modelOutputsRouter = createTRPCRouter({
|
||||||
get: publicProcedure
|
get: publicProcedure
|
||||||
.input(z.object({ scenarioId: z.string(), variantId: z.string(), channelId: z.string().optional() }))
|
.input(z.object({ scenarioId: z.string(), variantId: z.string(), channel: z.string().optional() }))
|
||||||
.query(async ({ input }) => {
|
.query(async ({ input }) => {
|
||||||
const existing = await prisma.modelOutput.findUnique({
|
const existing = await prisma.modelOutput.findUnique({
|
||||||
where: {
|
where: {
|
||||||
@@ -64,7 +64,7 @@ export const modelOutputsRouter = createTRPCRouter({
|
|||||||
timeToComplete: existingResponse.timeToComplete,
|
timeToComplete: existingResponse.timeToComplete,
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
modelResponse = await getChatCompletion(filledTemplate, env.OPENAI_API_KEY, input.channelId);
|
modelResponse = await getChatCompletion(filledTemplate, env.OPENAI_API_KEY, input.channel);
|
||||||
}
|
}
|
||||||
|
|
||||||
const modelOutput = await prisma.modelOutput.create({
|
const modelOutput = await prisma.modelOutput.create({
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ type CompletionResponse = {
|
|||||||
export async function getChatCompletion(
|
export async function getChatCompletion(
|
||||||
payload: JSONSerializable,
|
payload: JSONSerializable,
|
||||||
apiKey: string,
|
apiKey: string,
|
||||||
channelId?: string,
|
channel?: string,
|
||||||
): Promise<CompletionResponse> {
|
): Promise<CompletionResponse> {
|
||||||
const start = Date.now();
|
const start = Date.now();
|
||||||
const response = await fetch("https://api.openai.com/v1/chat/completions", {
|
const response = await fetch("https://api.openai.com/v1/chat/completions", {
|
||||||
@@ -36,13 +36,13 @@ export async function getChatCompletion(
|
|||||||
};
|
};
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (channelId) {
|
if (channel) {
|
||||||
const completion = streamChatCompletion(payload as unknown as CompletionCreateParams);
|
const completion = streamChatCompletion(payload as unknown as CompletionCreateParams);
|
||||||
let finalOutput: ChatCompletion | null = null;
|
let finalOutput: ChatCompletion | null = null;
|
||||||
await (async () => {
|
await (async () => {
|
||||||
for await (const partialCompletion of completion) {
|
for await (const partialCompletion of completion) {
|
||||||
finalOutput = partialCompletion
|
finalOutput = partialCompletion
|
||||||
wsConnection.emit("message", { channel: channelId, payload: partialCompletion });
|
wsConnection.emit("message", { channel, payload: partialCompletion });
|
||||||
}
|
}
|
||||||
})().catch((err) => console.error(err));
|
})().catch((err) => console.error(err));
|
||||||
resp.output = finalOutput as unknown as Prisma.InputJsonValue;
|
resp.output = finalOutput as unknown as Prisma.InputJsonValue;
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
// generate random channel id
|
// generate random channel id
|
||||||
|
|
||||||
export const generateChannelId = () => {
|
export const generateChannel = () => {
|
||||||
return Math.random().toString(36).substring(2, 15) + Math.random().toString(36).substring(2, 15);
|
return Math.random().toString(36).substring(2, 15) + Math.random().toString(36).substring(2, 15);
|
||||||
};
|
};
|
||||||
@@ -5,7 +5,7 @@ import { env } from "~/env.mjs";
|
|||||||
|
|
||||||
const url = env.NEXT_PUBLIC_SOCKET_URL;
|
const url = env.NEXT_PUBLIC_SOCKET_URL;
|
||||||
|
|
||||||
export default function useSocket(channelId?: string) {
|
export default function useSocket(channel?: string) {
|
||||||
const socketRef = useRef<Socket>();
|
const socketRef = useRef<Socket>();
|
||||||
const [message, setMessage] = useState<ChatCompletion | null>(null);
|
const [message, setMessage] = useState<ChatCompletion | null>(null);
|
||||||
|
|
||||||
@@ -15,8 +15,8 @@ export default function useSocket(channelId?: string) {
|
|||||||
|
|
||||||
socketRef.current.on("connect", () => {
|
socketRef.current.on("connect", () => {
|
||||||
// Join the specific room
|
// Join the specific room
|
||||||
if (channelId) {
|
if (channel) {
|
||||||
socketRef.current?.emit("join", channelId);
|
socketRef.current?.emit("join", channel);
|
||||||
|
|
||||||
// Listen for 'message' events
|
// Listen for 'message' events
|
||||||
socketRef.current?.on("message", (message: ChatCompletion) => {
|
socketRef.current?.on("message", (message: ChatCompletion) => {
|
||||||
@@ -28,14 +28,14 @@ export default function useSocket(channelId?: string) {
|
|||||||
// Unsubscribe and disconnect on cleanup
|
// Unsubscribe and disconnect on cleanup
|
||||||
return () => {
|
return () => {
|
||||||
if (socketRef.current) {
|
if (socketRef.current) {
|
||||||
if (channelId) {
|
if (channel) {
|
||||||
socketRef.current.off("message");
|
socketRef.current.off("message");
|
||||||
}
|
}
|
||||||
socketRef.current.disconnect();
|
socketRef.current.disconnect();
|
||||||
}
|
}
|
||||||
setMessage(null);
|
setMessage(null);
|
||||||
};
|
};
|
||||||
}, [channelId]);
|
}, [channel]);
|
||||||
|
|
||||||
return message;
|
return message;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user