Assign experiments and datasets to correct org

This commit is contained in:
David Corbitt
2023-08-07 19:10:27 -07:00
parent 9f17d98736
commit 0f9a83cf45
7 changed files with 280 additions and 268 deletions

View File

@@ -15,6 +15,7 @@ import { useRouter } from "next/router";
import { BsPlusSquare } from "react-icons/bs"; import { BsPlusSquare } from "react-icons/bs";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks"; import { useHandledAsyncCallback } from "~/utils/hooks";
import { useAppStore } from "~/state/store";
type DatasetData = { type DatasetData = {
name: string; name: string;
@@ -71,11 +72,12 @@ const CountLabel = ({ label, count }: { label: string; count: number }) => {
export const NewDatasetCard = () => { export const NewDatasetCard = () => {
const router = useRouter(); const router = useRouter();
const selectedOrgId = useAppStore((s) => s.selectedOrgId);
const createMutation = api.datasets.create.useMutation(); const createMutation = api.datasets.create.useMutation();
const [createDataset, isLoading] = useHandledAsyncCallback(async () => { const [createDataset, isLoading] = useHandledAsyncCallback(async () => {
const newDataset = await createMutation.mutateAsync({ label: "New Dataset" }); const newDataset = await createMutation.mutateAsync({ organizationId: selectedOrgId ?? "" });
await router.push({ pathname: "/data/[id]", query: { id: newDataset.id } }); await router.push({ pathname: "/data/[id]", query: { id: newDataset.id } });
}, [createMutation, router]); }, [createMutation, router, selectedOrgId]);
return ( return (
<AspectRatio ratio={1.2} w="full"> <AspectRatio ratio={1.2} w="full">

View File

@@ -15,6 +15,7 @@ import { useRouter } from "next/router";
import { BsPlusSquare } from "react-icons/bs"; import { BsPlusSquare } from "react-icons/bs";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks"; import { useHandledAsyncCallback } from "~/utils/hooks";
import { useAppStore } from "~/state/store";
type ExperimentData = { type ExperimentData = {
testScenarioCount: number; testScenarioCount: number;
@@ -75,11 +76,17 @@ const CountLabel = ({ label, count }: { label: string; count: number }) => {
export const NewExperimentCard = () => { export const NewExperimentCard = () => {
const router = useRouter(); const router = useRouter();
const selectedOrgId = useAppStore((s) => s.selectedOrgId);
const createMutation = api.experiments.create.useMutation(); const createMutation = api.experiments.create.useMutation();
const [createExperiment, isLoading] = useHandledAsyncCallback(async () => { const [createExperiment, isLoading] = useHandledAsyncCallback(async () => {
const newExperiment = await createMutation.mutateAsync({ label: "New Experiment" }); const newExperiment = await createMutation.mutateAsync({
await router.push({ pathname: "/experiments/[id]", query: { id: newExperiment.id } }); organizationId: selectedOrgId ?? "",
}, [createMutation, router]); });
await router.push({
pathname: "/experiments/[id]",
query: { id: newExperiment.id },
});
}, [createMutation, router, selectedOrgId]);
return ( return (
<AspectRatio ratio={1.2} w="full"> <AspectRatio ratio={1.2} w="full">

View File

@@ -1,14 +1,14 @@
import { HStack, Flex, Text } from "@chakra-ui/react"; import { HStack, Flex, Text } from "@chakra-ui/react";
import Link from "next/link"; import Link from "next/link";
import { useSelectedOrg } from "~/utils/hooks"; import { useSelectedOrg } from "~/utils/hooks";
// Have to export only contents here instead of full BreadcrumbItem because Chakra doesn't // Have to export only contents here instead of full BreadcrumbItem because Chakra doesn't
// recognize a BreadcrumbItem exported with this component as a valid child of Breadcrumb. // recognize a BreadcrumbItem exported with this component as a valid child of Breadcrumb.
export default function ProjectBreadcrumbContents({ orgName = "" }: { orgName?: string }) {
export default function ProjectBreadcrumbContents() {
const { data: selectedOrg } = useSelectedOrg(); const { data: selectedOrg } = useSelectedOrg();
orgName = orgName || selectedOrg?.name || "";
return ( return (
<Link href="/home"> <Link href="/home">
<HStack w="full"> <HStack w="full">
@@ -23,10 +23,10 @@ export default function ProjectBreadcrumbContents() {
alignItems="center" alignItems="center"
justifyContent="center" justifyContent="center"
> >
<Text>{selectedOrg?.name[0]?.toUpperCase()}</Text> <Text>{orgName[0]?.toUpperCase()}</Text>
</Flex> </Flex>
<Text display={{ base: "none", md: "block" }} py={1}> <Text display={{ base: "none", md: "block" }} py={1}>
{selectedOrg?.name} {orgName}
</Text> </Text>
</HStack> </HStack>
</Link> </Link>

View File

@@ -60,7 +60,7 @@ export default function Dataset() {
<PageHeaderContainer> <PageHeaderContainer>
<Breadcrumb> <Breadcrumb>
<BreadcrumbItem> <BreadcrumbItem>
<ProjectBreadcrumbContents /> <ProjectBreadcrumbContents orgName={dataset.data?.organization?.name} />
</BreadcrumbItem> </BreadcrumbItem>
<BreadcrumbItem> <BreadcrumbItem>
<Link href="/data"> <Link href="/data">

View File

@@ -110,7 +110,7 @@ export default function Experiment() {
<PageHeaderContainer> <PageHeaderContainer>
<Breadcrumb> <Breadcrumb>
<BreadcrumbItem> <BreadcrumbItem>
<ProjectBreadcrumbContents /> <ProjectBreadcrumbContents orgName={experiment.data?.organization?.name} />
</BreadcrumbItem> </BreadcrumbItem>
<BreadcrumbItem> <BreadcrumbItem>
<Link href="/experiments"> <Link href="/experiments">

View File

@@ -3,11 +3,10 @@ import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { import {
requireCanModifyDataset, requireCanModifyDataset,
requireCanModifyOrganization,
requireCanViewDataset, requireCanViewDataset,
requireCanViewOrganization, requireCanViewOrganization,
requireNothing,
} from "~/utils/accessControl"; } from "~/utils/accessControl";
import userOrg from "~/server/utils/userOrg";
export const datasetsRouter = createTRPCRouter({ export const datasetsRouter = createTRPCRouter({
list: protectedProcedure list: protectedProcedure
@@ -36,27 +35,27 @@ export const datasetsRouter = createTRPCRouter({
await requireCanViewDataset(input.id, ctx); await requireCanViewDataset(input.id, ctx);
return await prisma.dataset.findFirstOrThrow({ return await prisma.dataset.findFirstOrThrow({
where: { id: input.id }, where: { id: input.id },
include: {
organization: true,
},
}); });
}), }),
create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => { create: protectedProcedure
// Anyone can create an experiment .input(z.object({ organizationId: z.string() }))
requireNothing(ctx); .mutation(async ({ input, ctx }) => {
await requireCanModifyOrganization(input.organizationId, ctx);
const numDatasets = await prisma.dataset.count({ const numDatasets = await prisma.dataset.count({
where: { where: {
organization: { organizationId: input.organizationId,
organizationUsers: {
some: { userId: ctx.session.user.id },
},
},
}, },
}); });
return await prisma.dataset.create({ return await prisma.dataset.create({
data: { data: {
name: `Dataset ${numDatasets + 1}`, name: `Dataset ${numDatasets + 1}`,
organizationId: (await userOrg(ctx.session.user.id)).id, organizationId: input.organizationId,
}, },
}); });
}), }),

View File

@@ -8,11 +8,10 @@ import { generateNewCell } from "~/server/utils/generateNewCell";
import { import {
canModifyExperiment, canModifyExperiment,
requireCanModifyExperiment, requireCanModifyExperiment,
requireCanModifyOrganization,
requireCanViewExperiment, requireCanViewExperiment,
requireCanViewOrganization, requireCanViewOrganization,
requireNothing,
} from "~/utils/accessControl"; } from "~/utils/accessControl";
import userOrg from "~/server/utils/userOrg";
import generateTypes from "~/modelProviders/generateTypes"; import generateTypes from "~/modelProviders/generateTypes";
import { promptConstructorVersion } from "~/promptConstructor/version"; import { promptConstructorVersion } from "~/promptConstructor/version";
@@ -47,7 +46,7 @@ export const experimentsRouter = createTRPCRouter({
list: protectedProcedure list: protectedProcedure
.input(z.object({ organizationId: z.string() })) .input(z.object({ organizationId: z.string() }))
.query(async ({ input, ctx }) => { .query(async ({ input, ctx }) => {
await requireCanViewOrganization(input.organizationId, ctx) await requireCanViewOrganization(input.organizationId, ctx);
const experiments = await prisma.experiment.findMany({ const experiments = await prisma.experiment.findMany({
where: { where: {
@@ -90,6 +89,9 @@ export const experimentsRouter = createTRPCRouter({
await requireCanViewExperiment(input.id, ctx); await requireCanViewExperiment(input.id, ctx);
const experiment = await prisma.experiment.findFirstOrThrow({ const experiment = await prisma.experiment.findFirstOrThrow({
where: { id: input.id }, where: { id: input.id },
include: {
organization: true,
},
}); });
const canModify = ctx.session?.user.id const canModify = ctx.session?.user.id
@@ -105,8 +107,11 @@ export const experimentsRouter = createTRPCRouter({
}; };
}), }),
fork: protectedProcedure.input(z.object({ id: z.string() })).mutation(async ({ input, ctx }) => { fork: protectedProcedure
.input(z.object({ id: z.string(), organizationId: z.string() }))
.mutation(async ({ input, ctx }) => {
await requireCanViewExperiment(input.id, ctx); await requireCanViewExperiment(input.id, ctx);
await requireCanModifyOrganization(input.organizationId, ctx);
const [ const [
existingExp, existingExp,
@@ -259,7 +264,7 @@ export const experimentsRouter = createTRPCRouter({
id: newExperimentId, id: newExperimentId,
sortIndex: maxSortIndex + 1, sortIndex: maxSortIndex + 1,
label: `${existingExp.label} (forked)`, label: `${existingExp.label} (forked)`,
organizationId: (await userOrg(ctx.session.user.id)).id, organizationId: input.organizationId,
}, },
}), }),
prisma.promptVariant.createMany({ prisma.promptVariant.createMany({
@@ -288,11 +293,10 @@ export const experimentsRouter = createTRPCRouter({
return newExperimentId; return newExperimentId;
}), }),
create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => { create: protectedProcedure
// Anyone can create an experiment .input(z.object({ organizationId: z.string() }))
requireNothing(ctx); .mutation(async ({ input, ctx }) => {
await requireCanModifyOrganization(input.organizationId, ctx);
const organizationId = (await userOrg(ctx.session.user.id)).id;
const maxSortIndex = const maxSortIndex =
( (
@@ -300,7 +304,7 @@ export const experimentsRouter = createTRPCRouter({
_max: { _max: {
sortIndex: true, sortIndex: true,
}, },
where: { organizationId }, where: { organizationId: input.organizationId },
}) })
)._max?.sortIndex ?? 0; )._max?.sortIndex ?? 0;
@@ -308,7 +312,7 @@ export const experimentsRouter = createTRPCRouter({
data: { data: {
sortIndex: maxSortIndex + 1, sortIndex: maxSortIndex + 1,
label: `Experiment ${maxSortIndex + 1}`, label: `Experiment ${maxSortIndex + 1}`,
organizationId, organizationId: input.organizationId,
}, },
}); });