From 0f9a83cf45e2e3ae6169966ba6815a515f964b3d Mon Sep 17 00:00:00 2001 From: David Corbitt Date: Mon, 7 Aug 2023 19:10:27 -0700 Subject: [PATCH] Assign experiments and datasets to correct org --- app/src/components/datasets/DatasetCard.tsx | 6 +- .../components/experiments/ExperimentCard.tsx | 13 +- .../nav/ProjectBreadcrumbContents.tsx | 10 +- app/src/pages/data/[id].tsx | 2 +- app/src/pages/experiments/[id].tsx | 2 +- app/src/server/api/routers/datasets.router.ts | 39 +- .../server/api/routers/experiments.router.ts | 476 +++++++++--------- 7 files changed, 280 insertions(+), 268 deletions(-) diff --git a/app/src/components/datasets/DatasetCard.tsx b/app/src/components/datasets/DatasetCard.tsx index d5f9344..fc59911 100644 --- a/app/src/components/datasets/DatasetCard.tsx +++ b/app/src/components/datasets/DatasetCard.tsx @@ -15,6 +15,7 @@ import { useRouter } from "next/router"; import { BsPlusSquare } from "react-icons/bs"; import { api } from "~/utils/api"; import { useHandledAsyncCallback } from "~/utils/hooks"; +import { useAppStore } from "~/state/store"; type DatasetData = { name: string; @@ -71,11 +72,12 @@ const CountLabel = ({ label, count }: { label: string; count: number }) => { export const NewDatasetCard = () => { const router = useRouter(); + const selectedOrgId = useAppStore((s) => s.selectedOrgId); const createMutation = api.datasets.create.useMutation(); const [createDataset, isLoading] = useHandledAsyncCallback(async () => { - const newDataset = await createMutation.mutateAsync({ label: "New Dataset" }); + const newDataset = await createMutation.mutateAsync({ organizationId: selectedOrgId ?? "" }); await router.push({ pathname: "/data/[id]", query: { id: newDataset.id } }); - }, [createMutation, router]); + }, [createMutation, router, selectedOrgId]); return ( diff --git a/app/src/components/experiments/ExperimentCard.tsx b/app/src/components/experiments/ExperimentCard.tsx index 9a8f4c3..01853af 100644 --- a/app/src/components/experiments/ExperimentCard.tsx +++ b/app/src/components/experiments/ExperimentCard.tsx @@ -15,6 +15,7 @@ import { useRouter } from "next/router"; import { BsPlusSquare } from "react-icons/bs"; import { api } from "~/utils/api"; import { useHandledAsyncCallback } from "~/utils/hooks"; +import { useAppStore } from "~/state/store"; type ExperimentData = { testScenarioCount: number; @@ -75,11 +76,17 @@ const CountLabel = ({ label, count }: { label: string; count: number }) => { export const NewExperimentCard = () => { const router = useRouter(); + const selectedOrgId = useAppStore((s) => s.selectedOrgId); const createMutation = api.experiments.create.useMutation(); const [createExperiment, isLoading] = useHandledAsyncCallback(async () => { - const newExperiment = await createMutation.mutateAsync({ label: "New Experiment" }); - await router.push({ pathname: "/experiments/[id]", query: { id: newExperiment.id } }); - }, [createMutation, router]); + const newExperiment = await createMutation.mutateAsync({ + organizationId: selectedOrgId ?? "", + }); + await router.push({ + pathname: "/experiments/[id]", + query: { id: newExperiment.id }, + }); + }, [createMutation, router, selectedOrgId]); return ( diff --git a/app/src/components/nav/ProjectBreadcrumbContents.tsx b/app/src/components/nav/ProjectBreadcrumbContents.tsx index 68de209..c2ecdbc 100644 --- a/app/src/components/nav/ProjectBreadcrumbContents.tsx +++ b/app/src/components/nav/ProjectBreadcrumbContents.tsx @@ -1,14 +1,14 @@ import { HStack, Flex, Text } from "@chakra-ui/react"; import Link from "next/link"; - import { useSelectedOrg } from "~/utils/hooks"; // Have to export only contents here instead of full BreadcrumbItem because Chakra doesn't // recognize a BreadcrumbItem exported with this component as a valid child of Breadcrumb. - -export default function ProjectBreadcrumbContents() { +export default function ProjectBreadcrumbContents({ orgName = "" }: { orgName?: string }) { const { data: selectedOrg } = useSelectedOrg(); + orgName = orgName || selectedOrg?.name || ""; + return ( @@ -23,10 +23,10 @@ export default function ProjectBreadcrumbContents() { alignItems="center" justifyContent="center" > - {selectedOrg?.name[0]?.toUpperCase()} + {orgName[0]?.toUpperCase()} - {selectedOrg?.name} + {orgName} diff --git a/app/src/pages/data/[id].tsx b/app/src/pages/data/[id].tsx index 63921a6..b4a79fb 100644 --- a/app/src/pages/data/[id].tsx +++ b/app/src/pages/data/[id].tsx @@ -60,7 +60,7 @@ export default function Dataset() { - + diff --git a/app/src/pages/experiments/[id].tsx b/app/src/pages/experiments/[id].tsx index 04ce91d..d1df7cd 100644 --- a/app/src/pages/experiments/[id].tsx +++ b/app/src/pages/experiments/[id].tsx @@ -110,7 +110,7 @@ export default function Experiment() { - + diff --git a/app/src/server/api/routers/datasets.router.ts b/app/src/server/api/routers/datasets.router.ts index 7ba7787..80bfd26 100644 --- a/app/src/server/api/routers/datasets.router.ts +++ b/app/src/server/api/routers/datasets.router.ts @@ -3,11 +3,10 @@ import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/ import { prisma } from "~/server/db"; import { requireCanModifyDataset, + requireCanModifyOrganization, requireCanViewDataset, requireCanViewOrganization, - requireNothing, } from "~/utils/accessControl"; -import userOrg from "~/server/utils/userOrg"; export const datasetsRouter = createTRPCRouter({ list: protectedProcedure @@ -36,30 +35,30 @@ export const datasetsRouter = createTRPCRouter({ await requireCanViewDataset(input.id, ctx); return await prisma.dataset.findFirstOrThrow({ where: { id: input.id }, + include: { + organization: true, + }, }); }), - create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => { - // Anyone can create an experiment - requireNothing(ctx); + create: protectedProcedure + .input(z.object({ organizationId: z.string() })) + .mutation(async ({ input, ctx }) => { + await requireCanModifyOrganization(input.organizationId, ctx); - const numDatasets = await prisma.dataset.count({ - where: { - organization: { - organizationUsers: { - some: { userId: ctx.session.user.id }, - }, + const numDatasets = await prisma.dataset.count({ + where: { + organizationId: input.organizationId, }, - }, - }); + }); - return await prisma.dataset.create({ - data: { - name: `Dataset ${numDatasets + 1}`, - organizationId: (await userOrg(ctx.session.user.id)).id, - }, - }); - }), + return await prisma.dataset.create({ + data: { + name: `Dataset ${numDatasets + 1}`, + organizationId: input.organizationId, + }, + }); + }), update: protectedProcedure .input(z.object({ id: z.string(), updates: z.object({ name: z.string() }) })) diff --git a/app/src/server/api/routers/experiments.router.ts b/app/src/server/api/routers/experiments.router.ts index 699daa6..b69e3d2 100644 --- a/app/src/server/api/routers/experiments.router.ts +++ b/app/src/server/api/routers/experiments.router.ts @@ -8,11 +8,10 @@ import { generateNewCell } from "~/server/utils/generateNewCell"; import { canModifyExperiment, requireCanModifyExperiment, + requireCanModifyOrganization, requireCanViewExperiment, requireCanViewOrganization, - requireNothing, } from "~/utils/accessControl"; -import userOrg from "~/server/utils/userOrg"; import generateTypes from "~/modelProviders/generateTypes"; import { promptConstructorVersion } from "~/promptConstructor/version"; @@ -47,7 +46,7 @@ export const experimentsRouter = createTRPCRouter({ list: protectedProcedure .input(z.object({ organizationId: z.string() })) .query(async ({ input, ctx }) => { - await requireCanViewOrganization(input.organizationId, ctx) + await requireCanViewOrganization(input.organizationId, ctx); const experiments = await prisma.experiment.findMany({ where: { @@ -90,6 +89,9 @@ export const experimentsRouter = createTRPCRouter({ await requireCanViewExperiment(input.id, ctx); const experiment = await prisma.experiment.findFirstOrThrow({ where: { id: input.id }, + include: { + organization: true, + }, }); const canModify = ctx.session?.user.id @@ -105,222 +107,224 @@ export const experimentsRouter = createTRPCRouter({ }; }), - fork: protectedProcedure.input(z.object({ id: z.string() })).mutation(async ({ input, ctx }) => { - await requireCanViewExperiment(input.id, ctx); + fork: protectedProcedure + .input(z.object({ id: z.string(), organizationId: z.string() })) + .mutation(async ({ input, ctx }) => { + await requireCanViewExperiment(input.id, ctx); + await requireCanModifyOrganization(input.organizationId, ctx); - const [ - existingExp, - existingVariants, - existingScenarios, - existingCells, - evaluations, - templateVariables, - ] = await prisma.$transaction([ - prisma.experiment.findUniqueOrThrow({ - where: { - id: input.id, - }, - }), - prisma.promptVariant.findMany({ - where: { - experimentId: input.id, - visible: true, - }, - }), - prisma.testScenario.findMany({ - where: { - experimentId: input.id, - visible: true, - }, - }), - prisma.scenarioVariantCell.findMany({ - where: { - testScenario: { - visible: true, + const [ + existingExp, + existingVariants, + existingScenarios, + existingCells, + evaluations, + templateVariables, + ] = await prisma.$transaction([ + prisma.experiment.findUniqueOrThrow({ + where: { + id: input.id, }, - promptVariant: { + }), + prisma.promptVariant.findMany({ + where: { experimentId: input.id, visible: true, }, - }, - include: { - modelResponses: { - include: { - outputEvaluations: true, + }), + prisma.testScenario.findMany({ + where: { + experimentId: input.id, + visible: true, + }, + }), + prisma.scenarioVariantCell.findMany({ + where: { + testScenario: { + visible: true, + }, + promptVariant: { + experimentId: input.id, + visible: true, }, }, - }, - }), - prisma.evaluation.findMany({ - where: { - experimentId: input.id, - }, - }), - prisma.templateVariable.findMany({ - where: { - experimentId: input.id, - }, - }), - ]); + include: { + modelResponses: { + include: { + outputEvaluations: true, + }, + }, + }, + }), + prisma.evaluation.findMany({ + where: { + experimentId: input.id, + }, + }), + prisma.templateVariable.findMany({ + where: { + experimentId: input.id, + }, + }), + ]); - const newExperimentId = uuidv4(); + const newExperimentId = uuidv4(); - const existingToNewVariantIds = new Map(); - const variantsToCreate: Prisma.PromptVariantCreateManyInput[] = []; - for (const variant of existingVariants) { - const newVariantId = uuidv4(); - existingToNewVariantIds.set(variant.id, newVariantId); - variantsToCreate.push({ - ...variant, - id: newVariantId, - experimentId: newExperimentId, - }); - } - - const existingToNewScenarioIds = new Map(); - const scenariosToCreate: Prisma.TestScenarioCreateManyInput[] = []; - for (const scenario of existingScenarios) { - const newScenarioId = uuidv4(); - existingToNewScenarioIds.set(scenario.id, newScenarioId); - scenariosToCreate.push({ - ...scenario, - id: newScenarioId, - experimentId: newExperimentId, - variableValues: scenario.variableValues as Prisma.InputJsonValue, - }); - } - - const existingToNewEvaluationIds = new Map(); - const evaluationsToCreate: Prisma.EvaluationCreateManyInput[] = []; - for (const evaluation of evaluations) { - const newEvaluationId = uuidv4(); - existingToNewEvaluationIds.set(evaluation.id, newEvaluationId); - evaluationsToCreate.push({ - ...evaluation, - id: newEvaluationId, - experimentId: newExperimentId, - }); - } - - const cellsToCreate: Prisma.ScenarioVariantCellCreateManyInput[] = []; - const modelResponsesToCreate: Prisma.ModelResponseCreateManyInput[] = []; - const outputEvaluationsToCreate: Prisma.OutputEvaluationCreateManyInput[] = []; - for (const cell of existingCells) { - const newCellId = uuidv4(); - const { modelResponses, ...cellData } = cell; - cellsToCreate.push({ - ...cellData, - id: newCellId, - promptVariantId: existingToNewVariantIds.get(cell.promptVariantId) ?? "", - testScenarioId: existingToNewScenarioIds.get(cell.testScenarioId) ?? "", - prompt: (cell.prompt as Prisma.InputJsonValue) ?? undefined, - }); - for (const modelResponse of modelResponses) { - const newModelResponseId = uuidv4(); - const { outputEvaluations, ...modelResponseData } = modelResponse; - modelResponsesToCreate.push({ - ...modelResponseData, - id: newModelResponseId, - scenarioVariantCellId: newCellId, - output: (modelResponse.output as Prisma.InputJsonValue) ?? undefined, + const existingToNewVariantIds = new Map(); + const variantsToCreate: Prisma.PromptVariantCreateManyInput[] = []; + for (const variant of existingVariants) { + const newVariantId = uuidv4(); + existingToNewVariantIds.set(variant.id, newVariantId); + variantsToCreate.push({ + ...variant, + id: newVariantId, + experimentId: newExperimentId, }); - for (const evaluation of outputEvaluations) { - outputEvaluationsToCreate.push({ - ...evaluation, - id: uuidv4(), - modelResponseId: newModelResponseId, - evaluationId: existingToNewEvaluationIds.get(evaluation.evaluationId) ?? "", + } + + const existingToNewScenarioIds = new Map(); + const scenariosToCreate: Prisma.TestScenarioCreateManyInput[] = []; + for (const scenario of existingScenarios) { + const newScenarioId = uuidv4(); + existingToNewScenarioIds.set(scenario.id, newScenarioId); + scenariosToCreate.push({ + ...scenario, + id: newScenarioId, + experimentId: newExperimentId, + variableValues: scenario.variableValues as Prisma.InputJsonValue, + }); + } + + const existingToNewEvaluationIds = new Map(); + const evaluationsToCreate: Prisma.EvaluationCreateManyInput[] = []; + for (const evaluation of evaluations) { + const newEvaluationId = uuidv4(); + existingToNewEvaluationIds.set(evaluation.id, newEvaluationId); + evaluationsToCreate.push({ + ...evaluation, + id: newEvaluationId, + experimentId: newExperimentId, + }); + } + + const cellsToCreate: Prisma.ScenarioVariantCellCreateManyInput[] = []; + const modelResponsesToCreate: Prisma.ModelResponseCreateManyInput[] = []; + const outputEvaluationsToCreate: Prisma.OutputEvaluationCreateManyInput[] = []; + for (const cell of existingCells) { + const newCellId = uuidv4(); + const { modelResponses, ...cellData } = cell; + cellsToCreate.push({ + ...cellData, + id: newCellId, + promptVariantId: existingToNewVariantIds.get(cell.promptVariantId) ?? "", + testScenarioId: existingToNewScenarioIds.get(cell.testScenarioId) ?? "", + prompt: (cell.prompt as Prisma.InputJsonValue) ?? undefined, + }); + for (const modelResponse of modelResponses) { + const newModelResponseId = uuidv4(); + const { outputEvaluations, ...modelResponseData } = modelResponse; + modelResponsesToCreate.push({ + ...modelResponseData, + id: newModelResponseId, + scenarioVariantCellId: newCellId, + output: (modelResponse.output as Prisma.InputJsonValue) ?? undefined, }); + for (const evaluation of outputEvaluations) { + outputEvaluationsToCreate.push({ + ...evaluation, + id: uuidv4(), + modelResponseId: newModelResponseId, + evaluationId: existingToNewEvaluationIds.get(evaluation.evaluationId) ?? "", + }); + } } } - } - const templateVariablesToCreate: Prisma.TemplateVariableCreateManyInput[] = []; - for (const templateVariable of templateVariables) { - templateVariablesToCreate.push({ - ...templateVariable, - id: uuidv4(), - experimentId: newExperimentId, - }); - } + const templateVariablesToCreate: Prisma.TemplateVariableCreateManyInput[] = []; + for (const templateVariable of templateVariables) { + templateVariablesToCreate.push({ + ...templateVariable, + id: uuidv4(), + experimentId: newExperimentId, + }); + } - const maxSortIndex = - ( - await prisma.experiment.aggregate({ - _max: { - sortIndex: true, + const maxSortIndex = + ( + await prisma.experiment.aggregate({ + _max: { + sortIndex: true, + }, + }) + )._max?.sortIndex ?? 0; + + await prisma.$transaction([ + prisma.experiment.create({ + data: { + id: newExperimentId, + sortIndex: maxSortIndex + 1, + label: `${existingExp.label} (forked)`, + organizationId: input.organizationId, }, - }) - )._max?.sortIndex ?? 0; + }), + prisma.promptVariant.createMany({ + data: variantsToCreate, + }), + prisma.testScenario.createMany({ + data: scenariosToCreate, + }), + prisma.scenarioVariantCell.createMany({ + data: cellsToCreate, + }), + prisma.modelResponse.createMany({ + data: modelResponsesToCreate, + }), + prisma.evaluation.createMany({ + data: evaluationsToCreate, + }), + prisma.outputEvaluation.createMany({ + data: outputEvaluationsToCreate, + }), + prisma.templateVariable.createMany({ + data: templateVariablesToCreate, + }), + ]); - await prisma.$transaction([ - prisma.experiment.create({ + return newExperimentId; + }), + + create: protectedProcedure + .input(z.object({ organizationId: z.string() })) + .mutation(async ({ input, ctx }) => { + await requireCanModifyOrganization(input.organizationId, ctx); + + const maxSortIndex = + ( + await prisma.experiment.aggregate({ + _max: { + sortIndex: true, + }, + where: { organizationId: input.organizationId }, + }) + )._max?.sortIndex ?? 0; + + const exp = await prisma.experiment.create({ data: { - id: newExperimentId, sortIndex: maxSortIndex + 1, - label: `${existingExp.label} (forked)`, - organizationId: (await userOrg(ctx.session.user.id)).id, + label: `Experiment ${maxSortIndex + 1}`, + organizationId: input.organizationId, }, - }), - prisma.promptVariant.createMany({ - data: variantsToCreate, - }), - prisma.testScenario.createMany({ - data: scenariosToCreate, - }), - prisma.scenarioVariantCell.createMany({ - data: cellsToCreate, - }), - prisma.modelResponse.createMany({ - data: modelResponsesToCreate, - }), - prisma.evaluation.createMany({ - data: evaluationsToCreate, - }), - prisma.outputEvaluation.createMany({ - data: outputEvaluationsToCreate, - }), - prisma.templateVariable.createMany({ - data: templateVariablesToCreate, - }), - ]); + }); - return newExperimentId; - }), - - create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => { - // Anyone can create an experiment - requireNothing(ctx); - - const organizationId = (await userOrg(ctx.session.user.id)).id; - - const maxSortIndex = - ( - await prisma.experiment.aggregate({ - _max: { - sortIndex: true, - }, - where: { organizationId }, - }) - )._max?.sortIndex ?? 0; - - const exp = await prisma.experiment.create({ - data: { - sortIndex: maxSortIndex + 1, - label: `Experiment ${maxSortIndex + 1}`, - organizationId, - }, - }); - - const [variant, _, scenario1, scenario2, scenario3] = await prisma.$transaction([ - prisma.promptVariant.create({ - data: { - experimentId: exp.id, - label: "Prompt Variant 1", - sortIndex: 0, - // The interpolated $ is necessary until dedent incorporates - // https://github.com/dmnd/dedent/pull/46 - promptConstructor: dedent` + const [variant, _, scenario1, scenario2, scenario3] = await prisma.$transaction([ + prisma.promptVariant.create({ + data: { + experimentId: exp.id, + label: "Prompt Variant 1", + sortIndex: 0, + // The interpolated $ is necessary until dedent incorporates + // https://github.com/dmnd/dedent/pull/46 + promptConstructor: dedent` /** * Use Javascript to define an OpenAI chat completion * (https://platform.openai.com/docs/api-reference/chat/create). @@ -339,49 +343,49 @@ export const experimentsRouter = createTRPCRouter({ }, ], });`, - model: "gpt-3.5-turbo-0613", - modelProvider: "openai/ChatCompletion", - promptConstructorVersion, - }, - }), - prisma.templateVariable.create({ - data: { - experimentId: exp.id, - label: "language", - }, - }), - prisma.testScenario.create({ - data: { - experimentId: exp.id, - variableValues: { - language: "English", + model: "gpt-3.5-turbo-0613", + modelProvider: "openai/ChatCompletion", + promptConstructorVersion, }, - }, - }), - prisma.testScenario.create({ - data: { - experimentId: exp.id, - variableValues: { - language: "Spanish", + }), + prisma.templateVariable.create({ + data: { + experimentId: exp.id, + label: "language", }, - }, - }), - prisma.testScenario.create({ - data: { - experimentId: exp.id, - variableValues: { - language: "German", + }), + prisma.testScenario.create({ + data: { + experimentId: exp.id, + variableValues: { + language: "English", + }, }, - }, - }), - ]); + }), + prisma.testScenario.create({ + data: { + experimentId: exp.id, + variableValues: { + language: "Spanish", + }, + }, + }), + prisma.testScenario.create({ + data: { + experimentId: exp.id, + variableValues: { + language: "German", + }, + }, + }), + ]); - await generateNewCell(variant.id, scenario1.id); - await generateNewCell(variant.id, scenario2.id); - await generateNewCell(variant.id, scenario3.id); + await generateNewCell(variant.id, scenario1.id); + await generateNewCell(variant.id, scenario2.id); + await generateNewCell(variant.id, scenario3.id); - return exp; - }), + return exp; + }), update: protectedProcedure .input(z.object({ id: z.string(), updates: z.object({ label: z.string() }) }))