diff --git a/package.json b/package.json index 68c45b2..5d8806e 100644 --- a/package.json +++ b/package.json @@ -12,7 +12,7 @@ "dev:next": "next dev", "dev:wss": "pnpm tsx --watch src/wss-server.ts", "dev:worker": "NODE_ENV='development' pnpm tsx --watch src/server/tasks/worker.ts", - "dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss'", + "dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss' 'pnpm dev:worker'", "postinstall": "prisma generate", "lint": "next lint", "start": "next start", diff --git a/prisma/migrations/20230725005817_use_id_as_streaming_channel/migration.sql b/prisma/migrations/20230725005817_use_id_as_streaming_channel/migration.sql new file mode 100644 index 0000000..e87e298 --- /dev/null +++ b/prisma/migrations/20230725005817_use_id_as_streaming_channel/migration.sql @@ -0,0 +1,8 @@ +/* + Warnings: + + - You are about to drop the column `streamingChannel` on the `ScenarioVariantCell` table. All the data in the column will be lost. + +*/ +-- AlterTable +ALTER TABLE "ScenarioVariantCell" DROP COLUMN "streamingChannel"; diff --git a/prisma/schema.prisma b/prisma/schema.prisma index 2cc2d86..ca11195 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -90,11 +90,10 @@ enum CellRetrievalStatus { model ScenarioVariantCell { id String @id @default(uuid()) @db.Uuid - statusCode Int? - errorMessage String? - retryTime DateTime? - streamingChannel String? - retrievalStatus CellRetrievalStatus @default(COMPLETE) + statusCode Int? + errorMessage String? + retryTime DateTime? + retrievalStatus CellRetrievalStatus @default(COMPLETE) modelOutput ModelOutput? diff --git a/prisma/seed.ts b/prisma/seed.ts index abdc071..adc39e3 100644 --- a/prisma/seed.ts +++ b/prisma/seed.ts @@ -164,5 +164,5 @@ await Promise.all( testScenarioId: scenario.id, })), ) - .map((cell) => generateNewCell(cell.promptVariantId, cell.testScenarioId)), + .map((cell) => generateNewCell(cell.promptVariantId, cell.testScenarioId, { stream: false })), ); diff --git a/run-prod.sh b/run-prod.sh index 1feb207..30ec09c 100755 --- a/run-prod.sh +++ b/run-prod.sh @@ -6,4 +6,7 @@ echo "Migrating the database" pnpm prisma migrate deploy echo "Starting the server" -pnpm start \ No newline at end of file + +pnpm concurrently --kill-others \ + "pnpm start" \ + "pnpm tsx src/server/tasks/worker.ts" \ No newline at end of file diff --git a/src/components/ChangeModelModal/ChangeModelModal.tsx b/src/components/ChangeModelModal/ChangeModelModal.tsx index 15ff507..c59391c 100644 --- a/src/components/ChangeModelModal/ChangeModelModal.tsx +++ b/src/components/ChangeModelModal/ChangeModelModal.tsx @@ -19,7 +19,7 @@ import { useState } from "react"; import { RiExchangeFundsFill } from "react-icons/ri"; import { type ProviderModel } from "~/modelProviders/types"; import { api } from "~/utils/api"; -import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; +import { useExperiment, useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks"; import { lookupModel, modelLabel } from "~/utils/utils"; import CompareFunctions from "../RefinePromptModal/CompareFunctions"; import { ModelSearch } from "./ModelSearch"; @@ -38,6 +38,7 @@ export const ChangeModelModal = ({ model: variant.model, } as ProviderModel); const [convertedModel, setConvertedModel] = useState(); + const visibleScenarios = useVisibleScenarioIds(); const utils = api.useContext(); @@ -68,6 +69,7 @@ export const ChangeModelModal = ({ await replaceVariantMutation.mutateAsync({ id: variant.id, constructFn: modifiedPromptFn, + streamScenarios: visibleScenarios, }); await utils.promptVariants.list.invalidate(); onClose(); diff --git a/src/components/OutputsTable/AddVariantButton.tsx b/src/components/OutputsTable/AddVariantButton.tsx index 7117d33..6e76d9d 100644 --- a/src/components/OutputsTable/AddVariantButton.tsx +++ b/src/components/OutputsTable/AddVariantButton.tsx @@ -2,7 +2,12 @@ import { Box, Flex, Icon, Spinner } from "@chakra-ui/react"; import { BsPlus } from "react-icons/bs"; import { Text } from "@chakra-ui/react"; import { api } from "~/utils/api"; -import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks"; +import { + useExperiment, + useExperimentAccess, + useHandledAsyncCallback, + useVisibleScenarioIds, +} from "~/utils/hooks"; import { cellPadding } from "../constants"; import { ActionButton } from "./ScenariosHeader"; @@ -10,11 +15,13 @@ export default function AddVariantButton() { const experiment = useExperiment(); const mutation = api.promptVariants.create.useMutation(); const utils = api.useContext(); + const visibleScenarios = useVisibleScenarioIds(); const [onClick, loading] = useHandledAsyncCallback(async () => { if (!experiment.data) return; await mutation.mutateAsync({ experimentId: experiment.data.id, + streamScenarios: visibleScenarios, }); await utils.promptVariants.list.invalidate(); }, [mutation]); diff --git a/src/components/OutputsTable/OutputCell/OutputCell.tsx b/src/components/OutputsTable/OutputCell/OutputCell.tsx index 5cf288b..e8450ef 100644 --- a/src/components/OutputsTable/OutputCell/OutputCell.tsx +++ b/src/components/OutputsTable/OutputCell/OutputCell.tsx @@ -67,8 +67,8 @@ export default function OutputCell({ const modelOutput = cell?.modelOutput; - // Disconnect from socket if we're not streaming anymore - const streamedMessage = useSocket(cell?.streamingChannel); + // TODO: disconnect from socket if we're not streaming anymore + const streamedMessage = useSocket(cell?.id); if (!vars) return null; diff --git a/src/components/OutputsTable/ScenariosHeader.tsx b/src/components/OutputsTable/ScenariosHeader.tsx index 4ec68bf..6dacf74 100644 --- a/src/components/OutputsTable/ScenariosHeader.tsx +++ b/src/components/OutputsTable/ScenariosHeader.tsx @@ -54,13 +54,13 @@ export const ScenariosHeader = () => { {canModify && ( - - } - /> - + } + /> } diff --git a/src/components/OutputsTable/VariantEditor.tsx b/src/components/OutputsTable/VariantEditor.tsx index 7e80c7f..aad1644 100644 --- a/src/components/OutputsTable/VariantEditor.tsx +++ b/src/components/OutputsTable/VariantEditor.tsx @@ -2,19 +2,24 @@ import { Box, Button, HStack, + IconButton, Spinner, + Text, Tooltip, useToast, - Text, - IconButton, } from "@chakra-ui/react"; -import { useRef, useEffect, useState, useCallback } from "react"; -import { useExperimentAccess, useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks"; -import { type PromptVariant } from "./types"; -import { api } from "~/utils/api"; -import { useAppStore } from "~/state/store"; +import { useCallback, useEffect, useRef, useState } from "react"; import { FiMaximize, FiMinimize } from "react-icons/fi"; import { editorBackground } from "~/state/sharedVariantEditor.slice"; +import { useAppStore } from "~/state/store"; +import { api } from "~/utils/api"; +import { + useExperimentAccess, + useHandledAsyncCallback, + useModifierKeyLabel, + useVisibleScenarioIds, +} from "~/utils/hooks"; +import { type PromptVariant } from "./types"; export default function VariantEditor(props: { variant: PromptVariant }) { const { canModify } = useExperimentAccess(); @@ -63,6 +68,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) { const replaceVariant = api.promptVariants.replaceVariant.useMutation(); const utils = api.useContext(); const toast = useToast(); + const visibleScenarios = useVisibleScenarioIds(); const [onSave, saveInProgress] = useHandledAsyncCallback(async () => { if (!editorRef.current) return; @@ -91,6 +97,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) { const resp = await replaceVariant.mutateAsync({ id: props.variant.id, constructFn: currentFn, + streamScenarios: visibleScenarios, }); if (resp.status === "error") { return toast({ diff --git a/src/components/RefinePromptModal/RefinePromptModal.tsx b/src/components/RefinePromptModal/RefinePromptModal.tsx index 4104dd3..49bdf0e 100644 --- a/src/components/RefinePromptModal/RefinePromptModal.tsx +++ b/src/components/RefinePromptModal/RefinePromptModal.tsx @@ -16,7 +16,7 @@ import { } from "@chakra-ui/react"; import { BsStars } from "react-icons/bs"; import { api } from "~/utils/api"; -import { useHandledAsyncCallback } from "~/utils/hooks"; +import { useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks"; import { type PromptVariant } from "@prisma/client"; import { useState } from "react"; import CompareFunctions from "./CompareFunctions"; @@ -34,6 +34,7 @@ export const RefinePromptModal = ({ onClose: () => void; }) => { const utils = api.useContext(); + const visibleScenarios = useVisibleScenarioIds(); const refinementActions = frontendModelProviders[variant.modelProvider as SupportedProvider].refinementActions || {}; @@ -73,6 +74,7 @@ export const RefinePromptModal = ({ await replaceVariantMutation.mutateAsync({ id: variant.id, constructFn: refinedPromptFn, + streamScenarios: visibleScenarios, }); await utils.promptVariants.list.invalidate(); onClose(); diff --git a/src/components/VariantHeader/VariantHeaderMenuButton.tsx b/src/components/VariantHeader/VariantHeaderMenuButton.tsx index 3ddfee4..add2757 100644 --- a/src/components/VariantHeader/VariantHeaderMenuButton.tsx +++ b/src/components/VariantHeader/VariantHeaderMenuButton.tsx @@ -1,8 +1,7 @@ import { type PromptVariant } from "../OutputsTable/types"; import { api } from "~/utils/api"; -import { useHandledAsyncCallback } from "~/utils/hooks"; +import { useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks"; import { - Button, Icon, Menu, MenuButton, @@ -11,6 +10,7 @@ import { MenuDivider, Text, Spinner, + IconButton, } from "@chakra-ui/react"; import { BsFillTrashFill, BsGear, BsStars } from "react-icons/bs"; import { FaRegClone } from "react-icons/fa"; @@ -33,11 +33,13 @@ export default function VariantHeaderMenuButton({ const utils = api.useContext(); const duplicateMutation = api.promptVariants.create.useMutation(); + const visibleScenarios = useVisibleScenarioIds(); const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => { await duplicateMutation.mutateAsync({ experimentId: variant.experimentId, variantId: variant.id, + streamScenarios: visibleScenarios, }); await utils.promptVariants.list.invalidate(); }, [duplicateMutation, variant.experimentId, variant.id]); @@ -56,15 +58,12 @@ export default function VariantHeaderMenuButton({ return ( <> setMenuOpen(true)} onClose={() => setMenuOpen(false)}> - {duplicationInProgress ? ( - - ) : ( - - - - )} + } + /> } onClick={duplicateVariant}> diff --git a/src/modelProviders/openai-ChatCompletion/index.ts b/src/modelProviders/openai-ChatCompletion/index.ts index 9aa882c..2b4e90c 100644 --- a/src/modelProviders/openai-ChatCompletion/index.ts +++ b/src/modelProviders/openai-ChatCompletion/index.ts @@ -37,7 +37,7 @@ const modelProvider: OpenaiChatModelProvider = { return null; }, inputSchema: inputSchema as JSONSchema4, - shouldStream: (input) => input.stream ?? false, + canStream: true, getCompletion, ...frontendModelProvider, }; diff --git a/src/modelProviders/replicate-llama2/getCompletion.ts b/src/modelProviders/replicate-llama2/getCompletion.ts index c95bf4c..ab26a13 100644 --- a/src/modelProviders/replicate-llama2/getCompletion.ts +++ b/src/modelProviders/replicate-llama2/getCompletion.ts @@ -19,7 +19,7 @@ export async function getCompletion( ): Promise> { const start = Date.now(); - const { model, stream, ...rest } = input; + const { model, ...rest } = input; try { const prediction = await replicate.predictions.create({ diff --git a/src/modelProviders/replicate-llama2/index.ts b/src/modelProviders/replicate-llama2/index.ts index 1907a65..786bf33 100644 --- a/src/modelProviders/replicate-llama2/index.ts +++ b/src/modelProviders/replicate-llama2/index.ts @@ -9,7 +9,6 @@ export type SupportedModel = (typeof supportedModels)[number]; export type ReplicateLlama2Input = { model: SupportedModel; prompt: string; - stream?: boolean; max_length?: number; temperature?: number; top_p?: number; @@ -47,10 +46,6 @@ const modelProvider: ReplicateLlama2Provider = { type: "string", description: "Prompt to send to Llama v2.", }, - stream: { - type: "boolean", - description: "Whether to stream output from Llama v2.", - }, max_new_tokens: { type: "number", description: @@ -78,7 +73,7 @@ const modelProvider: ReplicateLlama2Provider = { }, required: ["model", "prompt"], }, - shouldStream: (input) => input.stream ?? false, + canStream: true, getCompletion, ...frontendModelProvider, }; diff --git a/src/modelProviders/types.ts b/src/modelProviders/types.ts index 87111f2..a0b3154 100644 --- a/src/modelProviders/types.ts +++ b/src/modelProviders/types.ts @@ -48,7 +48,7 @@ export type CompletionResponse = export type ModelProvider = { getModel: (input: InputSchema) => SupportedModels | null; - shouldStream: (input: InputSchema) => boolean; + canStream: boolean; inputSchema: JSONSchema4; getCompletion: ( input: InputSchema, diff --git a/src/server/api/routers/promptVariants.router.ts b/src/server/api/routers/promptVariants.router.ts index b6a32ef..8832dc4 100644 --- a/src/server/api/routers/promptVariants.router.ts +++ b/src/server/api/routers/promptVariants.router.ts @@ -145,6 +145,7 @@ export const promptVariantsRouter = createTRPCRouter({ z.object({ experimentId: z.string(), variantId: z.string().optional(), + streamScenarios: z.array(z.string()), }), ) .mutation(async ({ input, ctx }) => { @@ -218,7 +219,9 @@ export const promptVariantsRouter = createTRPCRouter({ }); for (const scenario of scenarios) { - await generateNewCell(newVariant.id, scenario.id); + await generateNewCell(newVariant.id, scenario.id, { + stream: input.streamScenarios.includes(scenario.id), + }); } return newVariant; @@ -325,6 +328,7 @@ export const promptVariantsRouter = createTRPCRouter({ z.object({ id: z.string(), constructFn: z.string(), + streamScenarios: z.array(z.string()), }), ) .mutation(async ({ input, ctx }) => { @@ -382,7 +386,9 @@ export const promptVariantsRouter = createTRPCRouter({ }); for (const scenario of scenarios) { - await generateNewCell(newVariant.id, scenario.id); + await generateNewCell(newVariant.id, scenario.id, { + stream: input.streamScenarios.includes(scenario.id), + }); } return { status: "ok" } as const; diff --git a/src/server/api/routers/scenarioVariantCells.router.ts b/src/server/api/routers/scenarioVariantCells.router.ts index ed8a34d..b0dbecb 100644 --- a/src/server/api/routers/scenarioVariantCells.router.ts +++ b/src/server/api/routers/scenarioVariantCells.router.ts @@ -1,8 +1,8 @@ import { z } from "zod"; import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc"; import { prisma } from "~/server/db"; +import { queueQueryModel } from "~/server/tasks/queryModel.task"; import { generateNewCell } from "~/server/utils/generateNewCell"; -import { queueLLMRetrievalTask } from "~/server/utils/queueLLMRetrievalTask"; import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl"; export const scenarioVariantCellsRouter = createTRPCRouter({ @@ -62,14 +62,12 @@ export const scenarioVariantCellsRouter = createTRPCRouter({ testScenarioId: input.scenarioId, }, }, - include: { - modelOutput: true, - }, + include: { modelOutput: true }, }); if (!cell) { - await generateNewCell(input.variantId, input.scenarioId); - return true; + await generateNewCell(input.variantId, input.scenarioId, { stream: true }); + return; } if (cell.modelOutput) { @@ -79,12 +77,6 @@ export const scenarioVariantCellsRouter = createTRPCRouter({ }); } - await prisma.scenarioVariantCell.update({ - where: { id: cell.id }, - data: { retrievalStatus: "PENDING" }, - }); - - await queueLLMRetrievalTask(cell.id); - return true; + await queueQueryModel(cell.id, true); }), }); diff --git a/src/server/api/routers/scenarios.router.ts b/src/server/api/routers/scenarios.router.ts index d3eafb4..79b65af 100644 --- a/src/server/api/routers/scenarios.router.ts +++ b/src/server/api/routers/scenarios.router.ts @@ -86,7 +86,7 @@ export const scenariosRouter = createTRPCRouter({ }); for (const variant of promptVariants) { - await generateNewCell(variant.id, scenario.id); + await generateNewCell(variant.id, scenario.id, { stream: true }); } }), @@ -230,7 +230,7 @@ export const scenariosRouter = createTRPCRouter({ }); for (const variant of promptVariants) { - await generateNewCell(variant.id, newScenario.id); + await generateNewCell(variant.id, newScenario.id, { stream: true }); } return newScenario; diff --git a/src/server/tasks/queryLLM.task.ts b/src/server/tasks/queryModel.task.ts similarity index 78% rename from src/server/tasks/queryLLM.task.ts rename to src/server/tasks/queryModel.task.ts index 9afb050..d9c427f 100644 --- a/src/server/tasks/queryLLM.task.ts +++ b/src/server/tasks/queryModel.task.ts @@ -1,17 +1,17 @@ -import { prisma } from "~/server/db"; -import defineTask from "./defineTask"; -import { sleep } from "../utils/sleep"; -import { generateChannel } from "~/utils/generateChannel"; -import { runEvalsForOutput } from "../utils/evaluations"; import { type Prisma } from "@prisma/client"; -import parseConstructFn from "../utils/parseConstructFn"; -import hashPrompt from "../utils/hashPrompt"; import { type JsonObject } from "type-fest"; import modelProviders from "~/modelProviders/modelProviders"; +import { prisma } from "~/server/db"; import { wsConnection } from "~/utils/wsConnection"; +import { runEvalsForOutput } from "../utils/evaluations"; +import hashPrompt from "../utils/hashPrompt"; +import parseConstructFn from "../utils/parseConstructFn"; +import { sleep } from "../utils/sleep"; +import defineTask from "./defineTask"; -export type queryLLMJob = { - scenarioVariantCellId: string; +export type QueryModelJob = { + cellId: string; + stream: boolean; }; const MAX_AUTO_RETRIES = 10; @@ -24,15 +24,16 @@ function calculateDelay(numPreviousTries: number): number { return baseDelay + jitter; } -export const queryLLM = defineTask("queryLLM", async (task) => { - const { scenarioVariantCellId } = task; +export const queryModel = defineTask("queryModel", async (task) => { + console.log("RUNNING TASK", task); + const { cellId, stream } = task; const cell = await prisma.scenarioVariantCell.findUnique({ - where: { id: scenarioVariantCellId }, + where: { id: cellId }, include: { modelOutput: true }, }); if (!cell) { await prisma.scenarioVariantCell.update({ - where: { id: scenarioVariantCellId }, + where: { id: cellId }, data: { statusCode: 404, errorMessage: "Cell not found", @@ -47,7 +48,7 @@ export const queryLLM = defineTask("queryLLM", async (task) => { return; } await prisma.scenarioVariantCell.update({ - where: { id: scenarioVariantCellId }, + where: { id: cellId }, data: { retrievalStatus: "IN_PROGRESS", }, @@ -58,7 +59,7 @@ export const queryLLM = defineTask("queryLLM", async (task) => { }); if (!variant) { await prisma.scenarioVariantCell.update({ - where: { id: scenarioVariantCellId }, + where: { id: cellId }, data: { statusCode: 404, errorMessage: "Prompt Variant not found", @@ -73,7 +74,7 @@ export const queryLLM = defineTask("queryLLM", async (task) => { }); if (!scenario) { await prisma.scenarioVariantCell.update({ - where: { id: scenarioVariantCellId }, + where: { id: cellId }, data: { statusCode: 404, errorMessage: "Scenario not found", @@ -87,7 +88,7 @@ export const queryLLM = defineTask("queryLLM", async (task) => { if ("error" in prompt) { await prisma.scenarioVariantCell.update({ - where: { id: scenarioVariantCellId }, + where: { id: cellId }, data: { statusCode: 400, errorMessage: prompt.error, @@ -99,18 +100,9 @@ export const queryLLM = defineTask("queryLLM", async (task) => { const provider = modelProviders[prompt.modelProvider]; - const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null; - - if (streamingChannel) { - // Save streaming channel so that UI can connect to it - await prisma.scenarioVariantCell.update({ - where: { id: scenarioVariantCellId }, - data: { streamingChannel }, - }); - } - const onStream = streamingChannel + const onStream = stream ? (partialOutput: (typeof provider)["_outputSchema"]) => { - wsConnection.emit("message", { channel: streamingChannel, payload: partialOutput }); + wsConnection.emit("message", { channel: cell.id, payload: partialOutput }); } : null; @@ -121,7 +113,7 @@ export const queryLLM = defineTask("queryLLM", async (task) => { const modelOutput = await prisma.modelOutput.create({ data: { - scenarioVariantCellId, + scenarioVariantCellId: cellId, inputHash, output: response.value as Prisma.InputJsonObject, timeToComplete: response.timeToComplete, @@ -132,7 +124,7 @@ export const queryLLM = defineTask("queryLLM", async (task) => { }); await prisma.scenarioVariantCell.update({ - where: { id: scenarioVariantCellId }, + where: { id: cellId }, data: { statusCode: response.statusCode, retrievalStatus: "COMPLETE", @@ -146,7 +138,7 @@ export const queryLLM = defineTask("queryLLM", async (task) => { const delay = calculateDelay(i); await prisma.scenarioVariantCell.update({ - where: { id: scenarioVariantCellId }, + where: { id: cellId }, data: { errorMessage: response.message, statusCode: response.statusCode, @@ -163,3 +155,21 @@ export const queryLLM = defineTask("queryLLM", async (task) => { } } }); + +export const queueQueryModel = async (cellId: string, stream: boolean) => { + console.log("queueQueryModel", cellId, stream); + await Promise.all([ + prisma.scenarioVariantCell.update({ + where: { + id: cellId, + }, + data: { + retrievalStatus: "PENDING", + errorMessage: null, + }, + }), + + await queryModel.enqueue({ cellId, stream }), + console.log("queued"), + ]); +}; diff --git a/src/server/tasks/worker.ts b/src/server/tasks/worker.ts index e2fc916..76c9e8b 100644 --- a/src/server/tasks/worker.ts +++ b/src/server/tasks/worker.ts @@ -2,39 +2,27 @@ import { type TaskList, run } from "graphile-worker"; import "dotenv/config"; import { env } from "~/env.mjs"; -import { queryLLM } from "./queryLLM.task"; +import { queryModel } from "./queryModel.task"; -const registeredTasks = [queryLLM]; +console.log("Starting worker"); + +const registeredTasks = [queryModel]; const taskList = registeredTasks.reduce((acc, task) => { acc[task.task.identifier] = task.task.handler; return acc; }, {} as TaskList); -async function main() { - // Run a worker to execute jobs: - const runner = await run({ - connectionString: env.DATABASE_URL, - concurrency: 20, - // Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc - noHandleSignals: false, - pollInterval: 1000, - // you can set the taskList or taskDirectory but not both - taskList, - // or: - // taskDirectory: `${__dirname}/tasks`, - }); - - // Immediately await (or otherwise handled) the resulting promise, to avoid - // "unhandled rejection" errors causing a process crash in the event of - // something going wrong. - await runner.promise; - - // If the worker exits (whether through fatal error or otherwise), the above - // promise will resolve/reject. -} - -main().catch((err) => { - console.error("Unhandled error occurred running worker: ", err); - process.exit(1); +// Run a worker to execute jobs: +const runner = await run({ + connectionString: env.DATABASE_URL, + concurrency: 20, + // Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc + noHandleSignals: false, + pollInterval: 1000, + taskList, }); + +console.log("Worker successfully started"); + +await runner.promise; diff --git a/src/server/utils/generateNewCell.ts b/src/server/utils/generateNewCell.ts index ae5276d..0c4c2be 100644 --- a/src/server/utils/generateNewCell.ts +++ b/src/server/utils/generateNewCell.ts @@ -1,12 +1,18 @@ import { type Prisma } from "@prisma/client"; import { prisma } from "../db"; -import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask"; import parseConstructFn from "./parseConstructFn"; import { type JsonObject } from "type-fest"; import hashPrompt from "./hashPrompt"; import { omit } from "lodash-es"; +import { queueQueryModel } from "../tasks/queryModel.task"; + +export const generateNewCell = async ( + variantId: string, + scenarioId: string, + options?: { stream?: boolean }, +): Promise => { + const stream = options?.stream ?? false; -export const generateNewCell = async (variantId: string, scenarioId: string): Promise => { const variant = await prisma.promptVariant.findUnique({ where: { id: variantId, @@ -98,6 +104,6 @@ export const generateNewCell = async (variantId: string, scenarioId: string): Pr }), ); } else { - cell = await queueLLMRetrievalTask(cell.id); + await queueQueryModel(cell.id, stream); } }; diff --git a/src/server/utils/queueLLMRetrievalTask.ts b/src/server/utils/queueLLMRetrievalTask.ts deleted file mode 100644 index 762b708..0000000 --- a/src/server/utils/queueLLMRetrievalTask.ts +++ /dev/null @@ -1,22 +0,0 @@ -import { prisma } from "../db"; -import { queryLLM } from "../tasks/queryLLM.task"; - -export const queueLLMRetrievalTask = async (cellId: string) => { - const updatedCell = await prisma.scenarioVariantCell.update({ - where: { - id: cellId, - }, - data: { - retrievalStatus: "PENDING", - errorMessage: null, - }, - include: { - modelOutput: true, - }, - }); - - // @ts-expect-error we aren't passing the helpers but that's ok - void queryLLM.task.handler({ scenarioVariantCellId: cellId }, { logger: console }); - - return updatedCell; -}; diff --git a/src/utils/generateChannel.ts b/src/utils/generateChannel.ts deleted file mode 100644 index c7e7c6b..0000000 --- a/src/utils/generateChannel.ts +++ /dev/null @@ -1,5 +0,0 @@ -// generate random channel id - -export const generateChannel = () => { - return Math.random().toString(36).substring(2, 15) + Math.random().toString(36).substring(2, 15); -}; diff --git a/src/utils/hooks.ts b/src/utils/hooks.ts index 0cd5c39..3d27d55 100644 --- a/src/utils/hooks.ts +++ b/src/utils/hooks.ts @@ -106,3 +106,5 @@ export const useScenarios = () => { { enabled: experiment.data?.id != null }, ); }; + +export const useVisibleScenarioIds = () => useScenarios().data?.scenarios.map((s) => s.id) ?? [];