Update autogen.ts
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import { type CreateChatCompletionRequest } from "openai";
|
|
||||||
|
import { type CompletionCreateParams } from "openai/resources/chat";
|
||||||
import { prisma } from "../db";
|
import { prisma } from "../db";
|
||||||
import { openai } from "../utils/openai";
|
import { openai } from "../utils/openai";
|
||||||
import { pick } from "lodash";
|
import { pick } from "lodash";
|
||||||
@@ -62,7 +63,7 @@ export const autogenerateScenarioValues = async (
|
|||||||
|
|
||||||
if (!experiment || !(variables?.length > 0) || !prompt) return {};
|
if (!experiment || !(variables?.length > 0) || !prompt) return {};
|
||||||
|
|
||||||
const messages: CreateChatCompletionRequest["messages"] = [
|
const messages: CompletionCreateParams.CreateChatCompletionRequestNonStreaming["messages"] = [
|
||||||
{
|
{
|
||||||
role: "system",
|
role: "system",
|
||||||
content:
|
content:
|
||||||
@@ -90,7 +91,6 @@ export const autogenerateScenarioValues = async (
|
|||||||
.forEach((vals) => {
|
.forEach((vals) => {
|
||||||
messages.push({
|
messages.push({
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
// @ts-expect-error the openai type definition is wrong, the content field is required
|
|
||||||
content: null,
|
content: null,
|
||||||
function_call: {
|
function_call: {
|
||||||
name: "add_scenario",
|
name: "add_scenario",
|
||||||
@@ -105,7 +105,7 @@ export const autogenerateScenarioValues = async (
|
|||||||
}, {} as Record<string, { type: "string" }>);
|
}, {} as Record<string, { type: "string" }>);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const completion = await openai.createChatCompletion({
|
const completion = await openai.chat.completions.create({
|
||||||
model: "gpt-3.5-turbo-0613",
|
model: "gpt-3.5-turbo-0613",
|
||||||
messages,
|
messages,
|
||||||
functions: [
|
functions: [
|
||||||
@@ -123,7 +123,7 @@ export const autogenerateScenarioValues = async (
|
|||||||
});
|
});
|
||||||
|
|
||||||
const parsed = JSON.parse(
|
const parsed = JSON.parse(
|
||||||
completion.data.choices[0]?.message?.function_call?.arguments ?? "{}"
|
completion.choices[0]?.message?.function_call?.arguments ?? "{}"
|
||||||
) as Record<string, string>;
|
) as Record<string, string>;
|
||||||
return parsed;
|
return parsed;
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
|
|||||||
@@ -1,8 +1,59 @@
|
|||||||
import { Configuration, OpenAIApi } from "openai";
|
import { omit } from "lodash";
|
||||||
import { env } from "~/env.mjs";
|
import { env } from "~/env.mjs";
|
||||||
|
|
||||||
const configuration = new Configuration({
|
import OpenAI from "openai";
|
||||||
apiKey: env.OPENAI_API_KEY,
|
import { type ChatCompletion, type ChatCompletionChunk, type CompletionCreateParams } from "openai/resources/chat";
|
||||||
});
|
|
||||||
|
|
||||||
export const openai = new OpenAIApi(configuration);
|
// console.log("creating openai client");
|
||||||
|
|
||||||
|
export const openai = new OpenAI({ apiKey: env.OPENAI_API_KEY });
|
||||||
|
|
||||||
|
export const mergeStreamedChunks = (
|
||||||
|
base: ChatCompletion | null,
|
||||||
|
chunk: ChatCompletionChunk
|
||||||
|
): ChatCompletion => {
|
||||||
|
if (base === null) {
|
||||||
|
return mergeStreamedChunks({ ...chunk, choices: [] }, chunk);
|
||||||
|
}
|
||||||
|
|
||||||
|
const choices = [...base.choices];
|
||||||
|
for (const choice of chunk.choices) {
|
||||||
|
const baseChoice = choices.find((c) => c.index === choice.index);
|
||||||
|
if (baseChoice) {
|
||||||
|
baseChoice.finish_reason = choice.finish_reason ?? baseChoice.finish_reason;
|
||||||
|
baseChoice.message = baseChoice.message ?? { role: "assistant" };
|
||||||
|
|
||||||
|
if (choice.delta?.content)
|
||||||
|
baseChoice.message.content =
|
||||||
|
(baseChoice.message.content as string ?? "") + (choice.delta.content ?? "");
|
||||||
|
if (choice.delta?.function_call) {
|
||||||
|
const fnCall = baseChoice.message.function_call ?? {};
|
||||||
|
fnCall.name = (fnCall.name as string ?? "") + (choice.delta.function_call.name as string ?? "");
|
||||||
|
fnCall.arguments = (fnCall.arguments as string ?? "") + (choice.delta.function_call.arguments as string ?? "");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
choices.push({ ...omit(choice, "delta"), message: { role: "assistant", ...choice.delta } });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const merged: ChatCompletion = {
|
||||||
|
...base,
|
||||||
|
choices,
|
||||||
|
};
|
||||||
|
|
||||||
|
return merged;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const streamChatCompletion = async function* (body: CompletionCreateParams) {
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-unsafe-call
|
||||||
|
const resp = await openai.chat.completions.create({
|
||||||
|
...body,
|
||||||
|
stream: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
let mergedChunks: ChatCompletion | null = null;
|
||||||
|
for await (const part of resp) {
|
||||||
|
mergedChunks = mergeStreamedChunks(mergedChunks, part);
|
||||||
|
yield mergedChunks;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user