move app to app/ subdir
This commit is contained in:
108
app/src/server/api/autogenerate/autogenerateDatasetEntries.ts
Normal file
108
app/src/server/api/autogenerate/autogenerateDatasetEntries.ts
Normal file
@@ -0,0 +1,108 @@
|
||||
import { type ChatCompletion } from "openai/resources/chat";
|
||||
import { openai } from "../../utils/openai";
|
||||
import { isAxiosError } from "./utils";
|
||||
import { type APIResponse } from "openai/core";
|
||||
import { sleep } from "~/server/utils/sleep";
|
||||
|
||||
const MAX_AUTO_RETRIES = 50;
|
||||
const MIN_DELAY = 500; // milliseconds
|
||||
const MAX_DELAY = 15000; // milliseconds
|
||||
|
||||
function calculateDelay(numPreviousTries: number): number {
|
||||
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
|
||||
const jitter = Math.random() * baseDelay;
|
||||
return baseDelay + jitter;
|
||||
}
|
||||
|
||||
const getCompletionWithBackoff = async (
|
||||
getCompletion: () => Promise<APIResponse<ChatCompletion>>,
|
||||
) => {
|
||||
let completion;
|
||||
let tries = 0;
|
||||
while (tries < MAX_AUTO_RETRIES) {
|
||||
try {
|
||||
completion = await getCompletion();
|
||||
break;
|
||||
} catch (e) {
|
||||
if (isAxiosError(e)) {
|
||||
console.error(e?.response?.data?.error?.message);
|
||||
} else {
|
||||
await sleep(calculateDelay(tries));
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
tries++;
|
||||
}
|
||||
return completion;
|
||||
};
|
||||
// TODO: Add seeds to ensure batches don't contain duplicate data
|
||||
const MAX_BATCH_SIZE = 5;
|
||||
|
||||
export const autogenerateDatasetEntries = async (
|
||||
numToGenerate: number,
|
||||
inputDescription: string,
|
||||
outputDescription: string,
|
||||
): Promise<{ input: string; output: string }[]> => {
|
||||
const batchSizes = Array.from({ length: Math.ceil(numToGenerate / MAX_BATCH_SIZE) }, (_, i) =>
|
||||
i === Math.ceil(numToGenerate / MAX_BATCH_SIZE) - 1 && numToGenerate % MAX_BATCH_SIZE
|
||||
? numToGenerate % MAX_BATCH_SIZE
|
||||
: MAX_BATCH_SIZE,
|
||||
);
|
||||
|
||||
const getCompletion = (batchSize: number) =>
|
||||
openai.chat.completions.create({
|
||||
model: "gpt-4",
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: `The user needs ${batchSize} rows of data, each with an input and an output.\n---\n The input should follow these requirements: ${inputDescription}\n---\n The output should follow these requirements: ${outputDescription}`,
|
||||
},
|
||||
],
|
||||
functions: [
|
||||
{
|
||||
name: "add_list_of_data",
|
||||
description: "Add a list of data to the database",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
rows: {
|
||||
type: "array",
|
||||
description: "The rows of data that match the description",
|
||||
items: {
|
||||
type: "object",
|
||||
properties: {
|
||||
input: {
|
||||
type: "string",
|
||||
description: "The input for this row",
|
||||
},
|
||||
output: {
|
||||
type: "string",
|
||||
description: "The output for this row",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
|
||||
function_call: { name: "add_list_of_data" },
|
||||
temperature: 0.5,
|
||||
});
|
||||
|
||||
const completionCallbacks = batchSizes.map((batchSize) =>
|
||||
getCompletionWithBackoff(() => getCompletion(batchSize)),
|
||||
);
|
||||
|
||||
const completions = await Promise.all(completionCallbacks);
|
||||
|
||||
const rows = completions.flatMap((completion) => {
|
||||
const parsed = JSON.parse(
|
||||
completion?.choices[0]?.message?.function_call?.arguments ?? "{rows: []}",
|
||||
) as { rows: { input: string; output: string }[] };
|
||||
return parsed.rows;
|
||||
});
|
||||
|
||||
return rows;
|
||||
};
|
||||
Reference in New Issue
Block a user