Requeue rate-limited query model tasks (#99)
* Continue polling stats until all evals complete * Return evaluation changes early, before it has run * Add task for running new eval * requeue rate-limited tasks * Fix prettier
This commit is contained in:
@@ -21,17 +21,14 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
||||
completionTokens: 0,
|
||||
scenarioCount: 0,
|
||||
outputCount: 0,
|
||||
awaitingRetrievals: false,
|
||||
awaitingEvals: false,
|
||||
},
|
||||
refetchInterval,
|
||||
},
|
||||
);
|
||||
|
||||
// Poll every two seconds while we are waiting for LLM retrievals to finish
|
||||
useEffect(
|
||||
() => setRefetchInterval(data.awaitingRetrievals ? 2000 : 0),
|
||||
[data.awaitingRetrievals],
|
||||
);
|
||||
useEffect(() => setRefetchInterval(data.awaitingEvals ? 5000 : 0), [data.awaitingEvals]);
|
||||
|
||||
const [passColor, neutralColor, failColor] = useToken("colors", [
|
||||
"green.500",
|
||||
@@ -69,7 +66,7 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
||||
);
|
||||
})}
|
||||
</HStack>
|
||||
{data.overallCost && !data.awaitingRetrievals && (
|
||||
{data.overallCost && (
|
||||
<CostTooltip
|
||||
promptTokens={data.promptTokens}
|
||||
completionTokens={data.completionTokens}
|
||||
|
||||
@@ -2,7 +2,7 @@ import { EvalType } from "@prisma/client";
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { runAllEvals } from "~/server/utils/evaluations";
|
||||
import { queueRunNewEval } from "~/server/tasks/runNewEval.task";
|
||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||
|
||||
export const evaluationsRouter = createTRPCRouter({
|
||||
@@ -40,9 +40,7 @@ export const evaluationsRouter = createTRPCRouter({
|
||||
},
|
||||
});
|
||||
|
||||
// TODO: this may be a bad UX for slow evals (eg. GPT-4 evals) Maybe need
|
||||
// to kick off a background job or something instead
|
||||
await runAllEvals(input.experimentId);
|
||||
await queueRunNewEval(input.experimentId);
|
||||
}),
|
||||
|
||||
update: protectedProcedure
|
||||
@@ -78,7 +76,7 @@ export const evaluationsRouter = createTRPCRouter({
|
||||
});
|
||||
// Re-run all evals. Other eval results will already be cached, so this
|
||||
// should only re-run the updated one.
|
||||
await runAllEvals(evaluation.experimentId);
|
||||
await queueRunNewEval(experimentId);
|
||||
}),
|
||||
|
||||
delete: protectedProcedure
|
||||
|
||||
@@ -130,16 +130,9 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
|
||||
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
|
||||
|
||||
const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({
|
||||
where: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenario: { visible: true },
|
||||
// Check if is PENDING or IN_PROGRESS
|
||||
retrievalStatus: {
|
||||
in: ["PENDING", "IN_PROGRESS"],
|
||||
},
|
||||
},
|
||||
}));
|
||||
const awaitingEvals = !!evalResults.find(
|
||||
(result) => result.totalCount < scenarioCount * evals.length,
|
||||
);
|
||||
|
||||
return {
|
||||
evalResults,
|
||||
@@ -148,7 +141,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
overallCost: overallTokens._sum?.cost ?? 0,
|
||||
scenarioCount,
|
||||
outputCount,
|
||||
awaitingRetrievals,
|
||||
awaitingEvals,
|
||||
};
|
||||
}),
|
||||
|
||||
|
||||
@@ -7,9 +7,9 @@ function defineTask<TPayload>(
|
||||
taskIdentifier: string,
|
||||
taskHandler: (payload: TPayload, helpers: Helpers) => Promise<void>,
|
||||
) {
|
||||
const enqueue = async (payload: TPayload) => {
|
||||
const enqueue = async (payload: TPayload, runAt?: Date) => {
|
||||
console.log("Enqueuing task", taskIdentifier, payload);
|
||||
await quickAddJob({ connectionString: env.DATABASE_URL }, taskIdentifier, payload);
|
||||
await quickAddJob({ connectionString: env.DATABASE_URL }, taskIdentifier, payload, { runAt });
|
||||
};
|
||||
|
||||
const handler = (payload: TPayload, helpers: Helpers) => {
|
||||
|
||||
@@ -6,15 +6,15 @@ import { wsConnection } from "~/utils/wsConnection";
|
||||
import { runEvalsForOutput } from "../utils/evaluations";
|
||||
import hashPrompt from "../utils/hashPrompt";
|
||||
import parseConstructFn from "../utils/parseConstructFn";
|
||||
import { sleep } from "../utils/sleep";
|
||||
import defineTask from "./defineTask";
|
||||
|
||||
export type QueryModelJob = {
|
||||
cellId: string;
|
||||
stream: boolean;
|
||||
numPreviousTries: number;
|
||||
};
|
||||
|
||||
const MAX_AUTO_RETRIES = 10;
|
||||
const MAX_AUTO_RETRIES = 50;
|
||||
const MIN_DELAY = 500; // milliseconds
|
||||
const MAX_DELAY = 15000; // milliseconds
|
||||
|
||||
@@ -26,7 +26,7 @@ function calculateDelay(numPreviousTries: number): number {
|
||||
|
||||
export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) => {
|
||||
console.log("RUNNING TASK", task);
|
||||
const { cellId, stream } = task;
|
||||
const { cellId, stream, numPreviousTries } = task;
|
||||
const cell = await prisma.scenarioVariantCell.findUnique({
|
||||
where: { id: cellId },
|
||||
include: { modelResponses: true },
|
||||
@@ -98,62 +98,72 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
|
||||
|
||||
const inputHash = hashPrompt(prompt);
|
||||
|
||||
for (let i = 0; true; i++) {
|
||||
let modelResponse = await prisma.modelResponse.create({
|
||||
let modelResponse = await prisma.modelResponse.create({
|
||||
data: {
|
||||
inputHash,
|
||||
scenarioVariantCellId: cellId,
|
||||
requestedAt: new Date(),
|
||||
},
|
||||
});
|
||||
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
||||
if (response.type === "success") {
|
||||
modelResponse = await prisma.modelResponse.update({
|
||||
where: { id: modelResponse.id },
|
||||
data: {
|
||||
inputHash,
|
||||
scenarioVariantCellId: cellId,
|
||||
requestedAt: new Date(),
|
||||
output: response.value as Prisma.InputJsonObject,
|
||||
statusCode: response.statusCode,
|
||||
receivedAt: new Date(),
|
||||
promptTokens: response.promptTokens,
|
||||
completionTokens: response.completionTokens,
|
||||
cost: response.cost,
|
||||
},
|
||||
});
|
||||
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
||||
if (response.type === "success") {
|
||||
modelResponse = await prisma.modelResponse.update({
|
||||
where: { id: modelResponse.id },
|
||||
data: {
|
||||
output: response.value as Prisma.InputJsonObject,
|
||||
statusCode: response.statusCode,
|
||||
receivedAt: new Date(),
|
||||
promptTokens: response.promptTokens,
|
||||
completionTokens: response.completionTokens,
|
||||
cost: response.cost,
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
retrievalStatus: "COMPLETE",
|
||||
},
|
||||
});
|
||||
|
||||
await runEvalsForOutput(variant.experimentId, scenario, modelResponse, prompt.modelProvider);
|
||||
} else {
|
||||
const shouldRetry = response.autoRetry && numPreviousTries < MAX_AUTO_RETRIES;
|
||||
const delay = calculateDelay(numPreviousTries);
|
||||
const retryTime = new Date(Date.now() + delay);
|
||||
|
||||
await prisma.modelResponse.update({
|
||||
where: { id: modelResponse.id },
|
||||
data: {
|
||||
statusCode: response.statusCode,
|
||||
errorMessage: response.message,
|
||||
receivedAt: new Date(),
|
||||
retryTime: shouldRetry ? retryTime : null,
|
||||
},
|
||||
});
|
||||
|
||||
if (shouldRetry) {
|
||||
await queryModel.enqueue(
|
||||
{
|
||||
cellId,
|
||||
stream,
|
||||
numPreviousTries: numPreviousTries + 1,
|
||||
},
|
||||
retryTime,
|
||||
);
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
retrievalStatus: "COMPLETE",
|
||||
retrievalStatus: "PENDING",
|
||||
},
|
||||
});
|
||||
|
||||
await runEvalsForOutput(variant.experimentId, scenario, modelResponse, prompt.modelProvider);
|
||||
break;
|
||||
} else {
|
||||
const shouldRetry = response.autoRetry && i < MAX_AUTO_RETRIES;
|
||||
const delay = calculateDelay(i);
|
||||
|
||||
await prisma.modelResponse.update({
|
||||
where: { id: modelResponse.id },
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
statusCode: response.statusCode,
|
||||
errorMessage: response.message,
|
||||
receivedAt: new Date(),
|
||||
retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
|
||||
if (shouldRetry) {
|
||||
await sleep(delay);
|
||||
} else {
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -170,6 +180,6 @@ export const queueQueryModel = async (cellId: string, stream: boolean) => {
|
||||
jobQueuedAt: new Date(),
|
||||
},
|
||||
}),
|
||||
queryModel.enqueue({ cellId, stream }),
|
||||
queryModel.enqueue({ cellId, stream, numPreviousTries: 0 }),
|
||||
]);
|
||||
};
|
||||
|
||||
17
src/server/tasks/runNewEval.task.ts
Normal file
17
src/server/tasks/runNewEval.task.ts
Normal file
@@ -0,0 +1,17 @@
|
||||
import { runAllEvals } from "../utils/evaluations";
|
||||
import defineTask from "./defineTask";
|
||||
|
||||
export type RunNewEvalJob = {
|
||||
experimentId: string;
|
||||
};
|
||||
|
||||
// When a new eval is created, we want to run it on all existing outputs, but return the new eval first
|
||||
export const runNewEval = defineTask<RunNewEvalJob>("runNewEval", async (task) => {
|
||||
console.log("RUNNING TASK", task);
|
||||
const { experimentId } = task;
|
||||
await runAllEvals(experimentId);
|
||||
});
|
||||
|
||||
export const queueRunNewEval = async (experimentId: string) => {
|
||||
await runNewEval.enqueue({ experimentId });
|
||||
};
|
||||
@@ -3,10 +3,11 @@ import "dotenv/config";
|
||||
|
||||
import { env } from "~/env.mjs";
|
||||
import { queryModel } from "./queryModel.task";
|
||||
import { runNewEval } from "./runNewEval.task";
|
||||
|
||||
console.log("Starting worker");
|
||||
|
||||
const registeredTasks = [queryModel];
|
||||
const registeredTasks = [queryModel, runNewEval];
|
||||
|
||||
const taskList = registeredTasks.reduce((acc, task) => {
|
||||
acc[task.task.identifier] = task.task.handler;
|
||||
@@ -16,7 +17,7 @@ const taskList = registeredTasks.reduce((acc, task) => {
|
||||
// Run a worker to execute jobs:
|
||||
const runner = await run({
|
||||
connectionString: env.DATABASE_URL,
|
||||
concurrency: 20,
|
||||
concurrency: 50,
|
||||
// Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc
|
||||
noHandleSignals: false,
|
||||
pollInterval: 1000,
|
||||
|
||||
@@ -46,6 +46,7 @@ export const runEvalsForOutput = async (
|
||||
);
|
||||
};
|
||||
|
||||
// Will not run eval-output pairs that already exist in the database
|
||||
export const runAllEvals = async (experimentId: string) => {
|
||||
const outputs = await prisma.modelResponse.findMany({
|
||||
where: {
|
||||
|
||||
Reference in New Issue
Block a user