Better streaming

- Always stream the visible scenarios, if the modelProvider supports it
 - Never stream the invisible scenarios

Also actually runs our query tasks in a background worker, which we weren't quite doing before.
This commit is contained in:
Kyle Corbitt
2023-07-24 18:34:30 -07:00
parent d6b97b29f7
commit e1cbeccb90
25 changed files with 152 additions and 153 deletions

View File

@@ -12,7 +12,7 @@
"dev:next": "next dev", "dev:next": "next dev",
"dev:wss": "pnpm tsx --watch src/wss-server.ts", "dev:wss": "pnpm tsx --watch src/wss-server.ts",
"dev:worker": "NODE_ENV='development' pnpm tsx --watch src/server/tasks/worker.ts", "dev:worker": "NODE_ENV='development' pnpm tsx --watch src/server/tasks/worker.ts",
"dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss'", "dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss' 'pnpm dev:worker'",
"postinstall": "prisma generate", "postinstall": "prisma generate",
"lint": "next lint", "lint": "next lint",
"start": "next start", "start": "next start",

View File

@@ -0,0 +1,8 @@
/*
Warnings:
- You are about to drop the column `streamingChannel` on the `ScenarioVariantCell` table. All the data in the column will be lost.
*/
-- AlterTable
ALTER TABLE "ScenarioVariantCell" DROP COLUMN "streamingChannel";

View File

@@ -90,11 +90,10 @@ enum CellRetrievalStatus {
model ScenarioVariantCell { model ScenarioVariantCell {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
statusCode Int? statusCode Int?
errorMessage String? errorMessage String?
retryTime DateTime? retryTime DateTime?
streamingChannel String? retrievalStatus CellRetrievalStatus @default(COMPLETE)
retrievalStatus CellRetrievalStatus @default(COMPLETE)
modelOutput ModelOutput? modelOutput ModelOutput?

View File

@@ -164,5 +164,5 @@ await Promise.all(
testScenarioId: scenario.id, testScenarioId: scenario.id,
})), })),
) )
.map((cell) => generateNewCell(cell.promptVariantId, cell.testScenarioId)), .map((cell) => generateNewCell(cell.promptVariantId, cell.testScenarioId, { stream: false })),
); );

View File

@@ -6,4 +6,7 @@ echo "Migrating the database"
pnpm prisma migrate deploy pnpm prisma migrate deploy
echo "Starting the server" echo "Starting the server"
pnpm start
pnpm concurrently --kill-others \
"pnpm start" \
"pnpm tsx src/server/tasks/worker.ts"

View File

@@ -19,7 +19,7 @@ import { useState } from "react";
import { RiExchangeFundsFill } from "react-icons/ri"; import { RiExchangeFundsFill } from "react-icons/ri";
import { type ProviderModel } from "~/modelProviders/types"; import { type ProviderModel } from "~/modelProviders/types";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import { useExperiment, useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks";
import { lookupModel, modelLabel } from "~/utils/utils"; import { lookupModel, modelLabel } from "~/utils/utils";
import CompareFunctions from "../RefinePromptModal/CompareFunctions"; import CompareFunctions from "../RefinePromptModal/CompareFunctions";
import { ModelSearch } from "./ModelSearch"; import { ModelSearch } from "./ModelSearch";
@@ -38,6 +38,7 @@ export const ChangeModelModal = ({
model: variant.model, model: variant.model,
} as ProviderModel); } as ProviderModel);
const [convertedModel, setConvertedModel] = useState<ProviderModel | undefined>(); const [convertedModel, setConvertedModel] = useState<ProviderModel | undefined>();
const visibleScenarios = useVisibleScenarioIds();
const utils = api.useContext(); const utils = api.useContext();
@@ -68,6 +69,7 @@ export const ChangeModelModal = ({
await replaceVariantMutation.mutateAsync({ await replaceVariantMutation.mutateAsync({
id: variant.id, id: variant.id,
constructFn: modifiedPromptFn, constructFn: modifiedPromptFn,
streamScenarios: visibleScenarios,
}); });
await utils.promptVariants.list.invalidate(); await utils.promptVariants.list.invalidate();
onClose(); onClose();

View File

@@ -2,7 +2,12 @@ import { Box, Flex, Icon, Spinner } from "@chakra-ui/react";
import { BsPlus } from "react-icons/bs"; import { BsPlus } from "react-icons/bs";
import { Text } from "@chakra-ui/react"; import { Text } from "@chakra-ui/react";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks"; import {
useExperiment,
useExperimentAccess,
useHandledAsyncCallback,
useVisibleScenarioIds,
} from "~/utils/hooks";
import { cellPadding } from "../constants"; import { cellPadding } from "../constants";
import { ActionButton } from "./ScenariosHeader"; import { ActionButton } from "./ScenariosHeader";
@@ -10,11 +15,13 @@ export default function AddVariantButton() {
const experiment = useExperiment(); const experiment = useExperiment();
const mutation = api.promptVariants.create.useMutation(); const mutation = api.promptVariants.create.useMutation();
const utils = api.useContext(); const utils = api.useContext();
const visibleScenarios = useVisibleScenarioIds();
const [onClick, loading] = useHandledAsyncCallback(async () => { const [onClick, loading] = useHandledAsyncCallback(async () => {
if (!experiment.data) return; if (!experiment.data) return;
await mutation.mutateAsync({ await mutation.mutateAsync({
experimentId: experiment.data.id, experimentId: experiment.data.id,
streamScenarios: visibleScenarios,
}); });
await utils.promptVariants.list.invalidate(); await utils.promptVariants.list.invalidate();
}, [mutation]); }, [mutation]);

View File

@@ -67,8 +67,8 @@ export default function OutputCell({
const modelOutput = cell?.modelOutput; const modelOutput = cell?.modelOutput;
// Disconnect from socket if we're not streaming anymore // TODO: disconnect from socket if we're not streaming anymore
const streamedMessage = useSocket<OutputSchema>(cell?.streamingChannel); const streamedMessage = useSocket<OutputSchema>(cell?.id);
if (!vars) return null; if (!vars) return null;

View File

@@ -54,13 +54,13 @@ export const ScenariosHeader = () => {
</Text> </Text>
{canModify && ( {canModify && (
<Menu> <Menu>
<MenuButton mt={1}> <MenuButton
<IconButton as={IconButton}
variant="ghost" mt={1}
aria-label="Edit Scenarios" variant="ghost"
icon={<Icon as={loading ? Spinner : BsGear} />} aria-label="Edit Scenarios"
/> icon={<Icon as={loading ? Spinner : BsGear} />}
</MenuButton> />
<MenuList fontSize="md" zIndex="dropdown" mt={-3}> <MenuList fontSize="md" zIndex="dropdown" mt={-3}>
<MenuItem <MenuItem
icon={<Icon as={BsPlus} boxSize={6} mx="-5px" />} icon={<Icon as={BsPlus} boxSize={6} mx="-5px" />}

View File

@@ -2,19 +2,24 @@ import {
Box, Box,
Button, Button,
HStack, HStack,
IconButton,
Spinner, Spinner,
Text,
Tooltip, Tooltip,
useToast, useToast,
Text,
IconButton,
} from "@chakra-ui/react"; } from "@chakra-ui/react";
import { useRef, useEffect, useState, useCallback } from "react"; import { useCallback, useEffect, useRef, useState } from "react";
import { useExperimentAccess, useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
import { type PromptVariant } from "./types";
import { api } from "~/utils/api";
import { useAppStore } from "~/state/store";
import { FiMaximize, FiMinimize } from "react-icons/fi"; import { FiMaximize, FiMinimize } from "react-icons/fi";
import { editorBackground } from "~/state/sharedVariantEditor.slice"; import { editorBackground } from "~/state/sharedVariantEditor.slice";
import { useAppStore } from "~/state/store";
import { api } from "~/utils/api";
import {
useExperimentAccess,
useHandledAsyncCallback,
useModifierKeyLabel,
useVisibleScenarioIds,
} from "~/utils/hooks";
import { type PromptVariant } from "./types";
export default function VariantEditor(props: { variant: PromptVariant }) { export default function VariantEditor(props: { variant: PromptVariant }) {
const { canModify } = useExperimentAccess(); const { canModify } = useExperimentAccess();
@@ -63,6 +68,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
const replaceVariant = api.promptVariants.replaceVariant.useMutation(); const replaceVariant = api.promptVariants.replaceVariant.useMutation();
const utils = api.useContext(); const utils = api.useContext();
const toast = useToast(); const toast = useToast();
const visibleScenarios = useVisibleScenarioIds();
const [onSave, saveInProgress] = useHandledAsyncCallback(async () => { const [onSave, saveInProgress] = useHandledAsyncCallback(async () => {
if (!editorRef.current) return; if (!editorRef.current) return;
@@ -91,6 +97,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
const resp = await replaceVariant.mutateAsync({ const resp = await replaceVariant.mutateAsync({
id: props.variant.id, id: props.variant.id,
constructFn: currentFn, constructFn: currentFn,
streamScenarios: visibleScenarios,
}); });
if (resp.status === "error") { if (resp.status === "error") {
return toast({ return toast({

View File

@@ -16,7 +16,7 @@ import {
} from "@chakra-ui/react"; } from "@chakra-ui/react";
import { BsStars } from "react-icons/bs"; import { BsStars } from "react-icons/bs";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks"; import { useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks";
import { type PromptVariant } from "@prisma/client"; import { type PromptVariant } from "@prisma/client";
import { useState } from "react"; import { useState } from "react";
import CompareFunctions from "./CompareFunctions"; import CompareFunctions from "./CompareFunctions";
@@ -34,6 +34,7 @@ export const RefinePromptModal = ({
onClose: () => void; onClose: () => void;
}) => { }) => {
const utils = api.useContext(); const utils = api.useContext();
const visibleScenarios = useVisibleScenarioIds();
const refinementActions = const refinementActions =
frontendModelProviders[variant.modelProvider as SupportedProvider].refinementActions || {}; frontendModelProviders[variant.modelProvider as SupportedProvider].refinementActions || {};
@@ -73,6 +74,7 @@ export const RefinePromptModal = ({
await replaceVariantMutation.mutateAsync({ await replaceVariantMutation.mutateAsync({
id: variant.id, id: variant.id,
constructFn: refinedPromptFn, constructFn: refinedPromptFn,
streamScenarios: visibleScenarios,
}); });
await utils.promptVariants.list.invalidate(); await utils.promptVariants.list.invalidate();
onClose(); onClose();

View File

@@ -1,8 +1,7 @@
import { type PromptVariant } from "../OutputsTable/types"; import { type PromptVariant } from "../OutputsTable/types";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks"; import { useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks";
import { import {
Button,
Icon, Icon,
Menu, Menu,
MenuButton, MenuButton,
@@ -11,6 +10,7 @@ import {
MenuDivider, MenuDivider,
Text, Text,
Spinner, Spinner,
IconButton,
} from "@chakra-ui/react"; } from "@chakra-ui/react";
import { BsFillTrashFill, BsGear, BsStars } from "react-icons/bs"; import { BsFillTrashFill, BsGear, BsStars } from "react-icons/bs";
import { FaRegClone } from "react-icons/fa"; import { FaRegClone } from "react-icons/fa";
@@ -33,11 +33,13 @@ export default function VariantHeaderMenuButton({
const utils = api.useContext(); const utils = api.useContext();
const duplicateMutation = api.promptVariants.create.useMutation(); const duplicateMutation = api.promptVariants.create.useMutation();
const visibleScenarios = useVisibleScenarioIds();
const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => { const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => {
await duplicateMutation.mutateAsync({ await duplicateMutation.mutateAsync({
experimentId: variant.experimentId, experimentId: variant.experimentId,
variantId: variant.id, variantId: variant.id,
streamScenarios: visibleScenarios,
}); });
await utils.promptVariants.list.invalidate(); await utils.promptVariants.list.invalidate();
}, [duplicateMutation, variant.experimentId, variant.id]); }, [duplicateMutation, variant.experimentId, variant.id]);
@@ -56,15 +58,12 @@ export default function VariantHeaderMenuButton({
return ( return (
<> <>
<Menu isOpen={menuOpen} onOpen={() => setMenuOpen(true)} onClose={() => setMenuOpen(false)}> <Menu isOpen={menuOpen} onOpen={() => setMenuOpen(true)} onClose={() => setMenuOpen(false)}>
{duplicationInProgress ? ( <MenuButton
<Spinner boxSize={4} mx={3} my={3} /> as={IconButton}
) : ( variant="ghost"
<MenuButton> aria-label="Edit Scenarios"
<Button variant="ghost"> icon={<Icon as={duplicationInProgress ? Spinner : BsGear} />}
<Icon as={BsGear} /> />
</Button>
</MenuButton>
)}
<MenuList mt={-3} fontSize="md"> <MenuList mt={-3} fontSize="md">
<MenuItem icon={<Icon as={FaRegClone} boxSize={4} w={5} />} onClick={duplicateVariant}> <MenuItem icon={<Icon as={FaRegClone} boxSize={4} w={5} />} onClick={duplicateVariant}>

View File

@@ -37,7 +37,7 @@ const modelProvider: OpenaiChatModelProvider = {
return null; return null;
}, },
inputSchema: inputSchema as JSONSchema4, inputSchema: inputSchema as JSONSchema4,
shouldStream: (input) => input.stream ?? false, canStream: true,
getCompletion, getCompletion,
...frontendModelProvider, ...frontendModelProvider,
}; };

View File

@@ -19,7 +19,7 @@ export async function getCompletion(
): Promise<CompletionResponse<ReplicateLlama2Output>> { ): Promise<CompletionResponse<ReplicateLlama2Output>> {
const start = Date.now(); const start = Date.now();
const { model, stream, ...rest } = input; const { model, ...rest } = input;
try { try {
const prediction = await replicate.predictions.create({ const prediction = await replicate.predictions.create({

View File

@@ -9,7 +9,6 @@ export type SupportedModel = (typeof supportedModels)[number];
export type ReplicateLlama2Input = { export type ReplicateLlama2Input = {
model: SupportedModel; model: SupportedModel;
prompt: string; prompt: string;
stream?: boolean;
max_length?: number; max_length?: number;
temperature?: number; temperature?: number;
top_p?: number; top_p?: number;
@@ -47,10 +46,6 @@ const modelProvider: ReplicateLlama2Provider = {
type: "string", type: "string",
description: "Prompt to send to Llama v2.", description: "Prompt to send to Llama v2.",
}, },
stream: {
type: "boolean",
description: "Whether to stream output from Llama v2.",
},
max_new_tokens: { max_new_tokens: {
type: "number", type: "number",
description: description:
@@ -78,7 +73,7 @@ const modelProvider: ReplicateLlama2Provider = {
}, },
required: ["model", "prompt"], required: ["model", "prompt"],
}, },
shouldStream: (input) => input.stream ?? false, canStream: true,
getCompletion, getCompletion,
...frontendModelProvider, ...frontendModelProvider,
}; };

View File

@@ -48,7 +48,7 @@ export type CompletionResponse<T> =
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = { export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
getModel: (input: InputSchema) => SupportedModels | null; getModel: (input: InputSchema) => SupportedModels | null;
shouldStream: (input: InputSchema) => boolean; canStream: boolean;
inputSchema: JSONSchema4; inputSchema: JSONSchema4;
getCompletion: ( getCompletion: (
input: InputSchema, input: InputSchema,

View File

@@ -145,6 +145,7 @@ export const promptVariantsRouter = createTRPCRouter({
z.object({ z.object({
experimentId: z.string(), experimentId: z.string(),
variantId: z.string().optional(), variantId: z.string().optional(),
streamScenarios: z.array(z.string()),
}), }),
) )
.mutation(async ({ input, ctx }) => { .mutation(async ({ input, ctx }) => {
@@ -218,7 +219,9 @@ export const promptVariantsRouter = createTRPCRouter({
}); });
for (const scenario of scenarios) { for (const scenario of scenarios) {
await generateNewCell(newVariant.id, scenario.id); await generateNewCell(newVariant.id, scenario.id, {
stream: input.streamScenarios.includes(scenario.id),
});
} }
return newVariant; return newVariant;
@@ -325,6 +328,7 @@ export const promptVariantsRouter = createTRPCRouter({
z.object({ z.object({
id: z.string(), id: z.string(),
constructFn: z.string(), constructFn: z.string(),
streamScenarios: z.array(z.string()),
}), }),
) )
.mutation(async ({ input, ctx }) => { .mutation(async ({ input, ctx }) => {
@@ -382,7 +386,9 @@ export const promptVariantsRouter = createTRPCRouter({
}); });
for (const scenario of scenarios) { for (const scenario of scenarios) {
await generateNewCell(newVariant.id, scenario.id); await generateNewCell(newVariant.id, scenario.id, {
stream: input.streamScenarios.includes(scenario.id),
});
} }
return { status: "ok" } as const; return { status: "ok" } as const;

View File

@@ -1,8 +1,8 @@
import { z } from "zod"; import { z } from "zod";
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { queueQueryModel } from "~/server/tasks/queryModel.task";
import { generateNewCell } from "~/server/utils/generateNewCell"; import { generateNewCell } from "~/server/utils/generateNewCell";
import { queueLLMRetrievalTask } from "~/server/utils/queueLLMRetrievalTask";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl"; import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const scenarioVariantCellsRouter = createTRPCRouter({ export const scenarioVariantCellsRouter = createTRPCRouter({
@@ -62,14 +62,12 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
testScenarioId: input.scenarioId, testScenarioId: input.scenarioId,
}, },
}, },
include: { include: { modelOutput: true },
modelOutput: true,
},
}); });
if (!cell) { if (!cell) {
await generateNewCell(input.variantId, input.scenarioId); await generateNewCell(input.variantId, input.scenarioId, { stream: true });
return true; return;
} }
if (cell.modelOutput) { if (cell.modelOutput) {
@@ -79,12 +77,6 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
}); });
} }
await prisma.scenarioVariantCell.update({ await queueQueryModel(cell.id, true);
where: { id: cell.id },
data: { retrievalStatus: "PENDING" },
});
await queueLLMRetrievalTask(cell.id);
return true;
}), }),
}); });

View File

@@ -86,7 +86,7 @@ export const scenariosRouter = createTRPCRouter({
}); });
for (const variant of promptVariants) { for (const variant of promptVariants) {
await generateNewCell(variant.id, scenario.id); await generateNewCell(variant.id, scenario.id, { stream: true });
} }
}), }),
@@ -230,7 +230,7 @@ export const scenariosRouter = createTRPCRouter({
}); });
for (const variant of promptVariants) { for (const variant of promptVariants) {
await generateNewCell(variant.id, newScenario.id); await generateNewCell(variant.id, newScenario.id, { stream: true });
} }
return newScenario; return newScenario;

View File

@@ -1,17 +1,17 @@
import { prisma } from "~/server/db";
import defineTask from "./defineTask";
import { sleep } from "../utils/sleep";
import { generateChannel } from "~/utils/generateChannel";
import { runEvalsForOutput } from "../utils/evaluations";
import { type Prisma } from "@prisma/client"; import { type Prisma } from "@prisma/client";
import parseConstructFn from "../utils/parseConstructFn";
import hashPrompt from "../utils/hashPrompt";
import { type JsonObject } from "type-fest"; import { type JsonObject } from "type-fest";
import modelProviders from "~/modelProviders/modelProviders"; import modelProviders from "~/modelProviders/modelProviders";
import { prisma } from "~/server/db";
import { wsConnection } from "~/utils/wsConnection"; import { wsConnection } from "~/utils/wsConnection";
import { runEvalsForOutput } from "../utils/evaluations";
import hashPrompt from "../utils/hashPrompt";
import parseConstructFn from "../utils/parseConstructFn";
import { sleep } from "../utils/sleep";
import defineTask from "./defineTask";
export type queryLLMJob = { export type QueryModelJob = {
scenarioVariantCellId: string; cellId: string;
stream: boolean;
}; };
const MAX_AUTO_RETRIES = 10; const MAX_AUTO_RETRIES = 10;
@@ -24,15 +24,16 @@ function calculateDelay(numPreviousTries: number): number {
return baseDelay + jitter; return baseDelay + jitter;
} }
export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => { export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) => {
const { scenarioVariantCellId } = task; console.log("RUNNING TASK", task);
const { cellId, stream } = task;
const cell = await prisma.scenarioVariantCell.findUnique({ const cell = await prisma.scenarioVariantCell.findUnique({
where: { id: scenarioVariantCellId }, where: { id: cellId },
include: { modelOutput: true }, include: { modelOutput: true },
}); });
if (!cell) { if (!cell) {
await prisma.scenarioVariantCell.update({ await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId }, where: { id: cellId },
data: { data: {
statusCode: 404, statusCode: 404,
errorMessage: "Cell not found", errorMessage: "Cell not found",
@@ -47,7 +48,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
return; return;
} }
await prisma.scenarioVariantCell.update({ await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId }, where: { id: cellId },
data: { data: {
retrievalStatus: "IN_PROGRESS", retrievalStatus: "IN_PROGRESS",
}, },
@@ -58,7 +59,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
}); });
if (!variant) { if (!variant) {
await prisma.scenarioVariantCell.update({ await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId }, where: { id: cellId },
data: { data: {
statusCode: 404, statusCode: 404,
errorMessage: "Prompt Variant not found", errorMessage: "Prompt Variant not found",
@@ -73,7 +74,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
}); });
if (!scenario) { if (!scenario) {
await prisma.scenarioVariantCell.update({ await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId }, where: { id: cellId },
data: { data: {
statusCode: 404, statusCode: 404,
errorMessage: "Scenario not found", errorMessage: "Scenario not found",
@@ -87,7 +88,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
if ("error" in prompt) { if ("error" in prompt) {
await prisma.scenarioVariantCell.update({ await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId }, where: { id: cellId },
data: { data: {
statusCode: 400, statusCode: 400,
errorMessage: prompt.error, errorMessage: prompt.error,
@@ -99,18 +100,9 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
const provider = modelProviders[prompt.modelProvider]; const provider = modelProviders[prompt.modelProvider];
const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null; const onStream = stream
if (streamingChannel) {
// Save streaming channel so that UI can connect to it
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: { streamingChannel },
});
}
const onStream = streamingChannel
? (partialOutput: (typeof provider)["_outputSchema"]) => { ? (partialOutput: (typeof provider)["_outputSchema"]) => {
wsConnection.emit("message", { channel: streamingChannel, payload: partialOutput }); wsConnection.emit("message", { channel: cell.id, payload: partialOutput });
} }
: null; : null;
@@ -121,7 +113,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
const modelOutput = await prisma.modelOutput.create({ const modelOutput = await prisma.modelOutput.create({
data: { data: {
scenarioVariantCellId, scenarioVariantCellId: cellId,
inputHash, inputHash,
output: response.value as Prisma.InputJsonObject, output: response.value as Prisma.InputJsonObject,
timeToComplete: response.timeToComplete, timeToComplete: response.timeToComplete,
@@ -132,7 +124,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
}); });
await prisma.scenarioVariantCell.update({ await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId }, where: { id: cellId },
data: { data: {
statusCode: response.statusCode, statusCode: response.statusCode,
retrievalStatus: "COMPLETE", retrievalStatus: "COMPLETE",
@@ -146,7 +138,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
const delay = calculateDelay(i); const delay = calculateDelay(i);
await prisma.scenarioVariantCell.update({ await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId }, where: { id: cellId },
data: { data: {
errorMessage: response.message, errorMessage: response.message,
statusCode: response.statusCode, statusCode: response.statusCode,
@@ -163,3 +155,21 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
} }
} }
}); });
export const queueQueryModel = async (cellId: string, stream: boolean) => {
console.log("queueQueryModel", cellId, stream);
await Promise.all([
prisma.scenarioVariantCell.update({
where: {
id: cellId,
},
data: {
retrievalStatus: "PENDING",
errorMessage: null,
},
}),
await queryModel.enqueue({ cellId, stream }),
console.log("queued"),
]);
};

View File

@@ -2,39 +2,27 @@ import { type TaskList, run } from "graphile-worker";
import "dotenv/config"; import "dotenv/config";
import { env } from "~/env.mjs"; import { env } from "~/env.mjs";
import { queryLLM } from "./queryLLM.task"; import { queryModel } from "./queryModel.task";
const registeredTasks = [queryLLM]; console.log("Starting worker");
const registeredTasks = [queryModel];
const taskList = registeredTasks.reduce((acc, task) => { const taskList = registeredTasks.reduce((acc, task) => {
acc[task.task.identifier] = task.task.handler; acc[task.task.identifier] = task.task.handler;
return acc; return acc;
}, {} as TaskList); }, {} as TaskList);
async function main() { // Run a worker to execute jobs:
// Run a worker to execute jobs: const runner = await run({
const runner = await run({ connectionString: env.DATABASE_URL,
connectionString: env.DATABASE_URL, concurrency: 20,
concurrency: 20, // Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc
// Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc noHandleSignals: false,
noHandleSignals: false, pollInterval: 1000,
pollInterval: 1000, taskList,
// you can set the taskList or taskDirectory but not both
taskList,
// or:
// taskDirectory: `${__dirname}/tasks`,
});
// Immediately await (or otherwise handled) the resulting promise, to avoid
// "unhandled rejection" errors causing a process crash in the event of
// something going wrong.
await runner.promise;
// If the worker exits (whether through fatal error or otherwise), the above
// promise will resolve/reject.
}
main().catch((err) => {
console.error("Unhandled error occurred running worker: ", err);
process.exit(1);
}); });
console.log("Worker successfully started");
await runner.promise;

View File

@@ -1,12 +1,18 @@
import { type Prisma } from "@prisma/client"; import { type Prisma } from "@prisma/client";
import { prisma } from "../db"; import { prisma } from "../db";
import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
import parseConstructFn from "./parseConstructFn"; import parseConstructFn from "./parseConstructFn";
import { type JsonObject } from "type-fest"; import { type JsonObject } from "type-fest";
import hashPrompt from "./hashPrompt"; import hashPrompt from "./hashPrompt";
import { omit } from "lodash-es"; import { omit } from "lodash-es";
import { queueQueryModel } from "../tasks/queryModel.task";
export const generateNewCell = async (
variantId: string,
scenarioId: string,
options?: { stream?: boolean },
): Promise<void> => {
const stream = options?.stream ?? false;
export const generateNewCell = async (variantId: string, scenarioId: string): Promise<void> => {
const variant = await prisma.promptVariant.findUnique({ const variant = await prisma.promptVariant.findUnique({
where: { where: {
id: variantId, id: variantId,
@@ -98,6 +104,6 @@ export const generateNewCell = async (variantId: string, scenarioId: string): Pr
}), }),
); );
} else { } else {
cell = await queueLLMRetrievalTask(cell.id); await queueQueryModel(cell.id, stream);
} }
}; };

View File

@@ -1,22 +0,0 @@
import { prisma } from "../db";
import { queryLLM } from "../tasks/queryLLM.task";
export const queueLLMRetrievalTask = async (cellId: string) => {
const updatedCell = await prisma.scenarioVariantCell.update({
where: {
id: cellId,
},
data: {
retrievalStatus: "PENDING",
errorMessage: null,
},
include: {
modelOutput: true,
},
});
// @ts-expect-error we aren't passing the helpers but that's ok
void queryLLM.task.handler({ scenarioVariantCellId: cellId }, { logger: console });
return updatedCell;
};

View File

@@ -1,5 +0,0 @@
// generate random channel id
export const generateChannel = () => {
return Math.random().toString(36).substring(2, 15) + Math.random().toString(36).substring(2, 15);
};

View File

@@ -106,3 +106,5 @@ export const useScenarios = () => {
{ enabled: experiment.data?.id != null }, { enabled: experiment.data?.id != null },
); );
}; };
export const useVisibleScenarioIds = () => useScenarios().data?.scenarios.map((s) => s.id) ?? [];