diff --git a/src/components/OutputsTable/VariantStats.tsx b/src/components/OutputsTable/VariantStats.tsx index aa7198e..08ebfb2 100644 --- a/src/components/OutputsTable/VariantStats.tsx +++ b/src/components/OutputsTable/VariantStats.tsx @@ -21,17 +21,14 @@ export default function VariantStats(props: { variant: PromptVariant }) { completionTokens: 0, scenarioCount: 0, outputCount: 0, - awaitingRetrievals: false, + awaitingEvals: false, }, refetchInterval, }, ); // Poll every two seconds while we are waiting for LLM retrievals to finish - useEffect( - () => setRefetchInterval(data.awaitingRetrievals ? 2000 : 0), - [data.awaitingRetrievals], - ); + useEffect(() => setRefetchInterval(data.awaitingEvals ? 5000 : 0), [data.awaitingEvals]); const [passColor, neutralColor, failColor] = useToken("colors", [ "green.500", @@ -69,7 +66,7 @@ export default function VariantStats(props: { variant: PromptVariant }) { ); })} - {data.overallCost && !data.awaitingRetrievals && ( + {data.overallCost && ( result.totalCount < scenarioCount * evals.length, + ); return { evalResults, @@ -148,7 +141,7 @@ export const promptVariantsRouter = createTRPCRouter({ overallCost: overallTokens._sum?.cost ?? 0, scenarioCount, outputCount, - awaitingRetrievals, + awaitingEvals, }; }), diff --git a/src/server/tasks/defineTask.ts b/src/server/tasks/defineTask.ts index 64bd834..503c6c5 100644 --- a/src/server/tasks/defineTask.ts +++ b/src/server/tasks/defineTask.ts @@ -7,9 +7,9 @@ function defineTask( taskIdentifier: string, taskHandler: (payload: TPayload, helpers: Helpers) => Promise, ) { - const enqueue = async (payload: TPayload) => { + const enqueue = async (payload: TPayload, runAt?: Date) => { console.log("Enqueuing task", taskIdentifier, payload); - await quickAddJob({ connectionString: env.DATABASE_URL }, taskIdentifier, payload); + await quickAddJob({ connectionString: env.DATABASE_URL }, taskIdentifier, payload, { runAt }); }; const handler = (payload: TPayload, helpers: Helpers) => { diff --git a/src/server/tasks/queryModel.task.ts b/src/server/tasks/queryModel.task.ts index 929eb29..cd91487 100644 --- a/src/server/tasks/queryModel.task.ts +++ b/src/server/tasks/queryModel.task.ts @@ -6,15 +6,15 @@ import { wsConnection } from "~/utils/wsConnection"; import { runEvalsForOutput } from "../utils/evaluations"; import hashPrompt from "../utils/hashPrompt"; import parseConstructFn from "../utils/parseConstructFn"; -import { sleep } from "../utils/sleep"; import defineTask from "./defineTask"; export type QueryModelJob = { cellId: string; stream: boolean; + numPreviousTries: number; }; -const MAX_AUTO_RETRIES = 10; +const MAX_AUTO_RETRIES = 50; const MIN_DELAY = 500; // milliseconds const MAX_DELAY = 15000; // milliseconds @@ -26,7 +26,7 @@ function calculateDelay(numPreviousTries: number): number { export const queryModel = defineTask("queryModel", async (task) => { console.log("RUNNING TASK", task); - const { cellId, stream } = task; + const { cellId, stream, numPreviousTries } = task; const cell = await prisma.scenarioVariantCell.findUnique({ where: { id: cellId }, include: { modelResponses: true }, @@ -98,62 +98,72 @@ export const queryModel = defineTask("queryModel", async (task) = const inputHash = hashPrompt(prompt); - for (let i = 0; true; i++) { - let modelResponse = await prisma.modelResponse.create({ + let modelResponse = await prisma.modelResponse.create({ + data: { + inputHash, + scenarioVariantCellId: cellId, + requestedAt: new Date(), + }, + }); + const response = await provider.getCompletion(prompt.modelInput, onStream); + if (response.type === "success") { + modelResponse = await prisma.modelResponse.update({ + where: { id: modelResponse.id }, data: { - inputHash, - scenarioVariantCellId: cellId, - requestedAt: new Date(), + output: response.value as Prisma.InputJsonObject, + statusCode: response.statusCode, + receivedAt: new Date(), + promptTokens: response.promptTokens, + completionTokens: response.completionTokens, + cost: response.cost, }, }); - const response = await provider.getCompletion(prompt.modelInput, onStream); - if (response.type === "success") { - modelResponse = await prisma.modelResponse.update({ - where: { id: modelResponse.id }, - data: { - output: response.value as Prisma.InputJsonObject, - statusCode: response.statusCode, - receivedAt: new Date(), - promptTokens: response.promptTokens, - completionTokens: response.completionTokens, - cost: response.cost, - }, - }); + await prisma.scenarioVariantCell.update({ + where: { id: cellId }, + data: { + retrievalStatus: "COMPLETE", + }, + }); + + await runEvalsForOutput(variant.experimentId, scenario, modelResponse, prompt.modelProvider); + } else { + const shouldRetry = response.autoRetry && numPreviousTries < MAX_AUTO_RETRIES; + const delay = calculateDelay(numPreviousTries); + const retryTime = new Date(Date.now() + delay); + + await prisma.modelResponse.update({ + where: { id: modelResponse.id }, + data: { + statusCode: response.statusCode, + errorMessage: response.message, + receivedAt: new Date(), + retryTime: shouldRetry ? retryTime : null, + }, + }); + + if (shouldRetry) { + await queryModel.enqueue( + { + cellId, + stream, + numPreviousTries: numPreviousTries + 1, + }, + retryTime, + ); await prisma.scenarioVariantCell.update({ where: { id: cellId }, data: { - retrievalStatus: "COMPLETE", + retrievalStatus: "PENDING", }, }); - - await runEvalsForOutput(variant.experimentId, scenario, modelResponse, prompt.modelProvider); - break; } else { - const shouldRetry = response.autoRetry && i < MAX_AUTO_RETRIES; - const delay = calculateDelay(i); - - await prisma.modelResponse.update({ - where: { id: modelResponse.id }, + await prisma.scenarioVariantCell.update({ + where: { id: cellId }, data: { - statusCode: response.statusCode, - errorMessage: response.message, - receivedAt: new Date(), - retryTime: shouldRetry ? new Date(Date.now() + delay) : null, + retrievalStatus: "ERROR", }, }); - - if (shouldRetry) { - await sleep(delay); - } else { - await prisma.scenarioVariantCell.update({ - where: { id: cellId }, - data: { - retrievalStatus: "ERROR", - }, - }); - break; - } } } }); @@ -170,6 +180,6 @@ export const queueQueryModel = async (cellId: string, stream: boolean) => { jobQueuedAt: new Date(), }, }), - queryModel.enqueue({ cellId, stream }), + queryModel.enqueue({ cellId, stream, numPreviousTries: 0 }), ]); }; diff --git a/src/server/tasks/runNewEval.task.ts b/src/server/tasks/runNewEval.task.ts new file mode 100644 index 0000000..2da5ba4 --- /dev/null +++ b/src/server/tasks/runNewEval.task.ts @@ -0,0 +1,17 @@ +import { runAllEvals } from "../utils/evaluations"; +import defineTask from "./defineTask"; + +export type RunNewEvalJob = { + experimentId: string; +}; + +// When a new eval is created, we want to run it on all existing outputs, but return the new eval first +export const runNewEval = defineTask("runNewEval", async (task) => { + console.log("RUNNING TASK", task); + const { experimentId } = task; + await runAllEvals(experimentId); +}); + +export const queueRunNewEval = async (experimentId: string) => { + await runNewEval.enqueue({ experimentId }); +}; diff --git a/src/server/tasks/worker.ts b/src/server/tasks/worker.ts index 76c9e8b..74be92f 100644 --- a/src/server/tasks/worker.ts +++ b/src/server/tasks/worker.ts @@ -3,10 +3,11 @@ import "dotenv/config"; import { env } from "~/env.mjs"; import { queryModel } from "./queryModel.task"; +import { runNewEval } from "./runNewEval.task"; console.log("Starting worker"); -const registeredTasks = [queryModel]; +const registeredTasks = [queryModel, runNewEval]; const taskList = registeredTasks.reduce((acc, task) => { acc[task.task.identifier] = task.task.handler; @@ -16,7 +17,7 @@ const taskList = registeredTasks.reduce((acc, task) => { // Run a worker to execute jobs: const runner = await run({ connectionString: env.DATABASE_URL, - concurrency: 20, + concurrency: 50, // Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc noHandleSignals: false, pollInterval: 1000, diff --git a/src/server/utils/evaluations.ts b/src/server/utils/evaluations.ts index e4a7086..9259d91 100644 --- a/src/server/utils/evaluations.ts +++ b/src/server/utils/evaluations.ts @@ -46,6 +46,7 @@ export const runEvalsForOutput = async ( ); }; +// Will not run eval-output pairs that already exist in the database export const runAllEvals = async (experimentId: string) => { const outputs = await prisma.modelResponse.findMany({ where: {