diff --git a/.env.example b/.env.example index 1d6e9d7..d0ecc56 100644 --- a/.env.example +++ b/.env.example @@ -18,3 +18,11 @@ DATABASE_URL="postgresql://postgres:postgres@localhost:5432/openpipe?schema=publ OPENAI_API_KEY="" NEXT_PUBLIC_SOCKET_URL="http://localhost:3318" + +# Next Auth +NEXTAUTH_SECRET="your_secret" +NEXTAUTH_URL="http://localhost:3000" + +# Next Auth Github Provider +GITHUB_CLIENT_ID="your_client_id" +GITHUB_CLIENT_SECRET="your_secret" diff --git a/@types/nextjs-routes.d.ts b/@types/nextjs-routes.d.ts index 08e1564..d16d76e 100644 --- a/@types/nextjs-routes.d.ts +++ b/@types/nextjs-routes.d.ts @@ -11,6 +11,7 @@ declare module "nextjs-routes" { } from "next"; export type Route = + | StaticRoute<"/account/signin"> | DynamicRoute<"/api/auth/[...nextauth]", { "nextauth": string[] }> | DynamicRoute<"/api/trpc/[trpc]", { "trpc": string }> | DynamicRoute<"/experiments/[id]", { "id": string }> diff --git a/README.md b/README.md index 995a484..bdcbadb 100644 --- a/README.md +++ b/README.md @@ -4,11 +4,18 @@ OpenPipe is a flexible playground for comparing and optimizing LLM prompts. It lets you quickly generate, test and compare candidate prompts with realistic sample data. -**Live Demo:** https://openpipe.ai +## Sample Experiments + +These are simple experiments users have created that show how OpenPipe works. + +- [Country Capitals](https://openpipe.ai/experiments/11111111-1111-1111-1111-111111111111) +- [Reddit User Needs](https://openpipe.ai/experiments/22222222-2222-2222-2222-222222222222) +- [OpenAI Function Calls](https://openpipe.ai/experiments/2ebbdcb3-ed51-456e-87dc-91f72eaf3e2b) +- [Activity Classification](https://openpipe.ai/experiments/3950940f-ab6b-4b74-841d-7e9dbc4e4ff8) demo -Currently there's a public playground available at [https://openpipe.ai/](https://openpipe.ai/), but the recommended approach is to [run locally](#running-locally). +You can use our hosted version of OpenPipe at [https://openpipe.ai]. You can also clone this repository and [run it locally](#running-locally). ## High-Level Features @@ -47,5 +54,6 @@ OpenPipe currently supports GPT-3.5 and GPT-4. Wider model support is planned. 5. Install the dependencies: `cd openpipe && pnpm install` 6. Create a `.env` file (`cp .env.example .env`) and enter your `OPENAI_API_KEY`. 7. Update `DATABASE_URL` if necessary to point to your Postgres instance and run `pnpm prisma db push` to create the database. -8. Start the app: `pnpm dev`. -9. Navigate to [http://localhost:3000](http://localhost:3000) +8. Create a [GitHub OAuth App](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/creating-an-oauth-app) and update the `GITHUB_CLIENT_ID` and `GITHUB_CLIENT_SECRET` values. (Note: a PR to make auth optional when running locally would be a great contribution!) +9. Start the app: `pnpm dev`. +10. Navigate to [http://localhost:3000](http://localhost:3000) diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 3472e54..b098a06 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1,4 +1,4 @@ -lockfileVersion: '6.0' +lockfileVersion: '6.1' settings: autoInstallPeers: true diff --git a/prisma/migrations/20230718201303_add_users_and_orgs/migration.sql b/prisma/migrations/20230718201303_add_users_and_orgs/migration.sql new file mode 100644 index 0000000..bbd3152 --- /dev/null +++ b/prisma/migrations/20230718201303_add_users_and_orgs/migration.sql @@ -0,0 +1,124 @@ +DROP TABLE "Account"; +DROP TABLE "Session"; +DROP TABLE "User"; +DROP TABLE "VerificationToken"; + +CREATE TYPE "OrganizationUserRole" AS ENUM ('ADMIN', 'MEMBER', 'VIEWER'); + +-- CreateTable +CREATE TABLE "Organization" ( + "id" UUID NOT NULL, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" TIMESTAMP(3) NOT NULL, + "personalOrgUserId" UUID, + + CONSTRAINT "Organization_pkey" PRIMARY KEY ("id") +); + +-- CreateTable +CREATE TABLE "OrganizationUser" ( + "id" UUID NOT NULL, + "role" "OrganizationUserRole" NOT NULL, + "organizationId" UUID NOT NULL, + "userId" UUID NOT NULL, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" TIMESTAMP(3) NOT NULL, + + CONSTRAINT "OrganizationUser_pkey" PRIMARY KEY ("id") +); + +-- CreateTable +CREATE TABLE "Account" ( + "id" UUID NOT NULL, + "userId" UUID NOT NULL, + "type" TEXT NOT NULL, + "provider" TEXT NOT NULL, + "providerAccountId" TEXT NOT NULL, + "refresh_token" TEXT, + "refresh_token_expires_in" INTEGER, + "access_token" TEXT, + "expires_at" INTEGER, + "token_type" TEXT, + "scope" TEXT, + "id_token" TEXT, + "session_state" TEXT, + + CONSTRAINT "Account_pkey" PRIMARY KEY ("id") +); + +-- CreateTable +CREATE TABLE "Session" ( + "id" UUID NOT NULL, + "sessionToken" TEXT NOT NULL, + "userId" UUID NOT NULL, + "expires" TIMESTAMP(3) NOT NULL, + + CONSTRAINT "Session_pkey" PRIMARY KEY ("id") +); + +-- CreateTable +CREATE TABLE "User" ( + "id" UUID NOT NULL, + "name" TEXT, + "email" TEXT, + "emailVerified" TIMESTAMP(3), + "image" TEXT, + + CONSTRAINT "User_pkey" PRIMARY KEY ("id") +); + +-- CreateTable +CREATE TABLE "VerificationToken" ( + "identifier" TEXT NOT NULL, + "token" TEXT NOT NULL, + "expires" TIMESTAMP(3) NOT NULL +); + +INSERT INTO "Organization" ("id", "updatedAt") VALUES ('11111111-1111-1111-1111-111111111111', CURRENT_TIMESTAMP); + +-- AlterTable add organizationId as NULLABLE +ALTER TABLE "Experiment" ADD COLUMN "organizationId" UUID; + +-- Set default organization for existing experiments +UPDATE "Experiment" SET "organizationId" = '11111111-1111-1111-1111-111111111111'; + +-- AlterTable set organizationId as NOT NULL +ALTER TABLE "Experiment" ALTER COLUMN "organizationId" SET NOT NULL; + + +-- CreateIndex +CREATE UNIQUE INDEX "OrganizationUser_organizationId_userId_key" ON "OrganizationUser"("organizationId", "userId"); + +-- CreateIndex +CREATE UNIQUE INDEX "Account_provider_providerAccountId_key" ON "Account"("provider", "providerAccountId"); + +-- CreateIndex +CREATE UNIQUE INDEX "Session_sessionToken_key" ON "Session"("sessionToken"); + +-- CreateIndex +CREATE UNIQUE INDEX "User_email_key" ON "User"("email"); + +-- CreateIndex +CREATE UNIQUE INDEX "VerificationToken_token_key" ON "VerificationToken"("token"); + +-- CreateIndex +CREATE UNIQUE INDEX "VerificationToken_identifier_token_key" ON "VerificationToken"("identifier", "token"); + +-- AddForeignKey +ALTER TABLE "Experiment" ADD CONSTRAINT "Experiment_organizationId_fkey" FOREIGN KEY ("organizationId") REFERENCES "Organization"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "OrganizationUser" ADD CONSTRAINT "OrganizationUser_organizationId_fkey" FOREIGN KEY ("organizationId") REFERENCES "Organization"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "OrganizationUser" ADD CONSTRAINT "OrganizationUser_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "Account" ADD CONSTRAINT "Account_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "Session" ADD CONSTRAINT "Session_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +CREATE UNIQUE INDEX "Organization_personalOrgUserId_key" ON "Organization"("personalOrgUserId"); + +ALTER TABLE "Organization" ADD CONSTRAINT "Organization_personalOrgUserId_fkey" FOREIGN KEY ("personalOrgUserId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/prisma/schema.prisma b/prisma/schema.prisma index e8f8ea7..1afa8fc 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -16,8 +16,12 @@ model Experiment { sortIndex Int @default(0) - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt + organizationId String @db.Uuid + organization Organization? @relation(fields: [organizationId], references: [id], onDelete: Cascade) + + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + TemplateVariable TemplateVariable[] PromptVariant PromptVariant[] TestScenario TestScenario[] @@ -169,41 +173,77 @@ model OutputEvaluation { @@unique([modelOutputId, evaluationId]) } -// Necessary for Next auth +model Organization { + id String @id @default(uuid()) @db.Uuid + personalOrgUserId String? @unique @db.Uuid + PersonalOrgUser User? @relation(fields: [personalOrgUserId], references: [id], onDelete: Cascade) + + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + OrganizationUser OrganizationUser[] + Experiment Experiment[] +} + +enum OrganizationUserRole { + ADMIN + MEMBER + VIEWER +} + +model OrganizationUser { + id String @id @default(uuid()) @db.Uuid + + role OrganizationUserRole + + organizationId String @db.Uuid + organization Organization? @relation(fields: [organizationId], references: [id], onDelete: Cascade) + + userId String @db.Uuid + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + @@unique([organizationId, userId]) +} + model Account { - id String @id @default(cuid()) - userId String - type String - provider String - providerAccountId String - refresh_token String? // @db.Text - access_token String? // @db.Text - expires_at Int? - token_type String? - scope String? - id_token String? // @db.Text - session_state String? - user User @relation(fields: [userId], references: [id], onDelete: Cascade) + id String @id @default(uuid()) @db.Uuid + userId String @db.Uuid + type String + provider String + providerAccountId String + refresh_token String? @db.Text + refresh_token_expires_in Int? + access_token String? @db.Text + expires_at Int? + token_type String? + scope String? + id_token String? @db.Text + session_state String? + user User @relation(fields: [userId], references: [id], onDelete: Cascade) @@unique([provider, providerAccountId]) } model Session { - id String @id @default(cuid()) + id String @id @default(uuid()) @db.Uuid sessionToken String @unique - userId String + userId String @db.Uuid expires DateTime user User @relation(fields: [userId], references: [id], onDelete: Cascade) } model User { - id String @id @default(cuid()) - name String? - email String? @unique - emailVerified DateTime? - image String? - accounts Account[] - sessions Session[] + id String @id @default(uuid()) @db.Uuid + name String? + email String? @unique + emailVerified DateTime? + image String? + accounts Account[] + sessions Session[] + OrganizationUser OrganizationUser[] + Organization Organization[] } model VerificationToken { diff --git a/prisma/seed.ts b/prisma/seed.ts index a4fe415..156fd9f 100644 --- a/prisma/seed.ts +++ b/prisma/seed.ts @@ -2,40 +2,47 @@ import { prisma } from "~/server/db"; import dedent from "dedent"; import { generateNewCell } from "~/server/utils/generateNewCell"; -const experimentId = "11111111-1111-1111-1111-111111111111"; +const defaultId = "11111111-1111-1111-1111-111111111111"; + +await prisma.organization.deleteMany({ + where: { id: defaultId }, +}); +await prisma.organization.create({ + data: { id: defaultId }, +}); -// Delete the existing experiment await prisma.experiment.deleteMany({ where: { - id: experimentId, + id: defaultId, }, }); await prisma.experiment.create({ data: { - id: experimentId, + id: defaultId, label: "Country Capitals Example", + organizationId: defaultId, }, }); await prisma.scenarioVariantCell.deleteMany({ where: { promptVariant: { - experimentId, + experimentId: defaultId, }, }, }); await prisma.promptVariant.deleteMany({ where: { - experimentId, + experimentId: defaultId, }, }); await prisma.promptVariant.createMany({ data: [ { - experimentId, + experimentId: defaultId, label: "Prompt Variant 1", sortIndex: 0, model: "gpt-3.5-turbo-0613", @@ -52,7 +59,7 @@ await prisma.promptVariant.createMany({ }`, }, { - experimentId, + experimentId: defaultId, label: "Prompt Variant 2", sortIndex: 1, model: "gpt-3.5-turbo-0613", @@ -73,14 +80,14 @@ await prisma.promptVariant.createMany({ await prisma.templateVariable.deleteMany({ where: { - experimentId, + experimentId: defaultId, }, }); await prisma.templateVariable.createMany({ data: [ { - experimentId, + experimentId: defaultId, label: "country", }, ], @@ -88,28 +95,28 @@ await prisma.templateVariable.createMany({ await prisma.testScenario.deleteMany({ where: { - experimentId, + experimentId: defaultId, }, }); await prisma.testScenario.createMany({ data: [ { - experimentId, + experimentId: defaultId, sortIndex: 0, variableValues: { country: "Spain", }, }, { - experimentId, + experimentId: defaultId, sortIndex: 1, variableValues: { country: "USA", }, }, { - experimentId, + experimentId: defaultId, sortIndex: 2, variableValues: { country: "Chile", @@ -120,13 +127,13 @@ await prisma.testScenario.createMany({ const variants = await prisma.promptVariant.findMany({ where: { - experimentId, + experimentId: defaultId, }, }); const scenarios = await prisma.testScenario.findMany({ where: { - experimentId, + experimentId: defaultId, }, }); diff --git a/render.yaml b/render.yaml index 485ae2f..c0b18b8 100644 --- a/render.yaml +++ b/render.yaml @@ -12,7 +12,7 @@ services: dockerContext: . plan: standard domains: - - openpipe.ai + - app.openpipe.ai envVars: - key: NODE_ENV value: production diff --git a/src/components/OutputsTable/NewScenarioButton.tsx b/src/components/OutputsTable/NewScenarioButton.tsx index 10365a9..3f1ddb6 100644 --- a/src/components/OutputsTable/NewScenarioButton.tsx +++ b/src/components/OutputsTable/NewScenarioButton.tsx @@ -1,7 +1,7 @@ import { Button, type ButtonProps, HStack, Spinner, Icon } from "@chakra-ui/react"; import { BsPlus } from "react-icons/bs"; import { api } from "~/utils/api"; -import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; +import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks"; // Extracted Button styling into reusable component const StyledButton = ({ children, onClick }: ButtonProps) => ( @@ -17,6 +17,8 @@ const StyledButton = ({ children, onClick }: ButtonProps) => ( ); export default function NewScenarioButton() { + const { canModify } = useExperimentAccess(); + const experiment = useExperiment(); const mutation = api.scenarios.create.useMutation(); const utils = api.useContext(); @@ -38,6 +40,8 @@ export default function NewScenarioButton() { await utils.scenarios.list.invalidate(); }, [mutation]); + if (!canModify) return null; + return ( diff --git a/src/components/OutputsTable/NewVariantButton.tsx b/src/components/OutputsTable/NewVariantButton.tsx index e4d6b47..b71bcf7 100644 --- a/src/components/OutputsTable/NewVariantButton.tsx +++ b/src/components/OutputsTable/NewVariantButton.tsx @@ -1,7 +1,7 @@ -import { Button, Icon, Spinner } from "@chakra-ui/react"; +import { Box, Button, Icon, Spinner } from "@chakra-ui/react"; import { BsPlus } from "react-icons/bs"; import { api } from "~/utils/api"; -import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; +import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks"; import { cellPadding, headerMinHeight } from "../constants"; export default function NewVariantButton() { @@ -17,6 +17,9 @@ export default function NewVariantButton() { await utils.promptVariants.list.invalidate(); }, [mutation]); + const { canModify } = useExperimentAccess(); + if (!canModify) return ; + return ( + {!refetchingOutput && canModify && ( + + + )} ); diff --git a/src/components/OutputsTable/ScenarioEditor.tsx b/src/components/OutputsTable/ScenarioEditor.tsx index 3b2db7a..8049973 100644 --- a/src/components/OutputsTable/ScenarioEditor.tsx +++ b/src/components/OutputsTable/ScenarioEditor.tsx @@ -2,7 +2,7 @@ import { type DragEvent } from "react"; import { api } from "~/utils/api"; import { isEqual } from "lodash-es"; import { type Scenario } from "./types"; -import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; +import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks"; import { useState } from "react"; import { Box, Button, Flex, HStack, Icon, Spinner, Stack, Tooltip, VStack } from "@chakra-ui/react"; @@ -19,6 +19,8 @@ export default function ScenarioEditor({ hovered: boolean; canHide: boolean; }) { + const { canModify } = useExperimentAccess(); + const savedValues = scenario.variableValues as Record; const utils = api.useContext(); const [isDragTarget, setIsDragTarget] = useState(false); @@ -74,6 +76,7 @@ export default function ScenarioEditor({ alignItems="flex-start" pr={cellPadding.x} py={cellPadding.y} + pl={canModify ? 0 : cellPadding.x} height="100%" draggable={!variableInputHovered} onDragStart={(e) => { @@ -93,35 +96,38 @@ export default function ScenarioEditor({ onDrop={onReorder} backgroundColor={isDragTarget ? "gray.100" : "transparent"} > - - {props.canHide && ( - <> - - {/* for some reason the tooltip can't position itself properly relative to the icon without the wrapping box */} - + + - - - - - - )} - + _hover={{ color: "gray.800", cursor: "pointer" }} + /> + + )} + + )} + {variableLabels.length === 0 ? ( {vars.data ? "No scenario variables configured" : "Loading..."} ) : ( @@ -155,6 +161,8 @@ export default function ScenarioEditor({ fontSize="sm" lineHeight={1.2} value={value} + isDisabled={!canModify} + _disabled={{ opacity: 1, cursor: "default" }} onChange={(e) => { setValues((prev) => ({ ...prev, [key]: e.target.value })); }} diff --git a/src/components/OutputsTable/ScenariosHeader.tsx b/src/components/OutputsTable/ScenariosHeader.tsx index b09e753..c9a79e7 100644 --- a/src/components/OutputsTable/ScenariosHeader.tsx +++ b/src/components/OutputsTable/ScenariosHeader.tsx @@ -1,6 +1,6 @@ import { Button, GridItem, HStack, Heading } from "@chakra-ui/react"; import { cellPadding } from "../constants"; -import { useElementDimensions } from "~/utils/hooks"; +import { useElementDimensions, useExperimentAccess } from "~/utils/hooks"; import { stickyHeaderStyle } from "./styles"; import { BsPencil } from "react-icons/bs"; import { useAppStore } from "~/state/store"; @@ -13,6 +13,7 @@ export const ScenariosHeader = ({ numScenarios: number; }) => { const openDrawer = useAppStore((s) => s.openDrawer); + const { canModify } = useExperimentAccess(); const [ref, dimensions] = useElementDimensions(); const topValue = dimensions ? `-${dimensions.height - 24}px` : "-455px"; @@ -33,16 +34,18 @@ export const ScenariosHeader = ({ Scenarios ({numScenarios}) - + {canModify && ( + + )} ); diff --git a/src/components/OutputsTable/VariantEditor.tsx b/src/components/OutputsTable/VariantEditor.tsx index d5102ba..af11845 100644 --- a/src/components/OutputsTable/VariantEditor.tsx +++ b/src/components/OutputsTable/VariantEditor.tsx @@ -1,11 +1,12 @@ import { Box, Button, HStack, Spinner, Tooltip, useToast, Text } from "@chakra-ui/react"; import { useRef, useEffect, useState, useCallback } from "react"; -import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks"; +import { useExperimentAccess, useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks"; import { type PromptVariant } from "./types"; import { api } from "~/utils/api"; import { useAppStore } from "~/state/store"; export default function VariantEditor(props: { variant: PromptVariant }) { + const { canModify } = useExperimentAccess(); const monaco = useAppStore.use.sharedVariantEditor.monaco(); const editorRef = useRef["editor"]["create"]> | null>(null); const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`); @@ -40,18 +41,6 @@ export default function VariantEditor(props: { variant: PromptVariant }) { const model = editorRef.current.getModel(); if (!model) return; - const markers = monaco?.editor.getModelMarkers({ resource: model.uri }); - const hasErrors = markers?.some((m) => m.severity === monaco?.MarkerSeverity.Error); - - if (hasErrors) { - toast({ - title: "Invalid TypeScript", - description: "Please fix the TypeScript errors before saving.", - status: "error", - }); - return; - } - // Make sure the user defined the prompt with the string "prompt\w*=" somewhere const promptRegex = /prompt\s*=/; if (!promptRegex.test(currentFn)) { @@ -103,6 +92,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) { wordWrapBreakAfterCharacters: "", wordWrapBreakBeforeCharacters: "", quickSuggestions: true, + readOnly: !canModify, }); editorRef.current.onDidFocusEditorText(() => { @@ -130,6 +120,13 @@ export default function VariantEditor(props: { variant: PromptVariant }) { /* eslint-disable-next-line react-hooks/exhaustive-deps */ }, [monaco, editorId]); + useEffect(() => { + if (!editorRef.current) return; + editorRef.current.updateOptions({ + readOnly: !canModify, + }); + }, [canModify]); + return (
diff --git a/src/components/OutputsTable/styles.ts b/src/components/OutputsTable/styles.ts index 0dc8dc8..fcdac8f 100644 --- a/src/components/OutputsTable/styles.ts +++ b/src/components/OutputsTable/styles.ts @@ -2,7 +2,7 @@ import { type SystemStyleObject } from "@chakra-ui/react"; export const stickyHeaderStyle: SystemStyleObject = { position: "sticky", - top: "-1px", + top: "0", backgroundColor: "#fff", zIndex: 1, }; diff --git a/src/components/PublicPlaygroundWarning.tsx b/src/components/PublicPlaygroundWarning.tsx deleted file mode 100644 index ef612e5..0000000 --- a/src/components/PublicPlaygroundWarning.tsx +++ /dev/null @@ -1,21 +0,0 @@ -import { Flex, Icon, Link, Text } from "@chakra-ui/react"; -import { BsExclamationTriangleFill } from "react-icons/bs"; -import { env } from "~/env.mjs"; - -export default function PublicPlaygroundWarning() { - if (!env.NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND) return null; - - return ( - - - - Warning: this is a public playground. Anyone can see, edit or delete your experiments. For - private use,{" "} - - run a local copy - - . - - - ); -} diff --git a/src/components/VariantHeader/VariantHeader.tsx b/src/components/VariantHeader/VariantHeader.tsx index 96ad8a9..3c83316 100644 --- a/src/components/VariantHeader/VariantHeader.tsx +++ b/src/components/VariantHeader/VariantHeader.tsx @@ -1,15 +1,16 @@ import { useState, type DragEvent } from "react"; import { type PromptVariant } from "../OutputsTable/types"; import { api } from "~/utils/api"; -import { useHandledAsyncCallback } from "~/utils/hooks"; -import { HStack, Icon, GridItem } from "@chakra-ui/react"; // Changed here import { RiDraggable } from "react-icons/ri"; +import { useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks"; +import { HStack, Icon, Text, GridItem } from "@chakra-ui/react"; // Changed here import { cellPadding, headerMinHeight } from "../constants"; import AutoResizeTextArea from "../AutoResizeTextArea"; import { stickyHeaderStyle } from "../OutputsTable/styles"; import VariantHeaderMenuButton from "./VariantHeaderMenuButton"; export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) { + const { canModify } = useExperimentAccess(); const utils = api.useContext(); const [isDragTarget, setIsDragTarget] = useState(false); const [isInputHovered, setIsInputHovered] = useState(false); @@ -44,6 +45,16 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide: const [menuOpen, setMenuOpen] = useState(false); + if (!canModify) { + return ( + + + {props.variant.label} + + + ); + } + return ( { - const isActive = useRouter().pathname.startsWith(href); + const router = useRouter(); + const isActive = href && router.pathname.startsWith(href); return ( { + const user = useSession().data; + return ( @@ -59,26 +63,32 @@ const NavSidebar = () => { - + {user != null && ( + <> + + + )} + {user === null && ( + { + signIn("github").catch(console.error); + }} + /> + )} - - - : } + + - + p={2} + > + + ); @@ -108,16 +118,13 @@ export default function AppShell(props: { children: React.ReactNode; title?: str {props.title ? `${props.title} | OpenPipe` : "OpenPipe"} - - - diff --git a/src/components/nav/UserMenu.tsx b/src/components/nav/UserMenu.tsx new file mode 100644 index 0000000..37710ff --- /dev/null +++ b/src/components/nav/UserMenu.tsx @@ -0,0 +1,72 @@ +import { + HStack, + Icon, + Image, + VStack, + Text, + Popover, + PopoverTrigger, + PopoverContent, + Link, +} from "@chakra-ui/react"; +import { type Session } from "next-auth"; +import { signOut } from "next-auth/react"; +import { BsBoxArrowRight, BsChevronRight, BsPersonCircle } from "react-icons/bs"; + +export default function UserMenu({ user }: { user: Session }) { + const profileImage = user.user.image ? ( + profile picture + ) : ( + + ); + + return ( + <> + + + + {profileImage} + + + {user.user.name} + + + {user.user.email} + + + + + + + + {/* sign out */} + { + signOut().catch(console.error); + }} + px={4} + py={2} + spacing={4} + color="gray.500" + fontSize="sm" + > + + Sign out + + + + + + ); +} diff --git a/src/env.mjs b/src/env.mjs index cb57366..2032c08 100644 --- a/src/env.mjs +++ b/src/env.mjs @@ -15,6 +15,8 @@ export const env = createEnv({ .optional() .default("false") .transform((val) => val.toLowerCase() === "true"), + GITHUB_CLIENT_ID: z.string().min(1), + GITHUB_CLIENT_SECRET: z.string().min(1), }, /** @@ -24,11 +26,6 @@ export const env = createEnv({ */ client: { NEXT_PUBLIC_POSTHOG_KEY: z.string().optional(), - NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND: z - .string() - .optional() - .default("false") - .transform((val) => val.toLowerCase() === "true"), NEXT_PUBLIC_SOCKET_URL: z.string().url().default("http://localhost:3318"), }, @@ -42,8 +39,9 @@ export const env = createEnv({ OPENAI_API_KEY: process.env.OPENAI_API_KEY, RESTRICT_PRISMA_LOGS: process.env.RESTRICT_PRISMA_LOGS, NEXT_PUBLIC_POSTHOG_KEY: process.env.NEXT_PUBLIC_POSTHOG_KEY, - NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND: process.env.NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND, NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL, + GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID, + GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET, }, /** * Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation. diff --git a/src/pages/account/signin.tsx b/src/pages/account/signin.tsx new file mode 100644 index 0000000..74e00d4 --- /dev/null +++ b/src/pages/account/signin.tsx @@ -0,0 +1,23 @@ +import { signIn, useSession } from "next-auth/react"; +import { useRouter } from "next/router"; +import { useEffect } from "react"; +import AppShell from "~/components/nav/AppShell"; + +export default function SignIn() { + const session = useSession().data; + const router = useRouter(); + + useEffect(() => { + if (session) { + router.push("/experiments").catch(console.error); + } else if (session === null) { + signIn("github").catch(console.error); + } + }, [session, router]); + + return ( + +
+ + ); +} diff --git a/src/pages/experiments/[id].tsx b/src/pages/experiments/[id].tsx index 4d5e3a7..896e248 100644 --- a/src/pages/experiments/[id].tsx +++ b/src/pages/experiments/[id].tsx @@ -124,6 +124,8 @@ export default function Experiment() { ); } + const canModify = experiment.data?.access.canModify ?? false; + return ( @@ -143,37 +145,45 @@ export default function Experiment() { - setLabel(e.target.value)} - onBlur={onSaveLabel} - borderWidth={1} - borderColor="transparent" - fontSize={16} - px={0} - minW={{ base: 100, lg: 300 }} - flex={1} - _hover={{ borderColor: "gray.300" }} - _focus={{ borderColor: "blue.500", outline: "none" }} - /> + {canModify ? ( + setLabel(e.target.value)} + onBlur={onSaveLabel} + borderWidth={1} + borderColor="transparent" + fontSize={16} + px={0} + minW={{ base: 100, lg: 300 }} + flex={1} + _hover={{ borderColor: "gray.300" }} + _focus={{ borderColor: "blue.500", outline: "none" }} + /> + ) : ( + + {experiment.data?.label} + + )} - - - - + {canModify && ( + + + + + )} diff --git a/src/pages/experiments/index.tsx b/src/pages/experiments/index.tsx index 4af8344..cb2702b 100644 --- a/src/pages/experiments/index.tsx +++ b/src/pages/experiments/index.tsx @@ -6,18 +6,44 @@ import { Breadcrumb, BreadcrumbItem, Flex, + Center, + Text, + Link, } from "@chakra-ui/react"; import { RiFlaskLine } from "react-icons/ri"; import AppShell from "~/components/nav/AppShell"; import { api } from "~/utils/api"; import { NewExperimentButton } from "~/components/experiments/NewExperimentButton"; import { ExperimentCard } from "~/components/experiments/ExperimentCard"; +import { signIn, useSession } from "next-auth/react"; export default function ExperimentsPage() { const experiments = api.experiments.list.useQuery(); + const user = useSession().data; + + if (user === null) { + return ( + +
+ + { + signIn("github").catch(console.error); + }} + textDecor="underline" + > + Sign in + {" "} + to view or create new experiments! + +
+
+ ); + } + return ( - + diff --git a/src/server/api/routers/evaluations.router.ts b/src/server/api/routers/evaluations.router.ts index 7ee0d12..f113c20 100644 --- a/src/server/api/routers/evaluations.router.ts +++ b/src/server/api/routers/evaluations.router.ts @@ -1,20 +1,25 @@ import { EvalType } from "@prisma/client"; import { z } from "zod"; -import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; +import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc"; import { prisma } from "~/server/db"; import { runAllEvals } from "~/server/utils/evaluations"; +import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl"; export const evaluationsRouter = createTRPCRouter({ - list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { - return await prisma.evaluation.findMany({ - where: { - experimentId: input.experimentId, - }, - orderBy: { createdAt: "asc" }, - }); - }), + list: publicProcedure + .input(z.object({ experimentId: z.string() })) + .query(async ({ input, ctx }) => { + await requireCanViewExperiment(input.experimentId, ctx); - create: publicProcedure + return await prisma.evaluation.findMany({ + where: { + experimentId: input.experimentId, + }, + orderBy: { createdAt: "asc" }, + }); + }), + + create: protectedProcedure .input( z.object({ experimentId: z.string(), @@ -23,7 +28,9 @@ export const evaluationsRouter = createTRPCRouter({ evalType: z.nativeEnum(EvalType), }), ) - .mutation(async ({ input }) => { + .mutation(async ({ input, ctx }) => { + await requireCanModifyExperiment(input.experimentId, ctx); + await prisma.evaluation.create({ data: { experimentId: input.experimentId, @@ -38,7 +45,7 @@ export const evaluationsRouter = createTRPCRouter({ await runAllEvals(input.experimentId); }), - update: publicProcedure + update: protectedProcedure .input( z.object({ id: z.string(), @@ -49,7 +56,12 @@ export const evaluationsRouter = createTRPCRouter({ }), }), ) - .mutation(async ({ input }) => { + .mutation(async ({ input, ctx }) => { + const { experimentId } = await prisma.evaluation.findUniqueOrThrow({ + where: { id: input.id }, + }); + await requireCanModifyExperiment(experimentId, ctx); + const evaluation = await prisma.evaluation.update({ where: { id: input.id }, data: { @@ -69,9 +81,16 @@ export const evaluationsRouter = createTRPCRouter({ await runAllEvals(evaluation.experimentId); }), - delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => { - await prisma.evaluation.delete({ - where: { id: input.id }, - }); - }), + delete: protectedProcedure + .input(z.object({ id: z.string() })) + .mutation(async ({ input, ctx }) => { + const { experimentId } = await prisma.evaluation.findUniqueOrThrow({ + where: { id: input.id }, + }); + await requireCanModifyExperiment(experimentId, ctx); + + await prisma.evaluation.delete({ + where: { id: input.id }, + }); + }), }); diff --git a/src/server/api/routers/experiments.router.ts b/src/server/api/routers/experiments.router.ts index 5ed8c85..02931e8 100644 --- a/src/server/api/routers/experiments.router.ts +++ b/src/server/api/routers/experiments.router.ts @@ -1,12 +1,29 @@ import { z } from "zod"; -import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; +import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc"; import { prisma } from "~/server/db"; import dedent from "dedent"; import { generateNewCell } from "~/server/utils/generateNewCell"; +import { + canModifyExperiment, + requireCanModifyExperiment, + requireCanViewExperiment, + requireNothing, +} from "~/utils/accessControl"; +import userOrg from "~/server/utils/userOrg"; export const experimentsRouter = createTRPCRouter({ - list: publicProcedure.query(async () => { + list: protectedProcedure.query(async ({ ctx }) => { + // Anyone can list experiments + requireNothing(ctx); + const experiments = await prisma.experiment.findMany({ + where: { + organization: { + OrganizationUser: { + some: { userId: ctx.session.user.id }, + }, + }, + }, orderBy: { sortIndex: "asc", }, @@ -40,15 +57,29 @@ export const experimentsRouter = createTRPCRouter({ return experimentsWithCounts; }), - get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input }) => { - return await prisma.experiment.findFirst({ - where: { - id: input.id, - }, + get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => { + await requireCanViewExperiment(input.id, ctx); + const experiment = await prisma.experiment.findFirstOrThrow({ + where: { id: input.id }, }); + + const canModify = ctx.session?.user.id + ? await canModifyExperiment(experiment.id, ctx.session?.user.id) + : false; + + return { + ...experiment, + access: { + canView: true, + canModify, + }, + }; }), - create: publicProcedure.input(z.object({})).mutation(async () => { + create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => { + // Anyone can create an experiment + requireNothing(ctx); + const maxSortIndex = ( await prisma.experiment.aggregate({ @@ -62,6 +93,7 @@ export const experimentsRouter = createTRPCRouter({ data: { sortIndex: maxSortIndex + 1, label: `Experiment ${maxSortIndex + 1}`, + organizationId: (await userOrg(ctx.session.user.id)).id, }, }); @@ -117,9 +149,10 @@ export const experimentsRouter = createTRPCRouter({ return exp; }), - update: publicProcedure + update: protectedProcedure .input(z.object({ id: z.string(), updates: z.object({ label: z.string() }) })) - .mutation(async ({ input }) => { + .mutation(async ({ input, ctx }) => { + await requireCanModifyExperiment(input.id, ctx); return await prisma.experiment.update({ where: { id: input.id, @@ -130,11 +163,15 @@ export const experimentsRouter = createTRPCRouter({ }); }), - delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => { - await prisma.experiment.delete({ - where: { - id: input.id, - }, - }); - }), + delete: protectedProcedure + .input(z.object({ id: z.string() })) + .mutation(async ({ input, ctx }) => { + await requireCanModifyExperiment(input.id, ctx); + + await prisma.experiment.delete({ + where: { + id: input.id, + }, + }); + }), }); diff --git a/src/server/api/routers/promptVariants.router.ts b/src/server/api/routers/promptVariants.router.ts index ac9ead2..d08243a 100644 --- a/src/server/api/routers/promptVariants.router.ts +++ b/src/server/api/routers/promptVariants.router.ts @@ -1,6 +1,6 @@ import { isObject } from "lodash-es"; import { z } from "zod"; -import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; +import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc"; import { prisma } from "~/server/db"; import { generateNewCell } from "~/server/utils/generateNewCell"; import { OpenAIChatModel, type SupportedModel } from "~/server/types"; @@ -11,129 +11,140 @@ import { calculateTokenCost } from "~/utils/calculateTokenCost"; import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants"; import { type PromptVariant } from "@prisma/client"; import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn"; +import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl"; export const promptVariantsRouter = createTRPCRouter({ - list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { - return await prisma.promptVariant.findMany({ - where: { - experimentId: input.experimentId, - visible: true, - }, - orderBy: { sortIndex: "asc" }, - }); - }), + list: publicProcedure + .input(z.object({ experimentId: z.string() })) + .query(async ({ input, ctx }) => { + await requireCanViewExperiment(input.experimentId, ctx); - stats: publicProcedure.input(z.object({ variantId: z.string() })).query(async ({ input }) => { - const variant = await prisma.promptVariant.findUnique({ - where: { - id: input.variantId, - }, - }); + return await prisma.promptVariant.findMany({ + where: { + experimentId: input.experimentId, + visible: true, + }, + orderBy: { sortIndex: "asc" }, + }); + }), - if (!variant) { - throw new Error(`Prompt Variant with id ${input.variantId} does not exist`); - } + stats: publicProcedure + .input(z.object({ variantId: z.string() })) + .query(async ({ input, ctx }) => { + const variant = await prisma.promptVariant.findUnique({ + where: { + id: input.variantId, + }, + }); - const outputEvals = await prisma.outputEvaluation.groupBy({ - by: ["evaluationId"], - _sum: { - result: true, - }, - _count: { - id: true, - }, - where: { - modelOutput: { - scenarioVariantCell: { - promptVariant: { - id: input.variantId, - visible: true, + if (!variant) { + throw new Error(`Prompt Variant with id ${input.variantId} does not exist`); + } + + await requireCanViewExperiment(variant.experimentId, ctx); + + const outputEvals = await prisma.outputEvaluation.groupBy({ + by: ["evaluationId"], + _sum: { + result: true, + }, + _count: { + id: true, + }, + where: { + modelOutput: { + scenarioVariantCell: { + promptVariant: { + id: input.variantId, + visible: true, + }, + testScenario: { + visible: true, + }, }, + }, + }, + }); + + const evals = await prisma.evaluation.findMany({ + where: { + experimentId: variant.experimentId, + }, + }); + + const evalResults = evals.map((evalItem) => { + const evalResult = outputEvals.find( + (outputEval) => outputEval.evaluationId === evalItem.id, + ); + return { + id: evalItem.id, + label: evalItem.label, + passCount: evalResult?._sum?.result ?? 0, + totalCount: evalResult?._count?.id ?? 1, + }; + }); + + const scenarioCount = await prisma.testScenario.count({ + where: { + experimentId: variant.experimentId, + visible: true, + }, + }); + const outputCount = await prisma.scenarioVariantCell.count({ + where: { + promptVariantId: input.variantId, + testScenario: { visible: true }, + modelOutput: { + is: {}, + }, + }, + }); + + const overallTokens = await prisma.modelOutput.aggregate({ + where: { + scenarioVariantCell: { + promptVariantId: input.variantId, testScenario: { visible: true, }, }, }, - }, - }); - - const evals = await prisma.evaluation.findMany({ - where: { - experimentId: variant.experimentId, - }, - }); - - const evalResults = evals.map((evalItem) => { - const evalResult = outputEvals.find((outputEval) => outputEval.evaluationId === evalItem.id); - return { - id: evalItem.id, - label: evalItem.label, - passCount: evalResult?._sum?.result ?? 0, - totalCount: evalResult?._count?.id ?? 1, - }; - }); - - const scenarioCount = await prisma.testScenario.count({ - where: { - experimentId: variant.experimentId, - visible: true, - }, - }); - const outputCount = await prisma.scenarioVariantCell.count({ - where: { - promptVariantId: input.variantId, - testScenario: { visible: true }, - modelOutput: { - is: {}, + _sum: { + promptTokens: true, + completionTokens: true, }, - }, - }); + }); - const overallTokens = await prisma.modelOutput.aggregate({ - where: { - scenarioVariantCell: { + const promptTokens = overallTokens._sum?.promptTokens ?? 0; + const overallPromptCost = calculateTokenCost(variant.model, promptTokens); + const completionTokens = overallTokens._sum?.completionTokens ?? 0; + const overallCompletionCost = calculateTokenCost(variant.model, completionTokens, true); + + const overallCost = overallPromptCost + overallCompletionCost; + + const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({ + where: { promptVariantId: input.variantId, - testScenario: { - visible: true, + testScenario: { visible: true }, + // Check if is PENDING or IN_PROGRESS + retrievalStatus: { + in: ["PENDING", "IN_PROGRESS"], }, }, - }, - _sum: { - promptTokens: true, - completionTokens: true, - }, - }); + })); - const promptTokens = overallTokens._sum?.promptTokens ?? 0; - const overallPromptCost = calculateTokenCost(variant.model, promptTokens); - const completionTokens = overallTokens._sum?.completionTokens ?? 0; - const overallCompletionCost = calculateTokenCost(variant.model, completionTokens, true); + return { + evalResults, + promptTokens, + completionTokens, + overallCost, + scenarioCount, + outputCount, + awaitingRetrievals, + }; + }), - const overallCost = overallPromptCost + overallCompletionCost; - - const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({ - where: { - promptVariantId: input.variantId, - testScenario: { visible: true }, - // Check if is PENDING or IN_PROGRESS - retrievalStatus: { - in: ["PENDING", "IN_PROGRESS"], - }, - }, - })); - - return { - evalResults, - promptTokens, - completionTokens, - overallCost, - scenarioCount, - outputCount, - awaitingRetrievals, - }; - }), - - create: publicProcedure + create: protectedProcedure .input( z.object({ experimentId: z.string(), @@ -141,7 +152,9 @@ export const promptVariantsRouter = createTRPCRouter({ newModel: z.string().optional(), }), ) - .mutation(async ({ input }) => { + .mutation(async ({ input, ctx }) => { + await requireCanViewExperiment(input.experimentId, ctx); + let originalVariant: PromptVariant | null = null; if (input.variantId) { originalVariant = await prisma.promptVariant.findUnique({ @@ -217,7 +230,7 @@ export const promptVariantsRouter = createTRPCRouter({ return newVariant; }), - update: publicProcedure + update: protectedProcedure .input( z.object({ id: z.string(), @@ -226,7 +239,7 @@ export const promptVariantsRouter = createTRPCRouter({ }), }), ) - .mutation(async ({ input }) => { + .mutation(async ({ input, ctx }) => { const existing = await prisma.promptVariant.findUnique({ where: { id: input.id, @@ -237,6 +250,8 @@ export const promptVariantsRouter = createTRPCRouter({ throw new Error(`Prompt Variant with id ${input.id} does not exist`); } + await requireCanModifyExperiment(existing.experimentId, ctx); + const updatePromptVariantAction = prisma.promptVariant.update({ where: { id: input.id, @@ -252,13 +267,18 @@ export const promptVariantsRouter = createTRPCRouter({ return updatedPromptVariant; }), - hide: publicProcedure + hide: protectedProcedure .input( z.object({ id: z.string(), }), ) - .mutation(async ({ input }) => { + .mutation(async ({ input, ctx }) => { + const { experimentId } = await prisma.promptVariant.findUniqueOrThrow({ + where: { id: input.id }, + }); + await requireCanModifyExperiment(experimentId, ctx); + const updatedPromptVariant = await prisma.promptVariant.update({ where: { id: input.id }, data: { visible: false, experiment: { update: { updatedAt: new Date() } } }, @@ -267,19 +287,20 @@ export const promptVariantsRouter = createTRPCRouter({ return updatedPromptVariant; }), - replaceVariant: publicProcedure + replaceVariant: protectedProcedure .input( z.object({ id: z.string(), constructFn: z.string(), }), ) - .mutation(async ({ input }) => { - const existing = await prisma.promptVariant.findUnique({ + .mutation(async ({ input, ctx }) => { + const existing = await prisma.promptVariant.findUniqueOrThrow({ where: { id: input.id, }, }); + await requireCanModifyExperiment(existing.experimentId, ctx); if (!existing) { throw new Error(`Prompt Variant with id ${input.id} does not exist`); @@ -347,14 +368,19 @@ export const promptVariantsRouter = createTRPCRouter({ return { status: "ok" } as const; }), - reorder: publicProcedure + reorder: protectedProcedure .input( z.object({ draggedId: z.string(), droppedId: z.string(), }), ) - .mutation(async ({ input }) => { + .mutation(async ({ input, ctx }) => { + const { experimentId } = await prisma.promptVariant.findUniqueOrThrow({ + where: { id: input.draggedId }, + }); + await requireCanModifyExperiment(experimentId, ctx); + await reorderPromptVariants(input.draggedId, input.droppedId); }), }); diff --git a/src/server/api/routers/scenarioVariantCells.router.ts b/src/server/api/routers/scenarioVariantCells.router.ts index b07657e..fee3d7a 100644 --- a/src/server/api/routers/scenarioVariantCells.router.ts +++ b/src/server/api/routers/scenarioVariantCells.router.ts @@ -1,8 +1,9 @@ import { z } from "zod"; -import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; +import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc"; import { prisma } from "~/server/db"; import { generateNewCell } from "~/server/utils/generateNewCell"; import { queueLLMRetrievalTask } from "~/server/utils/queueLLMRetrievalTask"; +import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl"; export const scenarioVariantCellsRouter = createTRPCRouter({ get: publicProcedure @@ -12,7 +13,12 @@ export const scenarioVariantCellsRouter = createTRPCRouter({ variantId: z.string(), }), ) - .query(async ({ input }) => { + .query(async ({ input, ctx }) => { + const { experimentId } = await prisma.testScenario.findUniqueOrThrow({ + where: { id: input.scenarioId }, + }); + await requireCanViewExperiment(experimentId, ctx); + return await prisma.scenarioVariantCell.findUnique({ where: { promptVariantId_testScenarioId: { @@ -35,14 +41,20 @@ export const scenarioVariantCellsRouter = createTRPCRouter({ }, }); }), - forceRefetch: publicProcedure + forceRefetch: protectedProcedure .input( z.object({ scenarioId: z.string(), variantId: z.string(), }), ) - .mutation(async ({ input }) => { + .mutation(async ({ input, ctx }) => { + const { experimentId } = await prisma.testScenario.findUniqueOrThrow({ + where: { id: input.scenarioId }, + }); + + await requireCanModifyExperiment(experimentId, ctx); + const cell = await prisma.scenarioVariantCell.findUnique({ where: { promptVariantId_testScenarioId: { diff --git a/src/server/api/routers/scenarios.router.ts b/src/server/api/routers/scenarios.router.ts index 0ddfb0b..91f1852 100644 --- a/src/server/api/routers/scenarios.router.ts +++ b/src/server/api/routers/scenarios.router.ts @@ -1,32 +1,39 @@ import { z } from "zod"; -import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; +import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc"; import { prisma } from "~/server/db"; import { autogenerateScenarioValues } from "../autogen"; import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated"; import { runAllEvals } from "~/server/utils/evaluations"; import { generateNewCell } from "~/server/utils/generateNewCell"; +import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl"; export const scenariosRouter = createTRPCRouter({ - list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { - return await prisma.testScenario.findMany({ - where: { - experimentId: input.experimentId, - visible: true, - }, - orderBy: { - sortIndex: "asc", - }, - }); - }), + list: publicProcedure + .input(z.object({ experimentId: z.string() })) + .query(async ({ input, ctx }) => { + await requireCanViewExperiment(input.experimentId, ctx); - create: publicProcedure + return await prisma.testScenario.findMany({ + where: { + experimentId: input.experimentId, + visible: true, + }, + orderBy: { + sortIndex: "asc", + }, + }); + }), + + create: protectedProcedure .input( z.object({ experimentId: z.string(), autogenerate: z.boolean().optional(), }), ) - .mutation(async ({ input }) => { + .mutation(async ({ input, ctx }) => { + await requireCanModifyExperiment(input.experimentId, ctx); + const maxSortIndex = ( await prisma.testScenario.aggregate({ @@ -66,7 +73,14 @@ export const scenariosRouter = createTRPCRouter({ } }), - hide: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => { + hide: protectedProcedure.input(z.object({ id: z.string() })).mutation(async ({ input, ctx }) => { + const experimentId = ( + await prisma.testScenario.findUniqueOrThrow({ + where: { id: input.id }, + }) + ).experimentId; + + await requireCanModifyExperiment(experimentId, ctx); const hiddenScenario = await prisma.testScenario.update({ where: { id: input.id }, data: { visible: false, experiment: { update: { updatedAt: new Date() } } }, @@ -78,14 +92,14 @@ export const scenariosRouter = createTRPCRouter({ return hiddenScenario; }), - reorder: publicProcedure + reorder: protectedProcedure .input( z.object({ draggedId: z.string(), droppedId: z.string(), }), ) - .mutation(async ({ input }) => { + .mutation(async ({ input, ctx }) => { const dragged = await prisma.testScenario.findUnique({ where: { id: input.draggedId, @@ -104,6 +118,8 @@ export const scenariosRouter = createTRPCRouter({ ); } + await requireCanModifyExperiment(dragged.experimentId, ctx); + const visibleItems = await prisma.testScenario.findMany({ where: { experimentId: dragged.experimentId, @@ -147,14 +163,14 @@ export const scenariosRouter = createTRPCRouter({ ); }), - replaceWithValues: publicProcedure + replaceWithValues: protectedProcedure .input( z.object({ id: z.string(), values: z.record(z.string()), }), ) - .mutation(async ({ input }) => { + .mutation(async ({ input, ctx }) => { const existing = await prisma.testScenario.findUnique({ where: { id: input.id, @@ -165,6 +181,8 @@ export const scenariosRouter = createTRPCRouter({ throw new Error(`Scenario with id ${input.id} does not exist`); } + await requireCanModifyExperiment(existing.experimentId, ctx); + const newScenario = await prisma.testScenario.create({ data: { experimentId: existing.experimentId, diff --git a/src/server/api/routers/templateVariables.router.ts b/src/server/api/routers/templateVariables.router.ts index 6762112..d62fec4 100644 --- a/src/server/api/routers/templateVariables.router.ts +++ b/src/server/api/routers/templateVariables.router.ts @@ -1,11 +1,14 @@ import { z } from "zod"; -import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; +import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc"; import { prisma } from "~/server/db"; +import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl"; export const templateVarsRouter = createTRPCRouter({ - create: publicProcedure + create: protectedProcedure .input(z.object({ experimentId: z.string(), label: z.string() })) - .mutation(async ({ input }) => { + .mutation(async ({ input, ctx }) => { + await requireCanModifyExperiment(input.experimentId, ctx); + await prisma.templateVariable.create({ data: { experimentId: input.experimentId, @@ -14,22 +17,33 @@ export const templateVarsRouter = createTRPCRouter({ }); }), - delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => { - await prisma.templateVariable.delete({ where: { id: input.id } }); - }), + delete: protectedProcedure + .input(z.object({ id: z.string() })) + .mutation(async ({ input, ctx }) => { + const { experimentId } = await prisma.templateVariable.findUniqueOrThrow({ + where: { id: input.id }, + }); - list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { - return await prisma.templateVariable.findMany({ - where: { - experimentId: input.experimentId, - }, - orderBy: { - createdAt: "asc", - }, - select: { - id: true, - label: true, - }, - }); - }), + await requireCanModifyExperiment(experimentId, ctx); + + await prisma.templateVariable.delete({ where: { id: input.id } }); + }), + + list: publicProcedure + .input(z.object({ experimentId: z.string() })) + .query(async ({ input, ctx }) => { + await requireCanViewExperiment(input.experimentId, ctx); + return await prisma.templateVariable.findMany({ + where: { + experimentId: input.experimentId, + }, + orderBy: { + createdAt: "asc", + }, + select: { + id: true, + label: true, + }, + }); + }), }); diff --git a/src/server/api/trpc.ts b/src/server/api/trpc.ts index 244cabd..e795857 100644 --- a/src/server/api/trpc.ts +++ b/src/server/api/trpc.ts @@ -27,6 +27,9 @@ type CreateContextOptions = { session: Session | null; }; +// eslint-disable-next-line @typescript-eslint/no-empty-function +const noOp = () => {}; + /** * This helper generates the "internals" for a tRPC context. If you need to use it, you can export * it from here. @@ -41,6 +44,7 @@ const createInnerTRPCContext = (opts: CreateContextOptions) => { return { session: opts.session, prisma, + markAccessControlRun: noOp, }; }; @@ -69,6 +73,8 @@ export const createTRPCContext = async (opts: CreateNextContextOptions) => { * errors on the backend. */ +export type TRPCContext = Awaited>; + const t = initTRPC.context().create({ transformer: superjson, errorFormatter({ shape, error }) { @@ -106,16 +112,29 @@ export const createTRPCRouter = t.router; export const publicProcedure = t.procedure; /** Reusable middleware that enforces users are logged in before running the procedure. */ -const enforceUserIsAuthed = t.middleware(({ ctx, next }) => { +const enforceUserIsAuthed = t.middleware(async ({ ctx, next }) => { if (!ctx.session || !ctx.session.user) { throw new TRPCError({ code: "UNAUTHORIZED" }); } - return next({ + + let accessControlRun = false; + const resp = await next({ ctx: { // infers the `session` as non-nullable session: { ...ctx.session, user: ctx.session.user }, + markAccessControlRun: () => { + accessControlRun = true; + }, }, }); + if (!accessControlRun) + throw new TRPCError({ + code: "INTERNAL_SERVER_ERROR", + message: + "Protected routes must perform access control checks then explicitly invoke the `ctx.markAccessControlRun()` function to ensure we don't forget access control on a route.", + }); + + return resp; }); /** diff --git a/src/server/auth.ts b/src/server/auth.ts index 4531237..f2b779c 100644 --- a/src/server/auth.ts +++ b/src/server/auth.ts @@ -2,6 +2,8 @@ import { PrismaAdapter } from "@next-auth/prisma-adapter"; import { type GetServerSidePropsContext } from "next"; import { getServerSession, type NextAuthOptions, type DefaultSession } from "next-auth"; import { prisma } from "~/server/db"; +import GitHubProvider from "next-auth/providers/github"; +import { env } from "~/env.mjs"; /** * Module augmentation for `next-auth` types. Allows us to add custom properties to the `session` @@ -41,20 +43,15 @@ export const authOptions: NextAuthOptions = { }, adapter: PrismaAdapter(prisma), providers: [ - // DiscordProvider({ - // clientId: env.DISCORD_CLIENT_ID, - // clientSecret: env.DISCORD_CLIENT_SECRET, - // }), - /** - * ...add more providers here. - * - * Most other providers require a bit more work than the Discord provider. For example, the - * GitHub provider requires you to add the `refresh_token_expires_in` field to the Account - * model. Refer to the NextAuth.js docs for the provider you want to use. Example: - * - * @see https://next-auth.js.org/providers/github - */ + GitHubProvider({ + clientId: env.GITHUB_CLIENT_ID, + clientSecret: env.GITHUB_CLIENT_SECRET, + }), ], + theme: { + logo: "/logo.svg", + brandColor: "#ff5733", + }, }; /** diff --git a/src/server/utils/userOrg.ts b/src/server/utils/userOrg.ts new file mode 100644 index 0000000..4158f87 --- /dev/null +++ b/src/server/utils/userOrg.ts @@ -0,0 +1,19 @@ +import { prisma } from "~/server/db"; + +export default async function userOrg(userId: string) { + return await prisma.organization.upsert({ + where: { + personalOrgUserId: userId, + }, + update: {}, + create: { + personalOrgUserId: userId, + OrganizationUser: { + create: { + userId: userId, + role: "ADMIN", + }, + }, + }, + }); +} diff --git a/src/utils/accessControl.ts b/src/utils/accessControl.ts new file mode 100644 index 0000000..e43d57d --- /dev/null +++ b/src/utils/accessControl.ts @@ -0,0 +1,49 @@ +import { OrganizationUserRole } from "@prisma/client"; +import { TRPCError } from "@trpc/server"; +import { type TRPCContext } from "~/server/api/trpc"; +import { prisma } from "~/server/db"; + +// No-op method for protected routes that really should be accessible to anyone. +export const requireNothing = (ctx: TRPCContext) => { + ctx.markAccessControlRun(); +}; + +export const requireCanViewExperiment = async (experimentId: string, ctx: TRPCContext) => { + await prisma.experiment.findFirst({ + where: { id: experimentId }, + }); + + // Right now all experiments are publicly viewable, so this is a no-op. + ctx.markAccessControlRun(); +}; + +export const canModifyExperiment = async (experimentId: string, userId: string) => { + const experiment = await prisma.experiment.findFirst({ + where: { + id: experimentId, + organization: { + OrganizationUser: { + some: { + role: { in: [OrganizationUserRole.ADMIN, OrganizationUserRole.MEMBER] }, + userId, + }, + }, + }, + }, + }); + + return !!experiment; +}; + +export const requireCanModifyExperiment = async (experimentId: string, ctx: TRPCContext) => { + const userId = ctx.session?.user.id; + if (!userId) { + throw new TRPCError({ code: "UNAUTHORIZED" }); + } + + if (!(await canModifyExperiment(experimentId, userId))) { + throw new TRPCError({ code: "UNAUTHORIZED" }); + } + + ctx.markAccessControlRun(); +}; diff --git a/src/utils/hooks.ts b/src/utils/hooks.ts index 7bb1c4b..fed6637 100644 --- a/src/utils/hooks.ts +++ b/src/utils/hooks.ts @@ -12,6 +12,10 @@ export const useExperiment = () => { return experiment; }; +export const useExperimentAccess = () => { + return useExperiment().data?.access ?? { canView: false, canModify: false }; +}; + type AsyncFunction = (...args: T) => Promise; export function useHandledAsyncCallback(