diff --git a/app/src/components/OutputsTable/OutputCell/OutputCell.tsx b/app/src/components/OutputsTable/OutputCell/OutputCell.tsx index 50af4bd..f3da232 100644 --- a/app/src/components/OutputsTable/OutputCell/OutputCell.tsx +++ b/app/src/components/OutputsTable/OutputCell/OutputCell.tsx @@ -43,7 +43,7 @@ export default function OutputCell({ type OutputSchema = Parameters[0]; - const { mutateAsync: hardRefetchMutate } = api.scenarioVariantCells.forceRefetch.useMutation(); + const { mutateAsync: hardRefetchMutate } = api.scenarioVariantCells.hardRefetch.useMutation(); const [hardRefetch, hardRefetching] = useHandledAsyncCallback(async () => { await hardRefetchMutate({ scenarioId: scenario.id, variantId: variant.id }); await utils.scenarioVariantCells.get.invalidate({ diff --git a/app/src/server/api/routers/scenarioVariantCells.router.ts b/app/src/server/api/routers/scenarioVariantCells.router.ts index 8e4c83b..ab677d1 100644 --- a/app/src/server/api/routers/scenarioVariantCells.router.ts +++ b/app/src/server/api/routers/scenarioVariantCells.router.ts @@ -61,7 +61,7 @@ export const scenarioVariantCellsRouter = createTRPCRouter({ evalsComplete, }; }), - forceRefetch: protectedProcedure + hardRefetch: protectedProcedure .input( z.object({ scenarioId: z.string(), @@ -85,7 +85,10 @@ export const scenarioVariantCellsRouter = createTRPCRouter({ }); if (!cell) { - await generateNewCell(input.variantId, input.scenarioId, { stream: true }); + await generateNewCell(input.variantId, input.scenarioId, { + stream: true, + hardRefetch: true, + }); return; } @@ -96,7 +99,7 @@ export const scenarioVariantCellsRouter = createTRPCRouter({ }, }); - await queueQueryModel(cell.id, true); + await queueQueryModel(cell.id, { stream: true, hardRefetch: true }); }), getTemplatedPromptMessage: publicProcedure .input( diff --git a/app/src/server/tasks/queryModel.task.ts b/app/src/server/tasks/queryModel.task.ts index 081bc2c..750ac51 100644 --- a/app/src/server/tasks/queryModel.task.ts +++ b/app/src/server/tasks/queryModel.task.ts @@ -25,7 +25,6 @@ function calculateDelay(numPreviousTries: number): number { } export const queryModel = defineTask("queryModel", async (task) => { - console.log("RUNNING TASK", task); const { cellId, stream, numPreviousTries } = task; const cell = await prisma.scenarioVariantCell.findUnique({ where: { id: cellId }, @@ -153,7 +152,7 @@ export const queryModel = defineTask("queryModel", async (task) = stream, numPreviousTries: numPreviousTries + 1, }, - { runAt: retryTime, jobKey: cellId }, + { runAt: retryTime, jobKey: cellId, priority: 3 }, ); await prisma.scenarioVariantCell.update({ where: { id: cellId }, @@ -172,7 +171,13 @@ export const queryModel = defineTask("queryModel", async (task) = } }); -export const queueQueryModel = async (cellId: string, stream: boolean) => { +export const queueQueryModel = async ( + cellId: string, + options: { stream?: boolean; hardRefetch?: boolean } = {}, +) => { + // Hard refetches are higher priority than streamed queries, which are higher priority than non-streamed queries. + const jobPriority = options.hardRefetch ? 0 : options.stream ? 1 : 2; + await Promise.all([ prisma.scenarioVariantCell.update({ where: { @@ -184,6 +189,13 @@ export const queueQueryModel = async (cellId: string, stream: boolean) => { jobQueuedAt: new Date(), }, }), - queryModel.enqueue({ cellId, stream, numPreviousTries: 0 }, { jobKey: cellId }), + + queryModel.enqueue( + { cellId, stream: options.stream ?? false, numPreviousTries: 0 }, + + // Streamed queries are higher priority than non-streamed queries. Lower + // numbers are higher priority in graphile-worker. + { jobKey: cellId, priority: jobPriority }, + ), ]); }; diff --git a/app/src/server/tasks/runNewEval.task.ts b/app/src/server/tasks/runNewEval.task.ts index 2da5ba4..83ae2e3 100644 --- a/app/src/server/tasks/runNewEval.task.ts +++ b/app/src/server/tasks/runNewEval.task.ts @@ -13,5 +13,6 @@ export const runNewEval = defineTask("runNewEval", async (task) = }); export const queueRunNewEval = async (experimentId: string) => { - await runNewEval.enqueue({ experimentId }); + // Evals are lower priority than completions + await runNewEval.enqueue({ experimentId }, { priority: 4 }); }; diff --git a/app/src/server/utils/generateNewCell.ts b/app/src/server/utils/generateNewCell.ts index 678740d..b1221f9 100644 --- a/app/src/server/utils/generateNewCell.ts +++ b/app/src/server/utils/generateNewCell.ts @@ -9,10 +9,8 @@ import parsePromptConstructor from "~/promptConstructor/parse"; export const generateNewCell = async ( variantId: string, scenarioId: string, - options?: { stream?: boolean }, + options: { stream?: boolean; hardRefetch?: boolean } = {}, ): Promise => { - const stream = options?.stream ?? false; - const variant = await prisma.promptVariant.findUnique({ where: { id: variantId, @@ -121,6 +119,6 @@ export const generateNewCell = async ( }), ); } else { - await queueQueryModel(cell.id, stream); + await queueQueryModel(cell.id, options); } };