Better streaming
- Always stream the visible scenarios, if the modelProvider supports it - Never stream the invisible scenarios Also actually runs our query tasks in a background worker, which we weren't quite doing before.
This commit is contained in:
@@ -12,7 +12,7 @@
|
|||||||
"dev:next": "next dev",
|
"dev:next": "next dev",
|
||||||
"dev:wss": "pnpm tsx --watch src/wss-server.ts",
|
"dev:wss": "pnpm tsx --watch src/wss-server.ts",
|
||||||
"dev:worker": "NODE_ENV='development' pnpm tsx --watch src/server/tasks/worker.ts",
|
"dev:worker": "NODE_ENV='development' pnpm tsx --watch src/server/tasks/worker.ts",
|
||||||
"dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss'",
|
"dev": "concurrently --kill-others 'pnpm dev:next' 'pnpm dev:wss' 'pnpm dev:worker'",
|
||||||
"postinstall": "prisma generate",
|
"postinstall": "prisma generate",
|
||||||
"lint": "next lint",
|
"lint": "next lint",
|
||||||
"start": "next start",
|
"start": "next start",
|
||||||
|
|||||||
@@ -0,0 +1,8 @@
|
|||||||
|
/*
|
||||||
|
Warnings:
|
||||||
|
|
||||||
|
- You are about to drop the column `streamingChannel` on the `ScenarioVariantCell` table. All the data in the column will be lost.
|
||||||
|
|
||||||
|
*/
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "ScenarioVariantCell" DROP COLUMN "streamingChannel";
|
||||||
@@ -90,11 +90,10 @@ enum CellRetrievalStatus {
|
|||||||
model ScenarioVariantCell {
|
model ScenarioVariantCell {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
statusCode Int?
|
statusCode Int?
|
||||||
errorMessage String?
|
errorMessage String?
|
||||||
retryTime DateTime?
|
retryTime DateTime?
|
||||||
streamingChannel String?
|
retrievalStatus CellRetrievalStatus @default(COMPLETE)
|
||||||
retrievalStatus CellRetrievalStatus @default(COMPLETE)
|
|
||||||
|
|
||||||
modelOutput ModelOutput?
|
modelOutput ModelOutput?
|
||||||
|
|
||||||
|
|||||||
@@ -164,5 +164,5 @@ await Promise.all(
|
|||||||
testScenarioId: scenario.id,
|
testScenarioId: scenario.id,
|
||||||
})),
|
})),
|
||||||
)
|
)
|
||||||
.map((cell) => generateNewCell(cell.promptVariantId, cell.testScenarioId)),
|
.map((cell) => generateNewCell(cell.promptVariantId, cell.testScenarioId, { stream: false })),
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -6,4 +6,7 @@ echo "Migrating the database"
|
|||||||
pnpm prisma migrate deploy
|
pnpm prisma migrate deploy
|
||||||
|
|
||||||
echo "Starting the server"
|
echo "Starting the server"
|
||||||
pnpm start
|
|
||||||
|
pnpm concurrently --kill-others \
|
||||||
|
"pnpm start" \
|
||||||
|
"pnpm tsx src/server/tasks/worker.ts"
|
||||||
@@ -19,7 +19,7 @@ import { useState } from "react";
|
|||||||
import { RiExchangeFundsFill } from "react-icons/ri";
|
import { RiExchangeFundsFill } from "react-icons/ri";
|
||||||
import { type ProviderModel } from "~/modelProviders/types";
|
import { type ProviderModel } from "~/modelProviders/types";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useExperiment, useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks";
|
||||||
import { lookupModel, modelLabel } from "~/utils/utils";
|
import { lookupModel, modelLabel } from "~/utils/utils";
|
||||||
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
|
import CompareFunctions from "../RefinePromptModal/CompareFunctions";
|
||||||
import { ModelSearch } from "./ModelSearch";
|
import { ModelSearch } from "./ModelSearch";
|
||||||
@@ -38,6 +38,7 @@ export const ChangeModelModal = ({
|
|||||||
model: variant.model,
|
model: variant.model,
|
||||||
} as ProviderModel);
|
} as ProviderModel);
|
||||||
const [convertedModel, setConvertedModel] = useState<ProviderModel | undefined>();
|
const [convertedModel, setConvertedModel] = useState<ProviderModel | undefined>();
|
||||||
|
const visibleScenarios = useVisibleScenarioIds();
|
||||||
|
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
|
|
||||||
@@ -68,6 +69,7 @@ export const ChangeModelModal = ({
|
|||||||
await replaceVariantMutation.mutateAsync({
|
await replaceVariantMutation.mutateAsync({
|
||||||
id: variant.id,
|
id: variant.id,
|
||||||
constructFn: modifiedPromptFn,
|
constructFn: modifiedPromptFn,
|
||||||
|
streamScenarios: visibleScenarios,
|
||||||
});
|
});
|
||||||
await utils.promptVariants.list.invalidate();
|
await utils.promptVariants.list.invalidate();
|
||||||
onClose();
|
onClose();
|
||||||
|
|||||||
@@ -2,7 +2,12 @@ import { Box, Flex, Icon, Spinner } from "@chakra-ui/react";
|
|||||||
import { BsPlus } from "react-icons/bs";
|
import { BsPlus } from "react-icons/bs";
|
||||||
import { Text } from "@chakra-ui/react";
|
import { Text } from "@chakra-ui/react";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
|
import {
|
||||||
|
useExperiment,
|
||||||
|
useExperimentAccess,
|
||||||
|
useHandledAsyncCallback,
|
||||||
|
useVisibleScenarioIds,
|
||||||
|
} from "~/utils/hooks";
|
||||||
import { cellPadding } from "../constants";
|
import { cellPadding } from "../constants";
|
||||||
import { ActionButton } from "./ScenariosHeader";
|
import { ActionButton } from "./ScenariosHeader";
|
||||||
|
|
||||||
@@ -10,11 +15,13 @@ export default function AddVariantButton() {
|
|||||||
const experiment = useExperiment();
|
const experiment = useExperiment();
|
||||||
const mutation = api.promptVariants.create.useMutation();
|
const mutation = api.promptVariants.create.useMutation();
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
|
const visibleScenarios = useVisibleScenarioIds();
|
||||||
|
|
||||||
const [onClick, loading] = useHandledAsyncCallback(async () => {
|
const [onClick, loading] = useHandledAsyncCallback(async () => {
|
||||||
if (!experiment.data) return;
|
if (!experiment.data) return;
|
||||||
await mutation.mutateAsync({
|
await mutation.mutateAsync({
|
||||||
experimentId: experiment.data.id,
|
experimentId: experiment.data.id,
|
||||||
|
streamScenarios: visibleScenarios,
|
||||||
});
|
});
|
||||||
await utils.promptVariants.list.invalidate();
|
await utils.promptVariants.list.invalidate();
|
||||||
}, [mutation]);
|
}, [mutation]);
|
||||||
|
|||||||
@@ -67,8 +67,8 @@ export default function OutputCell({
|
|||||||
|
|
||||||
const modelOutput = cell?.modelOutput;
|
const modelOutput = cell?.modelOutput;
|
||||||
|
|
||||||
// Disconnect from socket if we're not streaming anymore
|
// TODO: disconnect from socket if we're not streaming anymore
|
||||||
const streamedMessage = useSocket<OutputSchema>(cell?.streamingChannel);
|
const streamedMessage = useSocket<OutputSchema>(cell?.id);
|
||||||
|
|
||||||
if (!vars) return null;
|
if (!vars) return null;
|
||||||
|
|
||||||
|
|||||||
@@ -54,13 +54,13 @@ export const ScenariosHeader = () => {
|
|||||||
</Text>
|
</Text>
|
||||||
{canModify && (
|
{canModify && (
|
||||||
<Menu>
|
<Menu>
|
||||||
<MenuButton mt={1}>
|
<MenuButton
|
||||||
<IconButton
|
as={IconButton}
|
||||||
variant="ghost"
|
mt={1}
|
||||||
aria-label="Edit Scenarios"
|
variant="ghost"
|
||||||
icon={<Icon as={loading ? Spinner : BsGear} />}
|
aria-label="Edit Scenarios"
|
||||||
/>
|
icon={<Icon as={loading ? Spinner : BsGear} />}
|
||||||
</MenuButton>
|
/>
|
||||||
<MenuList fontSize="md" zIndex="dropdown" mt={-3}>
|
<MenuList fontSize="md" zIndex="dropdown" mt={-3}>
|
||||||
<MenuItem
|
<MenuItem
|
||||||
icon={<Icon as={BsPlus} boxSize={6} mx="-5px" />}
|
icon={<Icon as={BsPlus} boxSize={6} mx="-5px" />}
|
||||||
|
|||||||
@@ -2,19 +2,24 @@ import {
|
|||||||
Box,
|
Box,
|
||||||
Button,
|
Button,
|
||||||
HStack,
|
HStack,
|
||||||
|
IconButton,
|
||||||
Spinner,
|
Spinner,
|
||||||
|
Text,
|
||||||
Tooltip,
|
Tooltip,
|
||||||
useToast,
|
useToast,
|
||||||
Text,
|
|
||||||
IconButton,
|
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { useRef, useEffect, useState, useCallback } from "react";
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
import { useExperimentAccess, useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
|
|
||||||
import { type PromptVariant } from "./types";
|
|
||||||
import { api } from "~/utils/api";
|
|
||||||
import { useAppStore } from "~/state/store";
|
|
||||||
import { FiMaximize, FiMinimize } from "react-icons/fi";
|
import { FiMaximize, FiMinimize } from "react-icons/fi";
|
||||||
import { editorBackground } from "~/state/sharedVariantEditor.slice";
|
import { editorBackground } from "~/state/sharedVariantEditor.slice";
|
||||||
|
import { useAppStore } from "~/state/store";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import {
|
||||||
|
useExperimentAccess,
|
||||||
|
useHandledAsyncCallback,
|
||||||
|
useModifierKeyLabel,
|
||||||
|
useVisibleScenarioIds,
|
||||||
|
} from "~/utils/hooks";
|
||||||
|
import { type PromptVariant } from "./types";
|
||||||
|
|
||||||
export default function VariantEditor(props: { variant: PromptVariant }) {
|
export default function VariantEditor(props: { variant: PromptVariant }) {
|
||||||
const { canModify } = useExperimentAccess();
|
const { canModify } = useExperimentAccess();
|
||||||
@@ -63,6 +68,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
|||||||
const replaceVariant = api.promptVariants.replaceVariant.useMutation();
|
const replaceVariant = api.promptVariants.replaceVariant.useMutation();
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
const toast = useToast();
|
const toast = useToast();
|
||||||
|
const visibleScenarios = useVisibleScenarioIds();
|
||||||
|
|
||||||
const [onSave, saveInProgress] = useHandledAsyncCallback(async () => {
|
const [onSave, saveInProgress] = useHandledAsyncCallback(async () => {
|
||||||
if (!editorRef.current) return;
|
if (!editorRef.current) return;
|
||||||
@@ -91,6 +97,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
|
|||||||
const resp = await replaceVariant.mutateAsync({
|
const resp = await replaceVariant.mutateAsync({
|
||||||
id: props.variant.id,
|
id: props.variant.id,
|
||||||
constructFn: currentFn,
|
constructFn: currentFn,
|
||||||
|
streamScenarios: visibleScenarios,
|
||||||
});
|
});
|
||||||
if (resp.status === "error") {
|
if (resp.status === "error") {
|
||||||
return toast({
|
return toast({
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import {
|
|||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { BsStars } from "react-icons/bs";
|
import { BsStars } from "react-icons/bs";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
import { useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks";
|
||||||
import { type PromptVariant } from "@prisma/client";
|
import { type PromptVariant } from "@prisma/client";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
import CompareFunctions from "./CompareFunctions";
|
import CompareFunctions from "./CompareFunctions";
|
||||||
@@ -34,6 +34,7 @@ export const RefinePromptModal = ({
|
|||||||
onClose: () => void;
|
onClose: () => void;
|
||||||
}) => {
|
}) => {
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
|
const visibleScenarios = useVisibleScenarioIds();
|
||||||
|
|
||||||
const refinementActions =
|
const refinementActions =
|
||||||
frontendModelProviders[variant.modelProvider as SupportedProvider].refinementActions || {};
|
frontendModelProviders[variant.modelProvider as SupportedProvider].refinementActions || {};
|
||||||
@@ -73,6 +74,7 @@ export const RefinePromptModal = ({
|
|||||||
await replaceVariantMutation.mutateAsync({
|
await replaceVariantMutation.mutateAsync({
|
||||||
id: variant.id,
|
id: variant.id,
|
||||||
constructFn: refinedPromptFn,
|
constructFn: refinedPromptFn,
|
||||||
|
streamScenarios: visibleScenarios,
|
||||||
});
|
});
|
||||||
await utils.promptVariants.list.invalidate();
|
await utils.promptVariants.list.invalidate();
|
||||||
onClose();
|
onClose();
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
import { type PromptVariant } from "../OutputsTable/types";
|
import { type PromptVariant } from "../OutputsTable/types";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useHandledAsyncCallback } from "~/utils/hooks";
|
import { useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks";
|
||||||
import {
|
import {
|
||||||
Button,
|
|
||||||
Icon,
|
Icon,
|
||||||
Menu,
|
Menu,
|
||||||
MenuButton,
|
MenuButton,
|
||||||
@@ -11,6 +10,7 @@ import {
|
|||||||
MenuDivider,
|
MenuDivider,
|
||||||
Text,
|
Text,
|
||||||
Spinner,
|
Spinner,
|
||||||
|
IconButton,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { BsFillTrashFill, BsGear, BsStars } from "react-icons/bs";
|
import { BsFillTrashFill, BsGear, BsStars } from "react-icons/bs";
|
||||||
import { FaRegClone } from "react-icons/fa";
|
import { FaRegClone } from "react-icons/fa";
|
||||||
@@ -33,11 +33,13 @@ export default function VariantHeaderMenuButton({
|
|||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
|
|
||||||
const duplicateMutation = api.promptVariants.create.useMutation();
|
const duplicateMutation = api.promptVariants.create.useMutation();
|
||||||
|
const visibleScenarios = useVisibleScenarioIds();
|
||||||
|
|
||||||
const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => {
|
const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => {
|
||||||
await duplicateMutation.mutateAsync({
|
await duplicateMutation.mutateAsync({
|
||||||
experimentId: variant.experimentId,
|
experimentId: variant.experimentId,
|
||||||
variantId: variant.id,
|
variantId: variant.id,
|
||||||
|
streamScenarios: visibleScenarios,
|
||||||
});
|
});
|
||||||
await utils.promptVariants.list.invalidate();
|
await utils.promptVariants.list.invalidate();
|
||||||
}, [duplicateMutation, variant.experimentId, variant.id]);
|
}, [duplicateMutation, variant.experimentId, variant.id]);
|
||||||
@@ -56,15 +58,12 @@ export default function VariantHeaderMenuButton({
|
|||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Menu isOpen={menuOpen} onOpen={() => setMenuOpen(true)} onClose={() => setMenuOpen(false)}>
|
<Menu isOpen={menuOpen} onOpen={() => setMenuOpen(true)} onClose={() => setMenuOpen(false)}>
|
||||||
{duplicationInProgress ? (
|
<MenuButton
|
||||||
<Spinner boxSize={4} mx={3} my={3} />
|
as={IconButton}
|
||||||
) : (
|
variant="ghost"
|
||||||
<MenuButton>
|
aria-label="Edit Scenarios"
|
||||||
<Button variant="ghost">
|
icon={<Icon as={duplicationInProgress ? Spinner : BsGear} />}
|
||||||
<Icon as={BsGear} />
|
/>
|
||||||
</Button>
|
|
||||||
</MenuButton>
|
|
||||||
)}
|
|
||||||
|
|
||||||
<MenuList mt={-3} fontSize="md">
|
<MenuList mt={-3} fontSize="md">
|
||||||
<MenuItem icon={<Icon as={FaRegClone} boxSize={4} w={5} />} onClick={duplicateVariant}>
|
<MenuItem icon={<Icon as={FaRegClone} boxSize={4} w={5} />} onClick={duplicateVariant}>
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ const modelProvider: OpenaiChatModelProvider = {
|
|||||||
return null;
|
return null;
|
||||||
},
|
},
|
||||||
inputSchema: inputSchema as JSONSchema4,
|
inputSchema: inputSchema as JSONSchema4,
|
||||||
shouldStream: (input) => input.stream ?? false,
|
canStream: true,
|
||||||
getCompletion,
|
getCompletion,
|
||||||
...frontendModelProvider,
|
...frontendModelProvider,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ export async function getCompletion(
|
|||||||
): Promise<CompletionResponse<ReplicateLlama2Output>> {
|
): Promise<CompletionResponse<ReplicateLlama2Output>> {
|
||||||
const start = Date.now();
|
const start = Date.now();
|
||||||
|
|
||||||
const { model, stream, ...rest } = input;
|
const { model, ...rest } = input;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const prediction = await replicate.predictions.create({
|
const prediction = await replicate.predictions.create({
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ export type SupportedModel = (typeof supportedModels)[number];
|
|||||||
export type ReplicateLlama2Input = {
|
export type ReplicateLlama2Input = {
|
||||||
model: SupportedModel;
|
model: SupportedModel;
|
||||||
prompt: string;
|
prompt: string;
|
||||||
stream?: boolean;
|
|
||||||
max_length?: number;
|
max_length?: number;
|
||||||
temperature?: number;
|
temperature?: number;
|
||||||
top_p?: number;
|
top_p?: number;
|
||||||
@@ -47,10 +46,6 @@ const modelProvider: ReplicateLlama2Provider = {
|
|||||||
type: "string",
|
type: "string",
|
||||||
description: "Prompt to send to Llama v2.",
|
description: "Prompt to send to Llama v2.",
|
||||||
},
|
},
|
||||||
stream: {
|
|
||||||
type: "boolean",
|
|
||||||
description: "Whether to stream output from Llama v2.",
|
|
||||||
},
|
|
||||||
max_new_tokens: {
|
max_new_tokens: {
|
||||||
type: "number",
|
type: "number",
|
||||||
description:
|
description:
|
||||||
@@ -78,7 +73,7 @@ const modelProvider: ReplicateLlama2Provider = {
|
|||||||
},
|
},
|
||||||
required: ["model", "prompt"],
|
required: ["model", "prompt"],
|
||||||
},
|
},
|
||||||
shouldStream: (input) => input.stream ?? false,
|
canStream: true,
|
||||||
getCompletion,
|
getCompletion,
|
||||||
...frontendModelProvider,
|
...frontendModelProvider,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ export type CompletionResponse<T> =
|
|||||||
|
|
||||||
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
|
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
|
||||||
getModel: (input: InputSchema) => SupportedModels | null;
|
getModel: (input: InputSchema) => SupportedModels | null;
|
||||||
shouldStream: (input: InputSchema) => boolean;
|
canStream: boolean;
|
||||||
inputSchema: JSONSchema4;
|
inputSchema: JSONSchema4;
|
||||||
getCompletion: (
|
getCompletion: (
|
||||||
input: InputSchema,
|
input: InputSchema,
|
||||||
|
|||||||
@@ -145,6 +145,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
z.object({
|
z.object({
|
||||||
experimentId: z.string(),
|
experimentId: z.string(),
|
||||||
variantId: z.string().optional(),
|
variantId: z.string().optional(),
|
||||||
|
streamScenarios: z.array(z.string()),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input, ctx }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
@@ -218,7 +219,9 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
|
|
||||||
for (const scenario of scenarios) {
|
for (const scenario of scenarios) {
|
||||||
await generateNewCell(newVariant.id, scenario.id);
|
await generateNewCell(newVariant.id, scenario.id, {
|
||||||
|
stream: input.streamScenarios.includes(scenario.id),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
return newVariant;
|
return newVariant;
|
||||||
@@ -325,6 +328,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
z.object({
|
z.object({
|
||||||
id: z.string(),
|
id: z.string(),
|
||||||
constructFn: z.string(),
|
constructFn: z.string(),
|
||||||
|
streamScenarios: z.array(z.string()),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input, ctx }) => {
|
.mutation(async ({ input, ctx }) => {
|
||||||
@@ -382,7 +386,9 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
|
|
||||||
for (const scenario of scenarios) {
|
for (const scenario of scenarios) {
|
||||||
await generateNewCell(newVariant.id, scenario.id);
|
await generateNewCell(newVariant.id, scenario.id, {
|
||||||
|
stream: input.streamScenarios.includes(scenario.id),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
return { status: "ok" } as const;
|
return { status: "ok" } as const;
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
|
import { queueQueryModel } from "~/server/tasks/queryModel.task";
|
||||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
import { queueLLMRetrievalTask } from "~/server/utils/queueLLMRetrievalTask";
|
|
||||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||||
|
|
||||||
export const scenarioVariantCellsRouter = createTRPCRouter({
|
export const scenarioVariantCellsRouter = createTRPCRouter({
|
||||||
@@ -62,14 +62,12 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
|
|||||||
testScenarioId: input.scenarioId,
|
testScenarioId: input.scenarioId,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
include: {
|
include: { modelOutput: true },
|
||||||
modelOutput: true,
|
|
||||||
},
|
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!cell) {
|
if (!cell) {
|
||||||
await generateNewCell(input.variantId, input.scenarioId);
|
await generateNewCell(input.variantId, input.scenarioId, { stream: true });
|
||||||
return true;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cell.modelOutput) {
|
if (cell.modelOutput) {
|
||||||
@@ -79,12 +77,6 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
await prisma.scenarioVariantCell.update({
|
await queueQueryModel(cell.id, true);
|
||||||
where: { id: cell.id },
|
|
||||||
data: { retrievalStatus: "PENDING" },
|
|
||||||
});
|
|
||||||
|
|
||||||
await queueLLMRetrievalTask(cell.id);
|
|
||||||
return true;
|
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ export const scenariosRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
|
|
||||||
for (const variant of promptVariants) {
|
for (const variant of promptVariants) {
|
||||||
await generateNewCell(variant.id, scenario.id);
|
await generateNewCell(variant.id, scenario.id, { stream: true });
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
|
|
||||||
@@ -230,7 +230,7 @@ export const scenariosRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
|
|
||||||
for (const variant of promptVariants) {
|
for (const variant of promptVariants) {
|
||||||
await generateNewCell(variant.id, newScenario.id);
|
await generateNewCell(variant.id, newScenario.id, { stream: true });
|
||||||
}
|
}
|
||||||
|
|
||||||
return newScenario;
|
return newScenario;
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
import { prisma } from "~/server/db";
|
|
||||||
import defineTask from "./defineTask";
|
|
||||||
import { sleep } from "../utils/sleep";
|
|
||||||
import { generateChannel } from "~/utils/generateChannel";
|
|
||||||
import { runEvalsForOutput } from "../utils/evaluations";
|
|
||||||
import { type Prisma } from "@prisma/client";
|
import { type Prisma } from "@prisma/client";
|
||||||
import parseConstructFn from "../utils/parseConstructFn";
|
|
||||||
import hashPrompt from "../utils/hashPrompt";
|
|
||||||
import { type JsonObject } from "type-fest";
|
import { type JsonObject } from "type-fest";
|
||||||
import modelProviders from "~/modelProviders/modelProviders";
|
import modelProviders from "~/modelProviders/modelProviders";
|
||||||
|
import { prisma } from "~/server/db";
|
||||||
import { wsConnection } from "~/utils/wsConnection";
|
import { wsConnection } from "~/utils/wsConnection";
|
||||||
|
import { runEvalsForOutput } from "../utils/evaluations";
|
||||||
|
import hashPrompt from "../utils/hashPrompt";
|
||||||
|
import parseConstructFn from "../utils/parseConstructFn";
|
||||||
|
import { sleep } from "../utils/sleep";
|
||||||
|
import defineTask from "./defineTask";
|
||||||
|
|
||||||
export type queryLLMJob = {
|
export type QueryModelJob = {
|
||||||
scenarioVariantCellId: string;
|
cellId: string;
|
||||||
|
stream: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
const MAX_AUTO_RETRIES = 10;
|
const MAX_AUTO_RETRIES = 10;
|
||||||
@@ -24,15 +24,16 @@ function calculateDelay(numPreviousTries: number): number {
|
|||||||
return baseDelay + jitter;
|
return baseDelay + jitter;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) => {
|
||||||
const { scenarioVariantCellId } = task;
|
console.log("RUNNING TASK", task);
|
||||||
|
const { cellId, stream } = task;
|
||||||
const cell = await prisma.scenarioVariantCell.findUnique({
|
const cell = await prisma.scenarioVariantCell.findUnique({
|
||||||
where: { id: scenarioVariantCellId },
|
where: { id: cellId },
|
||||||
include: { modelOutput: true },
|
include: { modelOutput: true },
|
||||||
});
|
});
|
||||||
if (!cell) {
|
if (!cell) {
|
||||||
await prisma.scenarioVariantCell.update({
|
await prisma.scenarioVariantCell.update({
|
||||||
where: { id: scenarioVariantCellId },
|
where: { id: cellId },
|
||||||
data: {
|
data: {
|
||||||
statusCode: 404,
|
statusCode: 404,
|
||||||
errorMessage: "Cell not found",
|
errorMessage: "Cell not found",
|
||||||
@@ -47,7 +48,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
await prisma.scenarioVariantCell.update({
|
await prisma.scenarioVariantCell.update({
|
||||||
where: { id: scenarioVariantCellId },
|
where: { id: cellId },
|
||||||
data: {
|
data: {
|
||||||
retrievalStatus: "IN_PROGRESS",
|
retrievalStatus: "IN_PROGRESS",
|
||||||
},
|
},
|
||||||
@@ -58,7 +59,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
});
|
});
|
||||||
if (!variant) {
|
if (!variant) {
|
||||||
await prisma.scenarioVariantCell.update({
|
await prisma.scenarioVariantCell.update({
|
||||||
where: { id: scenarioVariantCellId },
|
where: { id: cellId },
|
||||||
data: {
|
data: {
|
||||||
statusCode: 404,
|
statusCode: 404,
|
||||||
errorMessage: "Prompt Variant not found",
|
errorMessage: "Prompt Variant not found",
|
||||||
@@ -73,7 +74,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
});
|
});
|
||||||
if (!scenario) {
|
if (!scenario) {
|
||||||
await prisma.scenarioVariantCell.update({
|
await prisma.scenarioVariantCell.update({
|
||||||
where: { id: scenarioVariantCellId },
|
where: { id: cellId },
|
||||||
data: {
|
data: {
|
||||||
statusCode: 404,
|
statusCode: 404,
|
||||||
errorMessage: "Scenario not found",
|
errorMessage: "Scenario not found",
|
||||||
@@ -87,7 +88,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
|
|
||||||
if ("error" in prompt) {
|
if ("error" in prompt) {
|
||||||
await prisma.scenarioVariantCell.update({
|
await prisma.scenarioVariantCell.update({
|
||||||
where: { id: scenarioVariantCellId },
|
where: { id: cellId },
|
||||||
data: {
|
data: {
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
errorMessage: prompt.error,
|
errorMessage: prompt.error,
|
||||||
@@ -99,18 +100,9 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
|
|
||||||
const provider = modelProviders[prompt.modelProvider];
|
const provider = modelProviders[prompt.modelProvider];
|
||||||
|
|
||||||
const streamingChannel = provider.shouldStream(prompt.modelInput) ? generateChannel() : null;
|
const onStream = stream
|
||||||
|
|
||||||
if (streamingChannel) {
|
|
||||||
// Save streaming channel so that UI can connect to it
|
|
||||||
await prisma.scenarioVariantCell.update({
|
|
||||||
where: { id: scenarioVariantCellId },
|
|
||||||
data: { streamingChannel },
|
|
||||||
});
|
|
||||||
}
|
|
||||||
const onStream = streamingChannel
|
|
||||||
? (partialOutput: (typeof provider)["_outputSchema"]) => {
|
? (partialOutput: (typeof provider)["_outputSchema"]) => {
|
||||||
wsConnection.emit("message", { channel: streamingChannel, payload: partialOutput });
|
wsConnection.emit("message", { channel: cell.id, payload: partialOutput });
|
||||||
}
|
}
|
||||||
: null;
|
: null;
|
||||||
|
|
||||||
@@ -121,7 +113,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
|
|
||||||
const modelOutput = await prisma.modelOutput.create({
|
const modelOutput = await prisma.modelOutput.create({
|
||||||
data: {
|
data: {
|
||||||
scenarioVariantCellId,
|
scenarioVariantCellId: cellId,
|
||||||
inputHash,
|
inputHash,
|
||||||
output: response.value as Prisma.InputJsonObject,
|
output: response.value as Prisma.InputJsonObject,
|
||||||
timeToComplete: response.timeToComplete,
|
timeToComplete: response.timeToComplete,
|
||||||
@@ -132,7 +124,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
await prisma.scenarioVariantCell.update({
|
await prisma.scenarioVariantCell.update({
|
||||||
where: { id: scenarioVariantCellId },
|
where: { id: cellId },
|
||||||
data: {
|
data: {
|
||||||
statusCode: response.statusCode,
|
statusCode: response.statusCode,
|
||||||
retrievalStatus: "COMPLETE",
|
retrievalStatus: "COMPLETE",
|
||||||
@@ -146,7 +138,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
const delay = calculateDelay(i);
|
const delay = calculateDelay(i);
|
||||||
|
|
||||||
await prisma.scenarioVariantCell.update({
|
await prisma.scenarioVariantCell.update({
|
||||||
where: { id: scenarioVariantCellId },
|
where: { id: cellId },
|
||||||
data: {
|
data: {
|
||||||
errorMessage: response.message,
|
errorMessage: response.message,
|
||||||
statusCode: response.statusCode,
|
statusCode: response.statusCode,
|
||||||
@@ -163,3 +155,21 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
export const queueQueryModel = async (cellId: string, stream: boolean) => {
|
||||||
|
console.log("queueQueryModel", cellId, stream);
|
||||||
|
await Promise.all([
|
||||||
|
prisma.scenarioVariantCell.update({
|
||||||
|
where: {
|
||||||
|
id: cellId,
|
||||||
|
},
|
||||||
|
data: {
|
||||||
|
retrievalStatus: "PENDING",
|
||||||
|
errorMessage: null,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
|
||||||
|
await queryModel.enqueue({ cellId, stream }),
|
||||||
|
console.log("queued"),
|
||||||
|
]);
|
||||||
|
};
|
||||||
@@ -2,39 +2,27 @@ import { type TaskList, run } from "graphile-worker";
|
|||||||
import "dotenv/config";
|
import "dotenv/config";
|
||||||
|
|
||||||
import { env } from "~/env.mjs";
|
import { env } from "~/env.mjs";
|
||||||
import { queryLLM } from "./queryLLM.task";
|
import { queryModel } from "./queryModel.task";
|
||||||
|
|
||||||
const registeredTasks = [queryLLM];
|
console.log("Starting worker");
|
||||||
|
|
||||||
|
const registeredTasks = [queryModel];
|
||||||
|
|
||||||
const taskList = registeredTasks.reduce((acc, task) => {
|
const taskList = registeredTasks.reduce((acc, task) => {
|
||||||
acc[task.task.identifier] = task.task.handler;
|
acc[task.task.identifier] = task.task.handler;
|
||||||
return acc;
|
return acc;
|
||||||
}, {} as TaskList);
|
}, {} as TaskList);
|
||||||
|
|
||||||
async function main() {
|
// Run a worker to execute jobs:
|
||||||
// Run a worker to execute jobs:
|
const runner = await run({
|
||||||
const runner = await run({
|
connectionString: env.DATABASE_URL,
|
||||||
connectionString: env.DATABASE_URL,
|
concurrency: 20,
|
||||||
concurrency: 20,
|
// Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc
|
||||||
// Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc
|
noHandleSignals: false,
|
||||||
noHandleSignals: false,
|
pollInterval: 1000,
|
||||||
pollInterval: 1000,
|
taskList,
|
||||||
// you can set the taskList or taskDirectory but not both
|
|
||||||
taskList,
|
|
||||||
// or:
|
|
||||||
// taskDirectory: `${__dirname}/tasks`,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Immediately await (or otherwise handled) the resulting promise, to avoid
|
|
||||||
// "unhandled rejection" errors causing a process crash in the event of
|
|
||||||
// something going wrong.
|
|
||||||
await runner.promise;
|
|
||||||
|
|
||||||
// If the worker exits (whether through fatal error or otherwise), the above
|
|
||||||
// promise will resolve/reject.
|
|
||||||
}
|
|
||||||
|
|
||||||
main().catch((err) => {
|
|
||||||
console.error("Unhandled error occurred running worker: ", err);
|
|
||||||
process.exit(1);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
console.log("Worker successfully started");
|
||||||
|
|
||||||
|
await runner.promise;
|
||||||
|
|||||||
@@ -1,12 +1,18 @@
|
|||||||
import { type Prisma } from "@prisma/client";
|
import { type Prisma } from "@prisma/client";
|
||||||
import { prisma } from "../db";
|
import { prisma } from "../db";
|
||||||
import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
|
|
||||||
import parseConstructFn from "./parseConstructFn";
|
import parseConstructFn from "./parseConstructFn";
|
||||||
import { type JsonObject } from "type-fest";
|
import { type JsonObject } from "type-fest";
|
||||||
import hashPrompt from "./hashPrompt";
|
import hashPrompt from "./hashPrompt";
|
||||||
import { omit } from "lodash-es";
|
import { omit } from "lodash-es";
|
||||||
|
import { queueQueryModel } from "../tasks/queryModel.task";
|
||||||
|
|
||||||
|
export const generateNewCell = async (
|
||||||
|
variantId: string,
|
||||||
|
scenarioId: string,
|
||||||
|
options?: { stream?: boolean },
|
||||||
|
): Promise<void> => {
|
||||||
|
const stream = options?.stream ?? false;
|
||||||
|
|
||||||
export const generateNewCell = async (variantId: string, scenarioId: string): Promise<void> => {
|
|
||||||
const variant = await prisma.promptVariant.findUnique({
|
const variant = await prisma.promptVariant.findUnique({
|
||||||
where: {
|
where: {
|
||||||
id: variantId,
|
id: variantId,
|
||||||
@@ -98,6 +104,6 @@ export const generateNewCell = async (variantId: string, scenarioId: string): Pr
|
|||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
cell = await queueLLMRetrievalTask(cell.id);
|
await queueQueryModel(cell.id, stream);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,22 +0,0 @@
|
|||||||
import { prisma } from "../db";
|
|
||||||
import { queryLLM } from "../tasks/queryLLM.task";
|
|
||||||
|
|
||||||
export const queueLLMRetrievalTask = async (cellId: string) => {
|
|
||||||
const updatedCell = await prisma.scenarioVariantCell.update({
|
|
||||||
where: {
|
|
||||||
id: cellId,
|
|
||||||
},
|
|
||||||
data: {
|
|
||||||
retrievalStatus: "PENDING",
|
|
||||||
errorMessage: null,
|
|
||||||
},
|
|
||||||
include: {
|
|
||||||
modelOutput: true,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// @ts-expect-error we aren't passing the helpers but that's ok
|
|
||||||
void queryLLM.task.handler({ scenarioVariantCellId: cellId }, { logger: console });
|
|
||||||
|
|
||||||
return updatedCell;
|
|
||||||
};
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
// generate random channel id
|
|
||||||
|
|
||||||
export const generateChannel = () => {
|
|
||||||
return Math.random().toString(36).substring(2, 15) + Math.random().toString(36).substring(2, 15);
|
|
||||||
};
|
|
||||||
@@ -106,3 +106,5 @@ export const useScenarios = () => {
|
|||||||
{ enabled: experiment.data?.id != null },
|
{ enabled: experiment.data?.id != null },
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const useVisibleScenarioIds = () => useScenarios().data?.scenarios.map((s) => s.id) ?? [];
|
||||||
|
|||||||
Reference in New Issue
Block a user