Record experiment updated in more places (#24)

* Record experiment updated in more places

* Update experiment updatedAt in same transaction
This commit is contained in:
arcticfly
2023-07-10 11:00:24 -07:00
committed by GitHub
parent d6a46b9e9d
commit e64a94e06e
3 changed files with 44 additions and 23 deletions

View File

@@ -3,6 +3,7 @@ import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { type OpenAIChatConfig } from "~/server/types"; import { type OpenAIChatConfig } from "~/server/types";
import { getModelName } from "~/server/utils/getModelName"; import { getModelName } from "~/server/utils/getModelName";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
import { calculateTokenCost } from "~/utils/calculateTokenCost"; import { calculateTokenCost } from "~/utils/calculateTokenCost";
export const promptVariantsRouter = createTRPCRouter({ export const promptVariantsRouter = createTRPCRouter({
@@ -98,7 +99,7 @@ export const promptVariantsRouter = createTRPCRouter({
}) })
)._max?.sortIndex ?? 0; )._max?.sortIndex ?? 0;
const newVariant = await prisma.promptVariant.create({ const createNewVariantAction = prisma.promptVariant.create({
data: { data: {
experimentId: input.experimentId, experimentId: input.experimentId,
label: `Prompt Variant ${largestSortIndex + 2}`, label: `Prompt Variant ${largestSortIndex + 2}`,
@@ -107,14 +108,10 @@ export const promptVariantsRouter = createTRPCRouter({
}, },
}); });
await prisma.experiment.update({ const [newVariant] = await prisma.$transaction([
where: { createNewVariantAction,
id: input.experimentId, recordExperimentUpdated(input.experimentId)
}, ]);
data: {
updatedAt: new Date(),
},
});
return newVariant; return newVariant;
}), }),
@@ -139,12 +136,20 @@ export const promptVariantsRouter = createTRPCRouter({
throw new Error(`Prompt Variant with id ${input.id} does not exist`); throw new Error(`Prompt Variant with id ${input.id} does not exist`);
} }
return await prisma.promptVariant.update({ const updatePromptVariantAction = prisma.promptVariant.update({
where: { where: {
id: input.id, id: input.id,
}, },
data: input.updates, data: input.updates,
}); });
const [updatedPromptVariant] = await prisma.$transaction([
updatePromptVariantAction,
recordExperimentUpdated(existing.experimentId)
]);
return updatedPromptVariant;
}), }),
hide: publicProcedure hide: publicProcedure
@@ -154,10 +159,12 @@ export const promptVariantsRouter = createTRPCRouter({
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input }) => {
return await prisma.promptVariant.update({ const updatedPromptVariant = await prisma.promptVariant.update({
where: { id: input.id }, where: { id: input.id },
data: { visible: false }, data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
}); });
return updatedPromptVariant;
}), }),
replaceWithConfig: publicProcedure replaceWithConfig: publicProcedure
@@ -197,7 +204,7 @@ export const promptVariantsRouter = createTRPCRouter({
}); });
// Hide anything with the same uiId besides the new one // Hide anything with the same uiId besides the new one
await prisma.promptVariant.updateMany({ const hideOldVariantsAction = prisma.promptVariant.updateMany({
where: { where: {
uiId: existing.uiId, uiId: existing.uiId,
id: { id: {
@@ -209,6 +216,11 @@ export const promptVariantsRouter = createTRPCRouter({
}, },
}); });
await prisma.$transaction([
hideOldVariantsAction,
recordExperimentUpdated(existing.experimentId)
]);
return newVariant; return newVariant;
}), }),

View File

@@ -2,6 +2,7 @@ import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { autogenerateScenarioValues } from "../autogen"; import { autogenerateScenarioValues } from "../autogen";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
export const scenariosRouter = createTRPCRouter({ export const scenariosRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
@@ -36,7 +37,7 @@ export const scenariosRouter = createTRPCRouter({
}) })
)._max.sortIndex ?? 0; )._max.sortIndex ?? 0;
await prisma.testScenario.create({ const createNewScenarioAction = prisma.testScenario.create({
data: { data: {
experimentId: input.experimentId, experimentId: input.experimentId,
sortIndex: maxSortIndex + 1, sortIndex: maxSortIndex + 1,
@@ -46,20 +47,16 @@ export const scenariosRouter = createTRPCRouter({
}, },
}); });
await prisma.experiment.update({ await prisma.$transaction([
where: { createNewScenarioAction,
id: input.experimentId, recordExperimentUpdated(input.experimentId)
}, ]);
data: {
updatedAt: new Date(),
},
});
}), }),
hide: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => { hide: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
return await prisma.testScenario.update({ return await prisma.testScenario.update({
where: { id: input.id }, where: { id: input.id },
data: { visible: false }, data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
}); });
}), }),

View File

@@ -0,0 +1,12 @@
import { prisma } from "~/server/db";
export const recordExperimentUpdated = (experimentId: string) => {
return prisma.experiment.update({
where: {
id: experimentId,
},
data: {
updatedAt: new Date(),
},
});
};