109 lines
3.3 KiB
TypeScript
109 lines
3.3 KiB
TypeScript
import { type ChatCompletion } from "openai/resources/chat";
|
|
import { openai } from "../../utils/openai";
|
|
import { isAxiosError } from "./utils";
|
|
import { type APIResponse } from "openai/core";
|
|
import { sleep } from "~/server/utils/sleep";
|
|
|
|
const MAX_AUTO_RETRIES = 50;
|
|
const MIN_DELAY = 500; // milliseconds
|
|
const MAX_DELAY = 15000; // milliseconds
|
|
|
|
function calculateDelay(numPreviousTries: number): number {
|
|
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
|
|
const jitter = Math.random() * baseDelay;
|
|
return baseDelay + jitter;
|
|
}
|
|
|
|
const getCompletionWithBackoff = async (
|
|
getCompletion: () => Promise<APIResponse<ChatCompletion>>,
|
|
) => {
|
|
let completion;
|
|
let tries = 0;
|
|
while (tries < MAX_AUTO_RETRIES) {
|
|
try {
|
|
completion = await getCompletion();
|
|
break;
|
|
} catch (e) {
|
|
if (isAxiosError(e)) {
|
|
console.error(e?.response?.data?.error?.message);
|
|
} else {
|
|
await sleep(calculateDelay(tries));
|
|
console.error(e);
|
|
}
|
|
}
|
|
tries++;
|
|
}
|
|
return completion;
|
|
};
|
|
// TODO: Add seeds to ensure batches don't contain duplicate data
|
|
const MAX_BATCH_SIZE = 5;
|
|
|
|
export const autogenerateDatasetEntries = async (
|
|
numToGenerate: number,
|
|
inputDescription: string,
|
|
outputDescription: string,
|
|
): Promise<{ input: string; output: string }[]> => {
|
|
const batchSizes = Array.from({ length: Math.ceil(numToGenerate / MAX_BATCH_SIZE) }, (_, i) =>
|
|
i === Math.ceil(numToGenerate / MAX_BATCH_SIZE) - 1 && numToGenerate % MAX_BATCH_SIZE
|
|
? numToGenerate % MAX_BATCH_SIZE
|
|
: MAX_BATCH_SIZE,
|
|
);
|
|
|
|
const getCompletion = (batchSize: number) =>
|
|
openai.chat.completions.create({
|
|
model: "gpt-4",
|
|
messages: [
|
|
{
|
|
role: "system",
|
|
content: `The user needs ${batchSize} rows of data, each with an input and an output.\n---\n The input should follow these requirements: ${inputDescription}\n---\n The output should follow these requirements: ${outputDescription}`,
|
|
},
|
|
],
|
|
functions: [
|
|
{
|
|
name: "add_list_of_data",
|
|
description: "Add a list of data to the database",
|
|
parameters: {
|
|
type: "object",
|
|
properties: {
|
|
rows: {
|
|
type: "array",
|
|
description: "The rows of data that match the description",
|
|
items: {
|
|
type: "object",
|
|
properties: {
|
|
input: {
|
|
type: "string",
|
|
description: "The input for this row",
|
|
},
|
|
output: {
|
|
type: "string",
|
|
description: "The output for this row",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
],
|
|
|
|
function_call: { name: "add_list_of_data" },
|
|
temperature: 0.5,
|
|
});
|
|
|
|
const completionCallbacks = batchSizes.map((batchSize) =>
|
|
getCompletionWithBackoff(() => getCompletion(batchSize)),
|
|
);
|
|
|
|
const completions = await Promise.all(completionCallbacks);
|
|
|
|
const rows = completions.flatMap((completion) => {
|
|
const parsed = JSON.parse(
|
|
completion?.choices[0]?.message?.function_call?.arguments ?? "{rows: []}",
|
|
) as { rows: { input: string; output: string }[] };
|
|
return parsed.rows;
|
|
});
|
|
|
|
return rows;
|
|
};
|