import { type ChatCompletion } from "openai/resources/chat"; import { openai } from "../../utils/openai"; import { isAxiosError } from "./utils"; import { type APIResponse } from "openai/core"; import { sleep } from "~/server/utils/sleep"; const MAX_AUTO_RETRIES = 50; const MIN_DELAY = 500; // milliseconds const MAX_DELAY = 15000; // milliseconds function calculateDelay(numPreviousTries: number): number { const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries)); const jitter = Math.random() * baseDelay; return baseDelay + jitter; } const getCompletionWithBackoff = async ( getCompletion: () => Promise>, ) => { let completion; let tries = 0; while (tries < MAX_AUTO_RETRIES) { try { completion = await getCompletion(); break; } catch (e) { if (isAxiosError(e)) { console.error(e?.response?.data?.error?.message); } else { await sleep(calculateDelay(tries)); console.error(e); } } tries++; } return completion; }; // TODO: Add seeds to ensure batches don't contain duplicate data const MAX_BATCH_SIZE = 5; export const autogenerateDatasetEntries = async ( numToGenerate: number, inputDescription: string, outputDescription: string, ): Promise<{ input: string; output: string }[]> => { const batchSizes = Array.from({ length: Math.ceil(numToGenerate / MAX_BATCH_SIZE) }, (_, i) => i === Math.ceil(numToGenerate / MAX_BATCH_SIZE) - 1 && numToGenerate % MAX_BATCH_SIZE ? numToGenerate % MAX_BATCH_SIZE : MAX_BATCH_SIZE, ); const getCompletion = (batchSize: number) => openai.chat.completions.create({ model: "gpt-4", messages: [ { role: "system", content: `The user needs ${batchSize} rows of data, each with an input and an output.\n---\n The input should follow these requirements: ${inputDescription}\n---\n The output should follow these requirements: ${outputDescription}`, }, ], functions: [ { name: "add_list_of_data", description: "Add a list of data to the database", parameters: { type: "object", properties: { rows: { type: "array", description: "The rows of data that match the description", items: { type: "object", properties: { input: { type: "string", description: "The input for this row", }, output: { type: "string", description: "The output for this row", }, }, }, }, }, }, }, ], function_call: { name: "add_list_of_data" }, temperature: 0.5, }); const completionCallbacks = batchSizes.map((batchSize) => getCompletionWithBackoff(() => getCompletion(batchSize)), ); const completions = await Promise.all(completionCallbacks); const rows = completions.flatMap((completion) => { const parsed = JSON.parse( completion?.choices[0]?.message?.function_call?.arguments ?? "{rows: []}", ) as { rows: { input: string; output: string }[] }; return parsed.rows; }); return rows; };