Requeue rate-limited query model tasks (#99)
* Continue polling stats until all evals complete * Return evaluation changes early, before it has run * Add task for running new eval * requeue rate-limited tasks * Fix prettier
This commit is contained in:
@@ -21,17 +21,14 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
|||||||
completionTokens: 0,
|
completionTokens: 0,
|
||||||
scenarioCount: 0,
|
scenarioCount: 0,
|
||||||
outputCount: 0,
|
outputCount: 0,
|
||||||
awaitingRetrievals: false,
|
awaitingEvals: false,
|
||||||
},
|
},
|
||||||
refetchInterval,
|
refetchInterval,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
// Poll every two seconds while we are waiting for LLM retrievals to finish
|
// Poll every two seconds while we are waiting for LLM retrievals to finish
|
||||||
useEffect(
|
useEffect(() => setRefetchInterval(data.awaitingEvals ? 5000 : 0), [data.awaitingEvals]);
|
||||||
() => setRefetchInterval(data.awaitingRetrievals ? 2000 : 0),
|
|
||||||
[data.awaitingRetrievals],
|
|
||||||
);
|
|
||||||
|
|
||||||
const [passColor, neutralColor, failColor] = useToken("colors", [
|
const [passColor, neutralColor, failColor] = useToken("colors", [
|
||||||
"green.500",
|
"green.500",
|
||||||
@@ -69,7 +66,7 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
|||||||
);
|
);
|
||||||
})}
|
})}
|
||||||
</HStack>
|
</HStack>
|
||||||
{data.overallCost && !data.awaitingRetrievals && (
|
{data.overallCost && (
|
||||||
<CostTooltip
|
<CostTooltip
|
||||||
promptTokens={data.promptTokens}
|
promptTokens={data.promptTokens}
|
||||||
completionTokens={data.completionTokens}
|
completionTokens={data.completionTokens}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import { EvalType } from "@prisma/client";
|
|||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import { runAllEvals } from "~/server/utils/evaluations";
|
import { queueRunNewEval } from "~/server/tasks/runNewEval.task";
|
||||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||||
|
|
||||||
export const evaluationsRouter = createTRPCRouter({
|
export const evaluationsRouter = createTRPCRouter({
|
||||||
@@ -40,9 +40,7 @@ export const evaluationsRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
// TODO: this may be a bad UX for slow evals (eg. GPT-4 evals) Maybe need
|
await queueRunNewEval(input.experimentId);
|
||||||
// to kick off a background job or something instead
|
|
||||||
await runAllEvals(input.experimentId);
|
|
||||||
}),
|
}),
|
||||||
|
|
||||||
update: protectedProcedure
|
update: protectedProcedure
|
||||||
@@ -78,7 +76,7 @@ export const evaluationsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
// Re-run all evals. Other eval results will already be cached, so this
|
// Re-run all evals. Other eval results will already be cached, so this
|
||||||
// should only re-run the updated one.
|
// should only re-run the updated one.
|
||||||
await runAllEvals(evaluation.experimentId);
|
await queueRunNewEval(experimentId);
|
||||||
}),
|
}),
|
||||||
|
|
||||||
delete: protectedProcedure
|
delete: protectedProcedure
|
||||||
|
|||||||
@@ -130,16 +130,9 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
|
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
|
||||||
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
|
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
|
||||||
|
|
||||||
const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({
|
const awaitingEvals = !!evalResults.find(
|
||||||
where: {
|
(result) => result.totalCount < scenarioCount * evals.length,
|
||||||
promptVariantId: input.variantId,
|
);
|
||||||
testScenario: { visible: true },
|
|
||||||
// Check if is PENDING or IN_PROGRESS
|
|
||||||
retrievalStatus: {
|
|
||||||
in: ["PENDING", "IN_PROGRESS"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
evalResults,
|
evalResults,
|
||||||
@@ -148,7 +141,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
overallCost: overallTokens._sum?.cost ?? 0,
|
overallCost: overallTokens._sum?.cost ?? 0,
|
||||||
scenarioCount,
|
scenarioCount,
|
||||||
outputCount,
|
outputCount,
|
||||||
awaitingRetrievals,
|
awaitingEvals,
|
||||||
};
|
};
|
||||||
}),
|
}),
|
||||||
|
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ function defineTask<TPayload>(
|
|||||||
taskIdentifier: string,
|
taskIdentifier: string,
|
||||||
taskHandler: (payload: TPayload, helpers: Helpers) => Promise<void>,
|
taskHandler: (payload: TPayload, helpers: Helpers) => Promise<void>,
|
||||||
) {
|
) {
|
||||||
const enqueue = async (payload: TPayload) => {
|
const enqueue = async (payload: TPayload, runAt?: Date) => {
|
||||||
console.log("Enqueuing task", taskIdentifier, payload);
|
console.log("Enqueuing task", taskIdentifier, payload);
|
||||||
await quickAddJob({ connectionString: env.DATABASE_URL }, taskIdentifier, payload);
|
await quickAddJob({ connectionString: env.DATABASE_URL }, taskIdentifier, payload, { runAt });
|
||||||
};
|
};
|
||||||
|
|
||||||
const handler = (payload: TPayload, helpers: Helpers) => {
|
const handler = (payload: TPayload, helpers: Helpers) => {
|
||||||
|
|||||||
@@ -6,15 +6,15 @@ import { wsConnection } from "~/utils/wsConnection";
|
|||||||
import { runEvalsForOutput } from "../utils/evaluations";
|
import { runEvalsForOutput } from "../utils/evaluations";
|
||||||
import hashPrompt from "../utils/hashPrompt";
|
import hashPrompt from "../utils/hashPrompt";
|
||||||
import parseConstructFn from "../utils/parseConstructFn";
|
import parseConstructFn from "../utils/parseConstructFn";
|
||||||
import { sleep } from "../utils/sleep";
|
|
||||||
import defineTask from "./defineTask";
|
import defineTask from "./defineTask";
|
||||||
|
|
||||||
export type QueryModelJob = {
|
export type QueryModelJob = {
|
||||||
cellId: string;
|
cellId: string;
|
||||||
stream: boolean;
|
stream: boolean;
|
||||||
|
numPreviousTries: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
const MAX_AUTO_RETRIES = 10;
|
const MAX_AUTO_RETRIES = 50;
|
||||||
const MIN_DELAY = 500; // milliseconds
|
const MIN_DELAY = 500; // milliseconds
|
||||||
const MAX_DELAY = 15000; // milliseconds
|
const MAX_DELAY = 15000; // milliseconds
|
||||||
|
|
||||||
@@ -26,7 +26,7 @@ function calculateDelay(numPreviousTries: number): number {
|
|||||||
|
|
||||||
export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) => {
|
export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) => {
|
||||||
console.log("RUNNING TASK", task);
|
console.log("RUNNING TASK", task);
|
||||||
const { cellId, stream } = task;
|
const { cellId, stream, numPreviousTries } = task;
|
||||||
const cell = await prisma.scenarioVariantCell.findUnique({
|
const cell = await prisma.scenarioVariantCell.findUnique({
|
||||||
where: { id: cellId },
|
where: { id: cellId },
|
||||||
include: { modelResponses: true },
|
include: { modelResponses: true },
|
||||||
@@ -98,62 +98,72 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
|
|||||||
|
|
||||||
const inputHash = hashPrompt(prompt);
|
const inputHash = hashPrompt(prompt);
|
||||||
|
|
||||||
for (let i = 0; true; i++) {
|
let modelResponse = await prisma.modelResponse.create({
|
||||||
let modelResponse = await prisma.modelResponse.create({
|
data: {
|
||||||
|
inputHash,
|
||||||
|
scenarioVariantCellId: cellId,
|
||||||
|
requestedAt: new Date(),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
||||||
|
if (response.type === "success") {
|
||||||
|
modelResponse = await prisma.modelResponse.update({
|
||||||
|
where: { id: modelResponse.id },
|
||||||
data: {
|
data: {
|
||||||
inputHash,
|
output: response.value as Prisma.InputJsonObject,
|
||||||
scenarioVariantCellId: cellId,
|
statusCode: response.statusCode,
|
||||||
requestedAt: new Date(),
|
receivedAt: new Date(),
|
||||||
|
promptTokens: response.promptTokens,
|
||||||
|
completionTokens: response.completionTokens,
|
||||||
|
cost: response.cost,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
|
||||||
if (response.type === "success") {
|
|
||||||
modelResponse = await prisma.modelResponse.update({
|
|
||||||
where: { id: modelResponse.id },
|
|
||||||
data: {
|
|
||||||
output: response.value as Prisma.InputJsonObject,
|
|
||||||
statusCode: response.statusCode,
|
|
||||||
receivedAt: new Date(),
|
|
||||||
promptTokens: response.promptTokens,
|
|
||||||
completionTokens: response.completionTokens,
|
|
||||||
cost: response.cost,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
|
await prisma.scenarioVariantCell.update({
|
||||||
|
where: { id: cellId },
|
||||||
|
data: {
|
||||||
|
retrievalStatus: "COMPLETE",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
await runEvalsForOutput(variant.experimentId, scenario, modelResponse, prompt.modelProvider);
|
||||||
|
} else {
|
||||||
|
const shouldRetry = response.autoRetry && numPreviousTries < MAX_AUTO_RETRIES;
|
||||||
|
const delay = calculateDelay(numPreviousTries);
|
||||||
|
const retryTime = new Date(Date.now() + delay);
|
||||||
|
|
||||||
|
await prisma.modelResponse.update({
|
||||||
|
where: { id: modelResponse.id },
|
||||||
|
data: {
|
||||||
|
statusCode: response.statusCode,
|
||||||
|
errorMessage: response.message,
|
||||||
|
receivedAt: new Date(),
|
||||||
|
retryTime: shouldRetry ? retryTime : null,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (shouldRetry) {
|
||||||
|
await queryModel.enqueue(
|
||||||
|
{
|
||||||
|
cellId,
|
||||||
|
stream,
|
||||||
|
numPreviousTries: numPreviousTries + 1,
|
||||||
|
},
|
||||||
|
retryTime,
|
||||||
|
);
|
||||||
await prisma.scenarioVariantCell.update({
|
await prisma.scenarioVariantCell.update({
|
||||||
where: { id: cellId },
|
where: { id: cellId },
|
||||||
data: {
|
data: {
|
||||||
retrievalStatus: "COMPLETE",
|
retrievalStatus: "PENDING",
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await runEvalsForOutput(variant.experimentId, scenario, modelResponse, prompt.modelProvider);
|
|
||||||
break;
|
|
||||||
} else {
|
} else {
|
||||||
const shouldRetry = response.autoRetry && i < MAX_AUTO_RETRIES;
|
await prisma.scenarioVariantCell.update({
|
||||||
const delay = calculateDelay(i);
|
where: { id: cellId },
|
||||||
|
|
||||||
await prisma.modelResponse.update({
|
|
||||||
where: { id: modelResponse.id },
|
|
||||||
data: {
|
data: {
|
||||||
statusCode: response.statusCode,
|
retrievalStatus: "ERROR",
|
||||||
errorMessage: response.message,
|
|
||||||
receivedAt: new Date(),
|
|
||||||
retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
if (shouldRetry) {
|
|
||||||
await sleep(delay);
|
|
||||||
} else {
|
|
||||||
await prisma.scenarioVariantCell.update({
|
|
||||||
where: { id: cellId },
|
|
||||||
data: {
|
|
||||||
retrievalStatus: "ERROR",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -170,6 +180,6 @@ export const queueQueryModel = async (cellId: string, stream: boolean) => {
|
|||||||
jobQueuedAt: new Date(),
|
jobQueuedAt: new Date(),
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
queryModel.enqueue({ cellId, stream }),
|
queryModel.enqueue({ cellId, stream, numPreviousTries: 0 }),
|
||||||
]);
|
]);
|
||||||
};
|
};
|
||||||
|
|||||||
17
src/server/tasks/runNewEval.task.ts
Normal file
17
src/server/tasks/runNewEval.task.ts
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
import { runAllEvals } from "../utils/evaluations";
|
||||||
|
import defineTask from "./defineTask";
|
||||||
|
|
||||||
|
export type RunNewEvalJob = {
|
||||||
|
experimentId: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
// When a new eval is created, we want to run it on all existing outputs, but return the new eval first
|
||||||
|
export const runNewEval = defineTask<RunNewEvalJob>("runNewEval", async (task) => {
|
||||||
|
console.log("RUNNING TASK", task);
|
||||||
|
const { experimentId } = task;
|
||||||
|
await runAllEvals(experimentId);
|
||||||
|
});
|
||||||
|
|
||||||
|
export const queueRunNewEval = async (experimentId: string) => {
|
||||||
|
await runNewEval.enqueue({ experimentId });
|
||||||
|
};
|
||||||
@@ -3,10 +3,11 @@ import "dotenv/config";
|
|||||||
|
|
||||||
import { env } from "~/env.mjs";
|
import { env } from "~/env.mjs";
|
||||||
import { queryModel } from "./queryModel.task";
|
import { queryModel } from "./queryModel.task";
|
||||||
|
import { runNewEval } from "./runNewEval.task";
|
||||||
|
|
||||||
console.log("Starting worker");
|
console.log("Starting worker");
|
||||||
|
|
||||||
const registeredTasks = [queryModel];
|
const registeredTasks = [queryModel, runNewEval];
|
||||||
|
|
||||||
const taskList = registeredTasks.reduce((acc, task) => {
|
const taskList = registeredTasks.reduce((acc, task) => {
|
||||||
acc[task.task.identifier] = task.task.handler;
|
acc[task.task.identifier] = task.task.handler;
|
||||||
@@ -16,7 +17,7 @@ const taskList = registeredTasks.reduce((acc, task) => {
|
|||||||
// Run a worker to execute jobs:
|
// Run a worker to execute jobs:
|
||||||
const runner = await run({
|
const runner = await run({
|
||||||
connectionString: env.DATABASE_URL,
|
connectionString: env.DATABASE_URL,
|
||||||
concurrency: 20,
|
concurrency: 50,
|
||||||
// Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc
|
// Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc
|
||||||
noHandleSignals: false,
|
noHandleSignals: false,
|
||||||
pollInterval: 1000,
|
pollInterval: 1000,
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ export const runEvalsForOutput = async (
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Will not run eval-output pairs that already exist in the database
|
||||||
export const runAllEvals = async (experimentId: string) => {
|
export const runAllEvals = async (experimentId: string) => {
|
||||||
const outputs = await prisma.modelResponse.findMany({
|
const outputs = await prisma.modelResponse.findMany({
|
||||||
where: {
|
where: {
|
||||||
|
|||||||
Reference in New Issue
Block a user