diff --git a/app/src/components/StatsCard.tsx b/app/src/components/StatsCard.tsx new file mode 100644 index 0000000..56eacfb --- /dev/null +++ b/app/src/components/StatsCard.tsx @@ -0,0 +1,28 @@ +import { VStack, HStack, type StackProps, Text, Divider } from "@chakra-ui/react"; +import Link, { type LinkProps } from "next/link"; + +const StatsCard = ({ + title, + href, + children, + ...rest +}: { title: string; href: string } & StackProps & LinkProps) => { + return ( + + + + {title} + + + + View all + + + + + {children} + + ); +}; + +export default StatsCard; diff --git a/app/src/components/nav/UserMenu.tsx b/app/src/components/nav/UserMenu.tsx index 8960e5b..9d538bc 100644 --- a/app/src/components/nav/UserMenu.tsx +++ b/app/src/components/nav/UserMenu.tsx @@ -9,6 +9,7 @@ import { PopoverContent, Link, type StackProps, + Box, } from "@chakra-ui/react"; import { type Session } from "next-auth"; import { signOut } from "next-auth/react"; @@ -16,7 +17,6 @@ import { BsBoxArrowRight, BsChevronRight, BsPersonCircle } from "react-icons/bs" import NavSidebarOption from "./NavSidebarOption"; export default function UserMenu({ user, ...rest }: { user: Session } & StackProps) { - const profileImage = user.user.image ? ( profile picture ) : ( @@ -27,26 +27,28 @@ export default function UserMenu({ user, ...rest }: { user: Session } & StackPro <> - - - {profileImage} - - - {user.user.name} - - - {/* {user.user.email} */} - - - - - + + + + {profileImage} + + + {user.user.name} + + + {/* {user.user.email} */} + + + + + + diff --git a/app/src/pages/data/index.tsx b/app/src/pages/data/index.tsx index 843a63e..bbc99ea 100644 --- a/app/src/pages/data/index.tsx +++ b/app/src/pages/data/index.tsx @@ -9,7 +9,6 @@ import { Link, } from "@chakra-ui/react"; import AppShell from "~/components/nav/AppShell"; -import { api } from "~/utils/api"; import { signIn, useSession } from "next-auth/react"; import { RiDatabase2Line } from "react-icons/ri"; import { @@ -19,9 +18,10 @@ import { } from "~/components/datasets/DatasetCard"; import PageHeaderContainer from "~/components/nav/PageHeaderContainer"; import ProjectBreadcrumbContents from "~/components/nav/ProjectBreadcrumbContents"; +import { useDatasets } from "~/utils/hooks"; export default function DatasetsPage() { - const datasets = api.datasets.list.useQuery(); + const datasets = useDatasets(); const user = useSession().data; const authLoading = useSession().status === "loading"; diff --git a/app/src/pages/experiments/index.tsx b/app/src/pages/experiments/index.tsx index 59e8f25..666d8e5 100644 --- a/app/src/pages/experiments/index.tsx +++ b/app/src/pages/experiments/index.tsx @@ -10,7 +10,6 @@ import { } from "@chakra-ui/react"; import { RiFlaskLine } from "react-icons/ri"; import AppShell from "~/components/nav/AppShell"; -import { api } from "~/utils/api"; import { ExperimentCard, ExperimentCardSkeleton, @@ -19,9 +18,10 @@ import { import { signIn, useSession } from "next-auth/react"; import PageHeaderContainer from "~/components/nav/PageHeaderContainer"; import ProjectBreadcrumbContents from "~/components/nav/ProjectBreadcrumbContents"; +import { useExperiments } from "~/utils/hooks"; export default function ExperimentsPage() { - const experiments = api.experiments.list.useQuery(); + const experiments = useExperiments(); const user = useSession().data; const authLoading = useSession().status === "loading"; diff --git a/app/src/pages/home/index.tsx b/app/src/pages/home/index.tsx index 6927cd8..7f9b3af 100644 --- a/app/src/pages/home/index.tsx +++ b/app/src/pages/home/index.tsx @@ -1,30 +1,15 @@ -import { Breadcrumb, BreadcrumbItem, Text } from "@chakra-ui/react"; -import { useEffect, useState } from "react"; +import { Breadcrumb, BreadcrumbItem, Divider, Text, VStack } from "@chakra-ui/react"; + import AppShell from "~/components/nav/AppShell"; import PageHeaderContainer from "~/components/nav/PageHeaderContainer"; import ProjectBreadcrumbContents from "~/components/nav/ProjectBreadcrumbContents"; -import { api } from "~/utils/api"; -import { useHandledAsyncCallback, useSelectedOrg } from "~/utils/hooks"; +import { useExperiments, useSelectedOrg } from "~/utils/hooks"; export default function HomePage() { - const utils = api.useContext(); const { data: selectedOrg } = useSelectedOrg(); - const updateMutation = api.organizations.update.useMutation(); - const [onSaveName] = useHandledAsyncCallback(async () => { - if (name && name !== selectedOrg?.name && selectedOrg?.id) { - await updateMutation.mutateAsync({ - id: selectedOrg.id, - updates: { name }, - }); - await Promise.all([utils.organizations.get.invalidate({ id: selectedOrg.id })]); - } - }, [updateMutation, selectedOrg]); + const experiments = useExperiments(); - const [name, setName] = useState(selectedOrg?.name); - useEffect(() => { - setName(selectedOrg?.name); - }, [selectedOrg?.name]); return ( @@ -33,10 +18,31 @@ export default function HomePage() { - Home + Homepage + + + {selectedOrg?.name} + + + {/* TODO: Add more dashboard cards (one looks weird) */} + {/* + + + {experiments.data?.slice(0, 5).map((exp) => ( + + + {exp.label} + Last updated {formatTimePast(exp.updatedAt)} + + + ))} + + + */} + ); } diff --git a/app/src/server/api/routers/datasets.router.ts b/app/src/server/api/routers/datasets.router.ts index b25fde4..7ba7787 100644 --- a/app/src/server/api/routers/datasets.router.ts +++ b/app/src/server/api/routers/datasets.router.ts @@ -4,35 +4,33 @@ import { prisma } from "~/server/db"; import { requireCanModifyDataset, requireCanViewDataset, + requireCanViewOrganization, requireNothing, } from "~/utils/accessControl"; import userOrg from "~/server/utils/userOrg"; export const datasetsRouter = createTRPCRouter({ - list: protectedProcedure.query(async ({ ctx }) => { - // Anyone can list experiments - requireNothing(ctx); + list: protectedProcedure + .input(z.object({ organizationId: z.string() })) + .query(async ({ input, ctx }) => { + await requireCanViewOrganization(input.organizationId, ctx); - const datasets = await prisma.dataset.findMany({ - where: { - organization: { - organizationUsers: { - some: { userId: ctx.session.user.id }, + const datasets = await prisma.dataset.findMany({ + where: { + organizationId: input.organizationId, + }, + orderBy: { + createdAt: "desc", + }, + include: { + _count: { + select: { datasetEntries: true }, }, }, - }, - orderBy: { - createdAt: "desc", - }, - include: { - _count: { - select: { datasetEntries: true }, - }, - }, - }); + }); - return datasets; - }), + return datasets; + }), get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => { await requireCanViewDataset(input.id, ctx); diff --git a/app/src/server/api/routers/experiments.router.ts b/app/src/server/api/routers/experiments.router.ts index 2008baa..699daa6 100644 --- a/app/src/server/api/routers/experiments.router.ts +++ b/app/src/server/api/routers/experiments.router.ts @@ -9,6 +9,7 @@ import { canModifyExperiment, requireCanModifyExperiment, requireCanViewExperiment, + requireCanViewOrganization, requireNothing, } from "~/utils/accessControl"; import userOrg from "~/server/utils/userOrg"; @@ -43,50 +44,47 @@ export const experimentsRouter = createTRPCRouter({ testScenarioCount, }; }), - list: protectedProcedure.query(async ({ ctx }) => { - // Anyone can list experiments - requireNothing(ctx); + list: protectedProcedure + .input(z.object({ organizationId: z.string() })) + .query(async ({ input, ctx }) => { + await requireCanViewOrganization(input.organizationId, ctx) - const experiments = await prisma.experiment.findMany({ - where: { - organization: { - organizationUsers: { - some: { userId: ctx.session.user.id }, - }, + const experiments = await prisma.experiment.findMany({ + where: { + organizationId: input.organizationId, }, - }, - orderBy: { - sortIndex: "desc", - }, - }); + orderBy: { + sortIndex: "desc", + }, + }); - // TODO: look for cleaner way to do this. Maybe aggregate? - const experimentsWithCounts = await Promise.all( - experiments.map(async (experiment) => { - const visibleTestScenarioCount = await prisma.testScenario.count({ - where: { - experimentId: experiment.id, - visible: true, - }, - }); + // TODO: look for cleaner way to do this. Maybe aggregate? + const experimentsWithCounts = await Promise.all( + experiments.map(async (experiment) => { + const visibleTestScenarioCount = await prisma.testScenario.count({ + where: { + experimentId: experiment.id, + visible: true, + }, + }); - const visiblePromptVariantCount = await prisma.promptVariant.count({ - where: { - experimentId: experiment.id, - visible: true, - }, - }); + const visiblePromptVariantCount = await prisma.promptVariant.count({ + where: { + experimentId: experiment.id, + visible: true, + }, + }); - return { - ...experiment, - testScenarioCount: visibleTestScenarioCount, - promptVariantCount: visiblePromptVariantCount, - }; - }), - ); + return { + ...experiment, + testScenarioCount: visibleTestScenarioCount, + promptVariantCount: visiblePromptVariantCount, + }; + }), + ); - return experimentsWithCounts; - }), + return experimentsWithCounts; + }), get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => { await requireCanViewExperiment(input.id, ctx); diff --git a/app/src/utils/accessControl.ts b/app/src/utils/accessControl.ts index 98aca75..ad55904 100644 --- a/app/src/utils/accessControl.ts +++ b/app/src/utils/accessControl.ts @@ -16,6 +16,26 @@ export const requireNothing = (ctx: TRPCContext) => { ctx.markAccessControlRun(); }; +export const requireCanViewOrganization = async (organizationId: string, ctx: TRPCContext) => { + const userId = ctx.session?.user.id; + if (!userId) { + throw new TRPCError({ code: "UNAUTHORIZED" }); + } + + const canView = await prisma.organizationUser.findFirst({ + where: { + userId, + organizationId, + }, + }); + + if (!canView) { + throw new TRPCError({ code: "UNAUTHORIZED" }); + } + + ctx.markAccessControlRun(); +}; + export const requireCanModifyOrganization = async (organizationId: string, ctx: TRPCContext) => { const userId = ctx.session?.user.id; if (!userId) { diff --git a/app/src/utils/hooks.ts b/app/src/utils/hooks.ts index 3c1518f..11ae5e3 100644 --- a/app/src/utils/hooks.ts +++ b/app/src/utils/hooks.ts @@ -4,6 +4,14 @@ import { api } from "~/utils/api"; import { NumberParam, useQueryParam, withDefault } from "use-query-params"; import { useAppStore } from "~/state/store"; +export const useExperiments = () => { + const selectedOrgId = useAppStore((state) => state.selectedOrgId); + return api.experiments.list.useQuery( + { organizationId: selectedOrgId ?? "" }, + { enabled: !!selectedOrgId }, + ); +}; + export const useExperiment = () => { const router = useRouter(); const experiment = api.experiments.get.useQuery( @@ -18,6 +26,14 @@ export const useExperimentAccess = () => { return useExperiment().data?.access ?? { canView: false, canModify: false }; }; +export const useDatasets = () => { + const selectedOrgId = useAppStore((state) => state.selectedOrgId); + return api.datasets.list.useQuery( + { organizationId: selectedOrgId ?? "" }, + { enabled: !!selectedOrgId }, + ); +}; + export const useDataset = () => { const router = useRouter(); const dataset = api.datasets.get.useQuery( @@ -136,8 +152,5 @@ export const useVisibleScenarioIds = () => useScenarios().data?.scenarios.map((s export const useSelectedOrg = () => { const selectedOrgId = useAppStore((state) => state.selectedOrgId); - return api.organizations.get.useQuery( - { id: selectedOrgId ?? "" }, - { enabled: !!selectedOrgId }, - ); + return api.organizations.get.useQuery({ id: selectedOrgId ?? "" }, { enabled: !!selectedOrgId }); };