diff --git a/src/server/api/routers/promptVariants.router.ts b/src/server/api/routers/promptVariants.router.ts index 71c449a..c90251d 100644 --- a/src/server/api/routers/promptVariants.router.ts +++ b/src/server/api/routers/promptVariants.router.ts @@ -3,6 +3,7 @@ import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { prisma } from "~/server/db"; import { type OpenAIChatConfig } from "~/server/types"; import { getModelName } from "~/server/utils/getModelName"; +import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated"; import { calculateTokenCost } from "~/utils/calculateTokenCost"; export const promptVariantsRouter = createTRPCRouter({ @@ -98,7 +99,7 @@ export const promptVariantsRouter = createTRPCRouter({ }) )._max?.sortIndex ?? 0; - const newVariant = await prisma.promptVariant.create({ + const createNewVariantAction = prisma.promptVariant.create({ data: { experimentId: input.experimentId, label: `Prompt Variant ${largestSortIndex + 2}`, @@ -107,14 +108,10 @@ export const promptVariantsRouter = createTRPCRouter({ }, }); - await prisma.experiment.update({ - where: { - id: input.experimentId, - }, - data: { - updatedAt: new Date(), - }, - }); + const [newVariant] = await prisma.$transaction([ + createNewVariantAction, + recordExperimentUpdated(input.experimentId) + ]); return newVariant; }), @@ -139,12 +136,20 @@ export const promptVariantsRouter = createTRPCRouter({ throw new Error(`Prompt Variant with id ${input.id} does not exist`); } - return await prisma.promptVariant.update({ + const updatePromptVariantAction = prisma.promptVariant.update({ where: { id: input.id, }, data: input.updates, }); + + const [updatedPromptVariant] = await prisma.$transaction([ + updatePromptVariantAction, + recordExperimentUpdated(existing.experimentId) + ]); + + return updatedPromptVariant; + }), hide: publicProcedure @@ -154,10 +159,12 @@ export const promptVariantsRouter = createTRPCRouter({ }), ) .mutation(async ({ input }) => { - return await prisma.promptVariant.update({ + const updatedPromptVariant = await prisma.promptVariant.update({ where: { id: input.id }, - data: { visible: false }, + data: { visible: false, experiment: { update: { updatedAt: new Date() } } }, }); + + return updatedPromptVariant; }), replaceWithConfig: publicProcedure @@ -197,7 +204,7 @@ export const promptVariantsRouter = createTRPCRouter({ }); // Hide anything with the same uiId besides the new one - await prisma.promptVariant.updateMany({ + const hideOldVariantsAction = prisma.promptVariant.updateMany({ where: { uiId: existing.uiId, id: { @@ -209,6 +216,11 @@ export const promptVariantsRouter = createTRPCRouter({ }, }); + await prisma.$transaction([ + hideOldVariantsAction, + recordExperimentUpdated(existing.experimentId) + ]); + return newVariant; }), diff --git a/src/server/api/routers/scenarios.router.ts b/src/server/api/routers/scenarios.router.ts index ee23b36..f5b3d1e 100644 --- a/src/server/api/routers/scenarios.router.ts +++ b/src/server/api/routers/scenarios.router.ts @@ -2,6 +2,7 @@ import { z } from "zod"; import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { prisma } from "~/server/db"; import { autogenerateScenarioValues } from "../autogen"; +import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated"; export const scenariosRouter = createTRPCRouter({ list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { @@ -36,7 +37,7 @@ export const scenariosRouter = createTRPCRouter({ }) )._max.sortIndex ?? 0; - await prisma.testScenario.create({ + const createNewScenarioAction = prisma.testScenario.create({ data: { experimentId: input.experimentId, sortIndex: maxSortIndex + 1, @@ -46,20 +47,16 @@ export const scenariosRouter = createTRPCRouter({ }, }); - await prisma.experiment.update({ - where: { - id: input.experimentId, - }, - data: { - updatedAt: new Date(), - }, - }); + await prisma.$transaction([ + createNewScenarioAction, + recordExperimentUpdated(input.experimentId) + ]); }), hide: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => { return await prisma.testScenario.update({ where: { id: input.id }, - data: { visible: false }, + data: { visible: false, experiment: { update: { updatedAt: new Date() } } }, }); }), diff --git a/src/server/utils/recordExperimentUpdated.ts b/src/server/utils/recordExperimentUpdated.ts new file mode 100644 index 0000000..f727845 --- /dev/null +++ b/src/server/utils/recordExperimentUpdated.ts @@ -0,0 +1,12 @@ +import { prisma } from "~/server/db"; + +export const recordExperimentUpdated = (experimentId: string) => { + return prisma.experiment.update({ + where: { + id: experimentId, + }, + data: { + updatedAt: new Date(), + }, + }); +};