Requeue rate-limited query model tasks (#99)

* Continue polling stats until all evals complete

* Return evaluation changes early, before it has run

* Add task for running new eval

* requeue rate-limited tasks

* Fix prettier
This commit is contained in:
arcticfly
2023-07-26 16:30:50 -07:00
committed by GitHub
parent 807665fdc1
commit 26b6fa4f0c
8 changed files with 90 additions and 73 deletions

View File

@@ -6,15 +6,15 @@ import { wsConnection } from "~/utils/wsConnection";
import { runEvalsForOutput } from "../utils/evaluations";
import hashPrompt from "../utils/hashPrompt";
import parseConstructFn from "../utils/parseConstructFn";
import { sleep } from "../utils/sleep";
import defineTask from "./defineTask";
export type QueryModelJob = {
cellId: string;
stream: boolean;
numPreviousTries: number;
};
const MAX_AUTO_RETRIES = 10;
const MAX_AUTO_RETRIES = 50;
const MIN_DELAY = 500; // milliseconds
const MAX_DELAY = 15000; // milliseconds
@@ -26,7 +26,7 @@ function calculateDelay(numPreviousTries: number): number {
export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) => {
console.log("RUNNING TASK", task);
const { cellId, stream } = task;
const { cellId, stream, numPreviousTries } = task;
const cell = await prisma.scenarioVariantCell.findUnique({
where: { id: cellId },
include: { modelResponses: true },
@@ -98,62 +98,72 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
const inputHash = hashPrompt(prompt);
for (let i = 0; true; i++) {
let modelResponse = await prisma.modelResponse.create({
let modelResponse = await prisma.modelResponse.create({
data: {
inputHash,
scenarioVariantCellId: cellId,
requestedAt: new Date(),
},
});
const response = await provider.getCompletion(prompt.modelInput, onStream);
if (response.type === "success") {
modelResponse = await prisma.modelResponse.update({
where: { id: modelResponse.id },
data: {
inputHash,
scenarioVariantCellId: cellId,
requestedAt: new Date(),
output: response.value as Prisma.InputJsonObject,
statusCode: response.statusCode,
receivedAt: new Date(),
promptTokens: response.promptTokens,
completionTokens: response.completionTokens,
cost: response.cost,
},
});
const response = await provider.getCompletion(prompt.modelInput, onStream);
if (response.type === "success") {
modelResponse = await prisma.modelResponse.update({
where: { id: modelResponse.id },
data: {
output: response.value as Prisma.InputJsonObject,
statusCode: response.statusCode,
receivedAt: new Date(),
promptTokens: response.promptTokens,
completionTokens: response.completionTokens,
cost: response.cost,
},
});
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
retrievalStatus: "COMPLETE",
},
});
await runEvalsForOutput(variant.experimentId, scenario, modelResponse, prompt.modelProvider);
} else {
const shouldRetry = response.autoRetry && numPreviousTries < MAX_AUTO_RETRIES;
const delay = calculateDelay(numPreviousTries);
const retryTime = new Date(Date.now() + delay);
await prisma.modelResponse.update({
where: { id: modelResponse.id },
data: {
statusCode: response.statusCode,
errorMessage: response.message,
receivedAt: new Date(),
retryTime: shouldRetry ? retryTime : null,
},
});
if (shouldRetry) {
await queryModel.enqueue(
{
cellId,
stream,
numPreviousTries: numPreviousTries + 1,
},
retryTime,
);
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
retrievalStatus: "COMPLETE",
retrievalStatus: "PENDING",
},
});
await runEvalsForOutput(variant.experimentId, scenario, modelResponse, prompt.modelProvider);
break;
} else {
const shouldRetry = response.autoRetry && i < MAX_AUTO_RETRIES;
const delay = calculateDelay(i);
await prisma.modelResponse.update({
where: { id: modelResponse.id },
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
statusCode: response.statusCode,
errorMessage: response.message,
receivedAt: new Date(),
retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
retrievalStatus: "ERROR",
},
});
if (shouldRetry) {
await sleep(delay);
} else {
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
retrievalStatus: "ERROR",
},
});
break;
}
}
}
});
@@ -170,6 +180,6 @@ export const queueQueryModel = async (cellId: string, stream: boolean) => {
jobQueuedAt: new Date(),
},
}),
queryModel.enqueue({ cellId, stream }),
queryModel.enqueue({ cellId, stream, numPreviousTries: 0 }),
]);
};