diff --git a/.env.example b/.env.example
index 1d6e9d7..d0ecc56 100644
--- a/.env.example
+++ b/.env.example
@@ -18,3 +18,11 @@ DATABASE_URL="postgresql://postgres:postgres@localhost:5432/openpipe?schema=publ
OPENAI_API_KEY=""
NEXT_PUBLIC_SOCKET_URL="http://localhost:3318"
+
+# Next Auth
+NEXTAUTH_SECRET="your_secret"
+NEXTAUTH_URL="http://localhost:3000"
+
+# Next Auth Github Provider
+GITHUB_CLIENT_ID="your_client_id"
+GITHUB_CLIENT_SECRET="your_secret"
diff --git a/@types/nextjs-routes.d.ts b/@types/nextjs-routes.d.ts
index 08e1564..d16d76e 100644
--- a/@types/nextjs-routes.d.ts
+++ b/@types/nextjs-routes.d.ts
@@ -11,6 +11,7 @@ declare module "nextjs-routes" {
} from "next";
export type Route =
+ | StaticRoute<"/account/signin">
| DynamicRoute<"/api/auth/[...nextauth]", { "nextauth": string[] }>
| DynamicRoute<"/api/trpc/[trpc]", { "trpc": string }>
| DynamicRoute<"/experiments/[id]", { "id": string }>
diff --git a/README.md b/README.md
index 995a484..bdcbadb 100644
--- a/README.md
+++ b/README.md
@@ -4,11 +4,18 @@
OpenPipe is a flexible playground for comparing and optimizing LLM prompts. It lets you quickly generate, test and compare candidate prompts with realistic sample data.
-**Live Demo:** https://openpipe.ai
+## Sample Experiments
+
+These are simple experiments users have created that show how OpenPipe works.
+
+- [Country Capitals](https://openpipe.ai/experiments/11111111-1111-1111-1111-111111111111)
+- [Reddit User Needs](https://openpipe.ai/experiments/22222222-2222-2222-2222-222222222222)
+- [OpenAI Function Calls](https://openpipe.ai/experiments/2ebbdcb3-ed51-456e-87dc-91f72eaf3e2b)
+- [Activity Classification](https://openpipe.ai/experiments/3950940f-ab6b-4b74-841d-7e9dbc4e4ff8)
-Currently there's a public playground available at [https://openpipe.ai/](https://openpipe.ai/), but the recommended approach is to [run locally](#running-locally).
+You can use our hosted version of OpenPipe at [https://openpipe.ai]. You can also clone this repository and [run it locally](#running-locally).
## High-Level Features
@@ -47,5 +54,6 @@ OpenPipe currently supports GPT-3.5 and GPT-4. Wider model support is planned.
5. Install the dependencies: `cd openpipe && pnpm install`
6. Create a `.env` file (`cp .env.example .env`) and enter your `OPENAI_API_KEY`.
7. Update `DATABASE_URL` if necessary to point to your Postgres instance and run `pnpm prisma db push` to create the database.
-8. Start the app: `pnpm dev`.
-9. Navigate to [http://localhost:3000](http://localhost:3000)
+8. Create a [GitHub OAuth App](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/creating-an-oauth-app) and update the `GITHUB_CLIENT_ID` and `GITHUB_CLIENT_SECRET` values. (Note: a PR to make auth optional when running locally would be a great contribution!)
+9. Start the app: `pnpm dev`.
+10. Navigate to [http://localhost:3000](http://localhost:3000)
diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml
index 3472e54..b098a06 100644
--- a/pnpm-lock.yaml
+++ b/pnpm-lock.yaml
@@ -1,4 +1,4 @@
-lockfileVersion: '6.0'
+lockfileVersion: '6.1'
settings:
autoInstallPeers: true
diff --git a/prisma/migrations/20230718201303_add_users_and_orgs/migration.sql b/prisma/migrations/20230718201303_add_users_and_orgs/migration.sql
new file mode 100644
index 0000000..bbd3152
--- /dev/null
+++ b/prisma/migrations/20230718201303_add_users_and_orgs/migration.sql
@@ -0,0 +1,124 @@
+DROP TABLE "Account";
+DROP TABLE "Session";
+DROP TABLE "User";
+DROP TABLE "VerificationToken";
+
+CREATE TYPE "OrganizationUserRole" AS ENUM ('ADMIN', 'MEMBER', 'VIEWER');
+
+-- CreateTable
+CREATE TABLE "Organization" (
+ "id" UUID NOT NULL,
+ "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ "updatedAt" TIMESTAMP(3) NOT NULL,
+ "personalOrgUserId" UUID,
+
+ CONSTRAINT "Organization_pkey" PRIMARY KEY ("id")
+);
+
+-- CreateTable
+CREATE TABLE "OrganizationUser" (
+ "id" UUID NOT NULL,
+ "role" "OrganizationUserRole" NOT NULL,
+ "organizationId" UUID NOT NULL,
+ "userId" UUID NOT NULL,
+ "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ "updatedAt" TIMESTAMP(3) NOT NULL,
+
+ CONSTRAINT "OrganizationUser_pkey" PRIMARY KEY ("id")
+);
+
+-- CreateTable
+CREATE TABLE "Account" (
+ "id" UUID NOT NULL,
+ "userId" UUID NOT NULL,
+ "type" TEXT NOT NULL,
+ "provider" TEXT NOT NULL,
+ "providerAccountId" TEXT NOT NULL,
+ "refresh_token" TEXT,
+ "refresh_token_expires_in" INTEGER,
+ "access_token" TEXT,
+ "expires_at" INTEGER,
+ "token_type" TEXT,
+ "scope" TEXT,
+ "id_token" TEXT,
+ "session_state" TEXT,
+
+ CONSTRAINT "Account_pkey" PRIMARY KEY ("id")
+);
+
+-- CreateTable
+CREATE TABLE "Session" (
+ "id" UUID NOT NULL,
+ "sessionToken" TEXT NOT NULL,
+ "userId" UUID NOT NULL,
+ "expires" TIMESTAMP(3) NOT NULL,
+
+ CONSTRAINT "Session_pkey" PRIMARY KEY ("id")
+);
+
+-- CreateTable
+CREATE TABLE "User" (
+ "id" UUID NOT NULL,
+ "name" TEXT,
+ "email" TEXT,
+ "emailVerified" TIMESTAMP(3),
+ "image" TEXT,
+
+ CONSTRAINT "User_pkey" PRIMARY KEY ("id")
+);
+
+-- CreateTable
+CREATE TABLE "VerificationToken" (
+ "identifier" TEXT NOT NULL,
+ "token" TEXT NOT NULL,
+ "expires" TIMESTAMP(3) NOT NULL
+);
+
+INSERT INTO "Organization" ("id", "updatedAt") VALUES ('11111111-1111-1111-1111-111111111111', CURRENT_TIMESTAMP);
+
+-- AlterTable add organizationId as NULLABLE
+ALTER TABLE "Experiment" ADD COLUMN "organizationId" UUID;
+
+-- Set default organization for existing experiments
+UPDATE "Experiment" SET "organizationId" = '11111111-1111-1111-1111-111111111111';
+
+-- AlterTable set organizationId as NOT NULL
+ALTER TABLE "Experiment" ALTER COLUMN "organizationId" SET NOT NULL;
+
+
+-- CreateIndex
+CREATE UNIQUE INDEX "OrganizationUser_organizationId_userId_key" ON "OrganizationUser"("organizationId", "userId");
+
+-- CreateIndex
+CREATE UNIQUE INDEX "Account_provider_providerAccountId_key" ON "Account"("provider", "providerAccountId");
+
+-- CreateIndex
+CREATE UNIQUE INDEX "Session_sessionToken_key" ON "Session"("sessionToken");
+
+-- CreateIndex
+CREATE UNIQUE INDEX "User_email_key" ON "User"("email");
+
+-- CreateIndex
+CREATE UNIQUE INDEX "VerificationToken_token_key" ON "VerificationToken"("token");
+
+-- CreateIndex
+CREATE UNIQUE INDEX "VerificationToken_identifier_token_key" ON "VerificationToken"("identifier", "token");
+
+-- AddForeignKey
+ALTER TABLE "Experiment" ADD CONSTRAINT "Experiment_organizationId_fkey" FOREIGN KEY ("organizationId") REFERENCES "Organization"("id") ON DELETE CASCADE ON UPDATE CASCADE;
+
+-- AddForeignKey
+ALTER TABLE "OrganizationUser" ADD CONSTRAINT "OrganizationUser_organizationId_fkey" FOREIGN KEY ("organizationId") REFERENCES "Organization"("id") ON DELETE CASCADE ON UPDATE CASCADE;
+
+-- AddForeignKey
+ALTER TABLE "OrganizationUser" ADD CONSTRAINT "OrganizationUser_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
+
+-- AddForeignKey
+ALTER TABLE "Account" ADD CONSTRAINT "Account_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
+
+-- AddForeignKey
+ALTER TABLE "Session" ADD CONSTRAINT "Session_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
+
+CREATE UNIQUE INDEX "Organization_personalOrgUserId_key" ON "Organization"("personalOrgUserId");
+
+ALTER TABLE "Organization" ADD CONSTRAINT "Organization_personalOrgUserId_fkey" FOREIGN KEY ("personalOrgUserId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
diff --git a/prisma/schema.prisma b/prisma/schema.prisma
index e8f8ea7..1afa8fc 100644
--- a/prisma/schema.prisma
+++ b/prisma/schema.prisma
@@ -16,8 +16,12 @@ model Experiment {
sortIndex Int @default(0)
- createdAt DateTime @default(now())
- updatedAt DateTime @updatedAt
+ organizationId String @db.Uuid
+ organization Organization? @relation(fields: [organizationId], references: [id], onDelete: Cascade)
+
+ createdAt DateTime @default(now())
+ updatedAt DateTime @updatedAt
+
TemplateVariable TemplateVariable[]
PromptVariant PromptVariant[]
TestScenario TestScenario[]
@@ -169,41 +173,77 @@ model OutputEvaluation {
@@unique([modelOutputId, evaluationId])
}
-// Necessary for Next auth
+model Organization {
+ id String @id @default(uuid()) @db.Uuid
+ personalOrgUserId String? @unique @db.Uuid
+ PersonalOrgUser User? @relation(fields: [personalOrgUserId], references: [id], onDelete: Cascade)
+
+ createdAt DateTime @default(now())
+ updatedAt DateTime @updatedAt
+ OrganizationUser OrganizationUser[]
+ Experiment Experiment[]
+}
+
+enum OrganizationUserRole {
+ ADMIN
+ MEMBER
+ VIEWER
+}
+
+model OrganizationUser {
+ id String @id @default(uuid()) @db.Uuid
+
+ role OrganizationUserRole
+
+ organizationId String @db.Uuid
+ organization Organization? @relation(fields: [organizationId], references: [id], onDelete: Cascade)
+
+ userId String @db.Uuid
+ user User @relation(fields: [userId], references: [id], onDelete: Cascade)
+
+ createdAt DateTime @default(now())
+ updatedAt DateTime @updatedAt
+
+ @@unique([organizationId, userId])
+}
+
model Account {
- id String @id @default(cuid())
- userId String
- type String
- provider String
- providerAccountId String
- refresh_token String? // @db.Text
- access_token String? // @db.Text
- expires_at Int?
- token_type String?
- scope String?
- id_token String? // @db.Text
- session_state String?
- user User @relation(fields: [userId], references: [id], onDelete: Cascade)
+ id String @id @default(uuid()) @db.Uuid
+ userId String @db.Uuid
+ type String
+ provider String
+ providerAccountId String
+ refresh_token String? @db.Text
+ refresh_token_expires_in Int?
+ access_token String? @db.Text
+ expires_at Int?
+ token_type String?
+ scope String?
+ id_token String? @db.Text
+ session_state String?
+ user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@unique([provider, providerAccountId])
}
model Session {
- id String @id @default(cuid())
+ id String @id @default(uuid()) @db.Uuid
sessionToken String @unique
- userId String
+ userId String @db.Uuid
expires DateTime
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
}
model User {
- id String @id @default(cuid())
- name String?
- email String? @unique
- emailVerified DateTime?
- image String?
- accounts Account[]
- sessions Session[]
+ id String @id @default(uuid()) @db.Uuid
+ name String?
+ email String? @unique
+ emailVerified DateTime?
+ image String?
+ accounts Account[]
+ sessions Session[]
+ OrganizationUser OrganizationUser[]
+ Organization Organization[]
}
model VerificationToken {
diff --git a/prisma/seed.ts b/prisma/seed.ts
index a4fe415..156fd9f 100644
--- a/prisma/seed.ts
+++ b/prisma/seed.ts
@@ -2,40 +2,47 @@ import { prisma } from "~/server/db";
import dedent from "dedent";
import { generateNewCell } from "~/server/utils/generateNewCell";
-const experimentId = "11111111-1111-1111-1111-111111111111";
+const defaultId = "11111111-1111-1111-1111-111111111111";
+
+await prisma.organization.deleteMany({
+ where: { id: defaultId },
+});
+await prisma.organization.create({
+ data: { id: defaultId },
+});
-// Delete the existing experiment
await prisma.experiment.deleteMany({
where: {
- id: experimentId,
+ id: defaultId,
},
});
await prisma.experiment.create({
data: {
- id: experimentId,
+ id: defaultId,
label: "Country Capitals Example",
+ organizationId: defaultId,
},
});
await prisma.scenarioVariantCell.deleteMany({
where: {
promptVariant: {
- experimentId,
+ experimentId: defaultId,
},
},
});
await prisma.promptVariant.deleteMany({
where: {
- experimentId,
+ experimentId: defaultId,
},
});
await prisma.promptVariant.createMany({
data: [
{
- experimentId,
+ experimentId: defaultId,
label: "Prompt Variant 1",
sortIndex: 0,
model: "gpt-3.5-turbo-0613",
@@ -52,7 +59,7 @@ await prisma.promptVariant.createMany({
}`,
},
{
- experimentId,
+ experimentId: defaultId,
label: "Prompt Variant 2",
sortIndex: 1,
model: "gpt-3.5-turbo-0613",
@@ -73,14 +80,14 @@ await prisma.promptVariant.createMany({
await prisma.templateVariable.deleteMany({
where: {
- experimentId,
+ experimentId: defaultId,
},
});
await prisma.templateVariable.createMany({
data: [
{
- experimentId,
+ experimentId: defaultId,
label: "country",
},
],
@@ -88,28 +95,28 @@ await prisma.templateVariable.createMany({
await prisma.testScenario.deleteMany({
where: {
- experimentId,
+ experimentId: defaultId,
},
});
await prisma.testScenario.createMany({
data: [
{
- experimentId,
+ experimentId: defaultId,
sortIndex: 0,
variableValues: {
country: "Spain",
},
},
{
- experimentId,
+ experimentId: defaultId,
sortIndex: 1,
variableValues: {
country: "USA",
},
},
{
- experimentId,
+ experimentId: defaultId,
sortIndex: 2,
variableValues: {
country: "Chile",
@@ -120,13 +127,13 @@ await prisma.testScenario.createMany({
const variants = await prisma.promptVariant.findMany({
where: {
- experimentId,
+ experimentId: defaultId,
},
});
const scenarios = await prisma.testScenario.findMany({
where: {
- experimentId,
+ experimentId: defaultId,
},
});
diff --git a/render.yaml b/render.yaml
index 485ae2f..c0b18b8 100644
--- a/render.yaml
+++ b/render.yaml
@@ -12,7 +12,7 @@ services:
dockerContext: .
plan: standard
domains:
- - openpipe.ai
+ - app.openpipe.ai
envVars:
- key: NODE_ENV
value: production
diff --git a/src/components/OutputsTable/NewScenarioButton.tsx b/src/components/OutputsTable/NewScenarioButton.tsx
index 10365a9..3f1ddb6 100644
--- a/src/components/OutputsTable/NewScenarioButton.tsx
+++ b/src/components/OutputsTable/NewScenarioButton.tsx
@@ -1,7 +1,7 @@
import { Button, type ButtonProps, HStack, Spinner, Icon } from "@chakra-ui/react";
import { BsPlus } from "react-icons/bs";
import { api } from "~/utils/api";
-import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
+import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
// Extracted Button styling into reusable component
const StyledButton = ({ children, onClick }: ButtonProps) => (
@@ -17,6 +17,8 @@ const StyledButton = ({ children, onClick }: ButtonProps) => (
);
export default function NewScenarioButton() {
+ const { canModify } = useExperimentAccess();
+
const experiment = useExperiment();
const mutation = api.scenarios.create.useMutation();
const utils = api.useContext();
@@ -38,6 +40,8 @@ export default function NewScenarioButton() {
await utils.scenarios.list.invalidate();
}, [mutation]);
+ if (!canModify) return null;
+
return (
diff --git a/src/components/OutputsTable/NewVariantButton.tsx b/src/components/OutputsTable/NewVariantButton.tsx
index e4d6b47..b71bcf7 100644
--- a/src/components/OutputsTable/NewVariantButton.tsx
+++ b/src/components/OutputsTable/NewVariantButton.tsx
@@ -1,7 +1,7 @@
-import { Button, Icon, Spinner } from "@chakra-ui/react";
+import { Box, Button, Icon, Spinner } from "@chakra-ui/react";
import { BsPlus } from "react-icons/bs";
import { api } from "~/utils/api";
-import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
+import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
import { cellPadding, headerMinHeight } from "../constants";
export default function NewVariantButton() {
@@ -17,6 +17,9 @@ export default function NewVariantButton() {
await utils.promptVariants.list.invalidate();
}, [mutation]);
+ const { canModify } = useExperimentAccess();
+ if (!canModify) return ;
+
return (
+ {canModify && (
+ }
+ onClick={openDrawer}
+ >
+ Edit Vars
+
+ )}
);
diff --git a/src/components/OutputsTable/VariantEditor.tsx b/src/components/OutputsTable/VariantEditor.tsx
index d5102ba..af11845 100644
--- a/src/components/OutputsTable/VariantEditor.tsx
+++ b/src/components/OutputsTable/VariantEditor.tsx
@@ -1,11 +1,12 @@
import { Box, Button, HStack, Spinner, Tooltip, useToast, Text } from "@chakra-ui/react";
import { useRef, useEffect, useState, useCallback } from "react";
-import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
+import { useExperimentAccess, useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
import { type PromptVariant } from "./types";
import { api } from "~/utils/api";
import { useAppStore } from "~/state/store";
export default function VariantEditor(props: { variant: PromptVariant }) {
+ const { canModify } = useExperimentAccess();
const monaco = useAppStore.use.sharedVariantEditor.monaco();
const editorRef = useRef["editor"]["create"]> | null>(null);
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
@@ -40,18 +41,6 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
const model = editorRef.current.getModel();
if (!model) return;
- const markers = monaco?.editor.getModelMarkers({ resource: model.uri });
- const hasErrors = markers?.some((m) => m.severity === monaco?.MarkerSeverity.Error);
-
- if (hasErrors) {
- toast({
- title: "Invalid TypeScript",
- description: "Please fix the TypeScript errors before saving.",
- status: "error",
- });
- return;
- }
-
// Make sure the user defined the prompt with the string "prompt\w*=" somewhere
const promptRegex = /prompt\s*=/;
if (!promptRegex.test(currentFn)) {
@@ -103,6 +92,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
wordWrapBreakAfterCharacters: "",
wordWrapBreakBeforeCharacters: "",
quickSuggestions: true,
+ readOnly: !canModify,
});
editorRef.current.onDidFocusEditorText(() => {
@@ -130,6 +120,13 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
/* eslint-disable-next-line react-hooks/exhaustive-deps */
}, [monaco, editorId]);
+ useEffect(() => {
+ if (!editorRef.current) return;
+ editorRef.current.updateOptions({
+ readOnly: !canModify,
+ });
+ }, [canModify]);
+
return (
diff --git a/src/components/OutputsTable/styles.ts b/src/components/OutputsTable/styles.ts
index 0dc8dc8..fcdac8f 100644
--- a/src/components/OutputsTable/styles.ts
+++ b/src/components/OutputsTable/styles.ts
@@ -2,7 +2,7 @@ import { type SystemStyleObject } from "@chakra-ui/react";
export const stickyHeaderStyle: SystemStyleObject = {
position: "sticky",
- top: "-1px",
+ top: "0",
backgroundColor: "#fff",
zIndex: 1,
};
diff --git a/src/components/PublicPlaygroundWarning.tsx b/src/components/PublicPlaygroundWarning.tsx
deleted file mode 100644
index ef612e5..0000000
--- a/src/components/PublicPlaygroundWarning.tsx
+++ /dev/null
@@ -1,21 +0,0 @@
-import { Flex, Icon, Link, Text } from "@chakra-ui/react";
-import { BsExclamationTriangleFill } from "react-icons/bs";
-import { env } from "~/env.mjs";
-
-export default function PublicPlaygroundWarning() {
- if (!env.NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND) return null;
-
- return (
-
-
-
- Warning: this is a public playground. Anyone can see, edit or delete your experiments. For
- private use,{" "}
-
- run a local copy
-
- .
-
-
- );
-}
diff --git a/src/components/VariantHeader/VariantHeader.tsx b/src/components/VariantHeader/VariantHeader.tsx
index 96ad8a9..3c83316 100644
--- a/src/components/VariantHeader/VariantHeader.tsx
+++ b/src/components/VariantHeader/VariantHeader.tsx
@@ -1,15 +1,16 @@
import { useState, type DragEvent } from "react";
import { type PromptVariant } from "../OutputsTable/types";
import { api } from "~/utils/api";
-import { useHandledAsyncCallback } from "~/utils/hooks";
-import { HStack, Icon, GridItem } from "@chakra-ui/react"; // Changed here
import { RiDraggable } from "react-icons/ri";
+import { useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
+import { HStack, Icon, Text, GridItem } from "@chakra-ui/react"; // Changed here
import { cellPadding, headerMinHeight } from "../constants";
import AutoResizeTextArea from "../AutoResizeTextArea";
import { stickyHeaderStyle } from "../OutputsTable/styles";
import VariantHeaderMenuButton from "./VariantHeaderMenuButton";
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
+ const { canModify } = useExperimentAccess();
const utils = api.useContext();
const [isDragTarget, setIsDragTarget] = useState(false);
const [isInputHovered, setIsInputHovered] = useState(false);
@@ -44,6 +45,16 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
const [menuOpen, setMenuOpen] = useState(false);
+ if (!canModify) {
+ return (
+
+
+ {props.variant.label}
+
+
+ );
+ }
+
return (
{
- const isActive = useRouter().pathname.startsWith(href);
+ const router = useRouter();
+ const isActive = href && router.pathname.startsWith(href);
return (
{
+ const user = useSession().data;
+
return (
@@ -59,26 +63,32 @@ const NavSidebar = () => {
-
+ {user != null && (
+ <>
+
+ >
+ )}
+ {user === null && (
+ {
+ signIn("github").catch(console.error);
+ }}
+ />
+ )}
-
-
- : }
+
+
-
+ p={2}
+ >
+
+
);
@@ -108,16 +118,13 @@ export default function AppShell(props: { children: React.ReactNode; title?: str
{props.title ? `${props.title} | OpenPipe` : "OpenPipe"}
-
-
-
diff --git a/src/components/nav/UserMenu.tsx b/src/components/nav/UserMenu.tsx
new file mode 100644
index 0000000..37710ff
--- /dev/null
+++ b/src/components/nav/UserMenu.tsx
@@ -0,0 +1,72 @@
+import {
+ HStack,
+ Icon,
+ Image,
+ VStack,
+ Text,
+ Popover,
+ PopoverTrigger,
+ PopoverContent,
+ Link,
+} from "@chakra-ui/react";
+import { type Session } from "next-auth";
+import { signOut } from "next-auth/react";
+import { BsBoxArrowRight, BsChevronRight, BsPersonCircle } from "react-icons/bs";
+
+export default function UserMenu({ user }: { user: Session }) {
+ const profileImage = user.user.image ? (
+
+ ) : (
+
+ );
+
+ return (
+ <>
+
+
+
+ {profileImage}
+
+
+ {user.user.name}
+
+
+ {user.user.email}
+
+
+
+
+
+
+
+ {/* sign out */}
+ {
+ signOut().catch(console.error);
+ }}
+ px={4}
+ py={2}
+ spacing={4}
+ color="gray.500"
+ fontSize="sm"
+ >
+
+ Sign out
+
+
+
+
+ >
+ );
+}
diff --git a/src/env.mjs b/src/env.mjs
index cb57366..2032c08 100644
--- a/src/env.mjs
+++ b/src/env.mjs
@@ -15,6 +15,8 @@ export const env = createEnv({
.optional()
.default("false")
.transform((val) => val.toLowerCase() === "true"),
+ GITHUB_CLIENT_ID: z.string().min(1),
+ GITHUB_CLIENT_SECRET: z.string().min(1),
},
/**
@@ -24,11 +26,6 @@ export const env = createEnv({
*/
client: {
NEXT_PUBLIC_POSTHOG_KEY: z.string().optional(),
- NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND: z
- .string()
- .optional()
- .default("false")
- .transform((val) => val.toLowerCase() === "true"),
NEXT_PUBLIC_SOCKET_URL: z.string().url().default("http://localhost:3318"),
},
@@ -42,8 +39,9 @@ export const env = createEnv({
OPENAI_API_KEY: process.env.OPENAI_API_KEY,
RESTRICT_PRISMA_LOGS: process.env.RESTRICT_PRISMA_LOGS,
NEXT_PUBLIC_POSTHOG_KEY: process.env.NEXT_PUBLIC_POSTHOG_KEY,
- NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND: process.env.NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND,
NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL,
+ GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID,
+ GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET,
},
/**
* Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation.
diff --git a/src/pages/account/signin.tsx b/src/pages/account/signin.tsx
new file mode 100644
index 0000000..74e00d4
--- /dev/null
+++ b/src/pages/account/signin.tsx
@@ -0,0 +1,23 @@
+import { signIn, useSession } from "next-auth/react";
+import { useRouter } from "next/router";
+import { useEffect } from "react";
+import AppShell from "~/components/nav/AppShell";
+
+export default function SignIn() {
+ const session = useSession().data;
+ const router = useRouter();
+
+ useEffect(() => {
+ if (session) {
+ router.push("/experiments").catch(console.error);
+ } else if (session === null) {
+ signIn("github").catch(console.error);
+ }
+ }, [session, router]);
+
+ return (
+
+
+
+ );
+}
diff --git a/src/pages/experiments/[id].tsx b/src/pages/experiments/[id].tsx
index 4d5e3a7..896e248 100644
--- a/src/pages/experiments/[id].tsx
+++ b/src/pages/experiments/[id].tsx
@@ -124,6 +124,8 @@ export default function Experiment() {
);
}
+ const canModify = experiment.data?.access.canModify ?? false;
+
return (
@@ -143,37 +145,45 @@ export default function Experiment() {
- setLabel(e.target.value)}
- onBlur={onSaveLabel}
- borderWidth={1}
- borderColor="transparent"
- fontSize={16}
- px={0}
- minW={{ base: 100, lg: 300 }}
- flex={1}
- _hover={{ borderColor: "gray.300" }}
- _focus={{ borderColor: "blue.500", outline: "none" }}
- />
+ {canModify ? (
+ setLabel(e.target.value)}
+ onBlur={onSaveLabel}
+ borderWidth={1}
+ borderColor="transparent"
+ fontSize={16}
+ px={0}
+ minW={{ base: 100, lg: 300 }}
+ flex={1}
+ _hover={{ borderColor: "gray.300" }}
+ _focus={{ borderColor: "blue.500", outline: "none" }}
+ />
+ ) : (
+
+ {experiment.data?.label}
+
+ )}
-
-
-
-
- Edit Vars & Evals
-
-
-
-
+ {canModify && (
+
+
+
+
+ Edit Vars & Evals
+
+
+
+
+ )}
diff --git a/src/pages/experiments/index.tsx b/src/pages/experiments/index.tsx
index 4af8344..cb2702b 100644
--- a/src/pages/experiments/index.tsx
+++ b/src/pages/experiments/index.tsx
@@ -6,18 +6,44 @@ import {
Breadcrumb,
BreadcrumbItem,
Flex,
+ Center,
+ Text,
+ Link,
} from "@chakra-ui/react";
import { RiFlaskLine } from "react-icons/ri";
import AppShell from "~/components/nav/AppShell";
import { api } from "~/utils/api";
import { NewExperimentButton } from "~/components/experiments/NewExperimentButton";
import { ExperimentCard } from "~/components/experiments/ExperimentCard";
+import { signIn, useSession } from "next-auth/react";
export default function ExperimentsPage() {
const experiments = api.experiments.list.useQuery();
+ const user = useSession().data;
+
+ if (user === null) {
+ return (
+
+
+
+ {
+ signIn("github").catch(console.error);
+ }}
+ textDecor="underline"
+ >
+ Sign in
+ {" "}
+ to view or create new experiments!
+
+
+
+ );
+ }
+
return (
-
+
diff --git a/src/server/api/routers/evaluations.router.ts b/src/server/api/routers/evaluations.router.ts
index 7ee0d12..f113c20 100644
--- a/src/server/api/routers/evaluations.router.ts
+++ b/src/server/api/routers/evaluations.router.ts
@@ -1,20 +1,25 @@
import { EvalType } from "@prisma/client";
import { z } from "zod";
-import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
+import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { runAllEvals } from "~/server/utils/evaluations";
+import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const evaluationsRouter = createTRPCRouter({
- list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
- return await prisma.evaluation.findMany({
- where: {
- experimentId: input.experimentId,
- },
- orderBy: { createdAt: "asc" },
- });
- }),
+ list: publicProcedure
+ .input(z.object({ experimentId: z.string() }))
+ .query(async ({ input, ctx }) => {
+ await requireCanViewExperiment(input.experimentId, ctx);
- create: publicProcedure
+ return await prisma.evaluation.findMany({
+ where: {
+ experimentId: input.experimentId,
+ },
+ orderBy: { createdAt: "asc" },
+ });
+ }),
+
+ create: protectedProcedure
.input(
z.object({
experimentId: z.string(),
@@ -23,7 +28,9 @@ export const evaluationsRouter = createTRPCRouter({
evalType: z.nativeEnum(EvalType),
}),
)
- .mutation(async ({ input }) => {
+ .mutation(async ({ input, ctx }) => {
+ await requireCanModifyExperiment(input.experimentId, ctx);
+
await prisma.evaluation.create({
data: {
experimentId: input.experimentId,
@@ -38,7 +45,7 @@ export const evaluationsRouter = createTRPCRouter({
await runAllEvals(input.experimentId);
}),
- update: publicProcedure
+ update: protectedProcedure
.input(
z.object({
id: z.string(),
@@ -49,7 +56,12 @@ export const evaluationsRouter = createTRPCRouter({
}),
}),
)
- .mutation(async ({ input }) => {
+ .mutation(async ({ input, ctx }) => {
+ const { experimentId } = await prisma.evaluation.findUniqueOrThrow({
+ where: { id: input.id },
+ });
+ await requireCanModifyExperiment(experimentId, ctx);
+
const evaluation = await prisma.evaluation.update({
where: { id: input.id },
data: {
@@ -69,9 +81,16 @@ export const evaluationsRouter = createTRPCRouter({
await runAllEvals(evaluation.experimentId);
}),
- delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
- await prisma.evaluation.delete({
- where: { id: input.id },
- });
- }),
+ delete: protectedProcedure
+ .input(z.object({ id: z.string() }))
+ .mutation(async ({ input, ctx }) => {
+ const { experimentId } = await prisma.evaluation.findUniqueOrThrow({
+ where: { id: input.id },
+ });
+ await requireCanModifyExperiment(experimentId, ctx);
+
+ await prisma.evaluation.delete({
+ where: { id: input.id },
+ });
+ }),
});
diff --git a/src/server/api/routers/experiments.router.ts b/src/server/api/routers/experiments.router.ts
index 5ed8c85..02931e8 100644
--- a/src/server/api/routers/experiments.router.ts
+++ b/src/server/api/routers/experiments.router.ts
@@ -1,12 +1,29 @@
import { z } from "zod";
-import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
+import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import dedent from "dedent";
import { generateNewCell } from "~/server/utils/generateNewCell";
+import {
+ canModifyExperiment,
+ requireCanModifyExperiment,
+ requireCanViewExperiment,
+ requireNothing,
+} from "~/utils/accessControl";
+import userOrg from "~/server/utils/userOrg";
export const experimentsRouter = createTRPCRouter({
- list: publicProcedure.query(async () => {
+ list: protectedProcedure.query(async ({ ctx }) => {
+ // Anyone can list experiments
+ requireNothing(ctx);
+
const experiments = await prisma.experiment.findMany({
+ where: {
+ organization: {
+ OrganizationUser: {
+ some: { userId: ctx.session.user.id },
+ },
+ },
+ },
orderBy: {
sortIndex: "asc",
},
@@ -40,15 +57,29 @@ export const experimentsRouter = createTRPCRouter({
return experimentsWithCounts;
}),
- get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input }) => {
- return await prisma.experiment.findFirst({
- where: {
- id: input.id,
- },
+ get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
+ await requireCanViewExperiment(input.id, ctx);
+ const experiment = await prisma.experiment.findFirstOrThrow({
+ where: { id: input.id },
});
+
+ const canModify = ctx.session?.user.id
+ ? await canModifyExperiment(experiment.id, ctx.session?.user.id)
+ : false;
+
+ return {
+ ...experiment,
+ access: {
+ canView: true,
+ canModify,
+ },
+ };
}),
- create: publicProcedure.input(z.object({})).mutation(async () => {
+ create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => {
+ // Anyone can create an experiment
+ requireNothing(ctx);
+
const maxSortIndex =
(
await prisma.experiment.aggregate({
@@ -62,6 +93,7 @@ export const experimentsRouter = createTRPCRouter({
data: {
sortIndex: maxSortIndex + 1,
label: `Experiment ${maxSortIndex + 1}`,
+ organizationId: (await userOrg(ctx.session.user.id)).id,
},
});
@@ -117,9 +149,10 @@ export const experimentsRouter = createTRPCRouter({
return exp;
}),
- update: publicProcedure
+ update: protectedProcedure
.input(z.object({ id: z.string(), updates: z.object({ label: z.string() }) }))
- .mutation(async ({ input }) => {
+ .mutation(async ({ input, ctx }) => {
+ await requireCanModifyExperiment(input.id, ctx);
return await prisma.experiment.update({
where: {
id: input.id,
@@ -130,11 +163,15 @@ export const experimentsRouter = createTRPCRouter({
});
}),
- delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
- await prisma.experiment.delete({
- where: {
- id: input.id,
- },
- });
- }),
+ delete: protectedProcedure
+ .input(z.object({ id: z.string() }))
+ .mutation(async ({ input, ctx }) => {
+ await requireCanModifyExperiment(input.id, ctx);
+
+ await prisma.experiment.delete({
+ where: {
+ id: input.id,
+ },
+ });
+ }),
});
diff --git a/src/server/api/routers/promptVariants.router.ts b/src/server/api/routers/promptVariants.router.ts
index ac9ead2..d08243a 100644
--- a/src/server/api/routers/promptVariants.router.ts
+++ b/src/server/api/routers/promptVariants.router.ts
@@ -1,6 +1,6 @@
import { isObject } from "lodash-es";
import { z } from "zod";
-import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
+import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { generateNewCell } from "~/server/utils/generateNewCell";
import { OpenAIChatModel, type SupportedModel } from "~/server/types";
@@ -11,129 +11,140 @@ import { calculateTokenCost } from "~/utils/calculateTokenCost";
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
import { type PromptVariant } from "@prisma/client";
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
+import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const promptVariantsRouter = createTRPCRouter({
- list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
- return await prisma.promptVariant.findMany({
- where: {
- experimentId: input.experimentId,
- visible: true,
- },
- orderBy: { sortIndex: "asc" },
- });
- }),
+ list: publicProcedure
+ .input(z.object({ experimentId: z.string() }))
+ .query(async ({ input, ctx }) => {
+ await requireCanViewExperiment(input.experimentId, ctx);
- stats: publicProcedure.input(z.object({ variantId: z.string() })).query(async ({ input }) => {
- const variant = await prisma.promptVariant.findUnique({
- where: {
- id: input.variantId,
- },
- });
+ return await prisma.promptVariant.findMany({
+ where: {
+ experimentId: input.experimentId,
+ visible: true,
+ },
+ orderBy: { sortIndex: "asc" },
+ });
+ }),
- if (!variant) {
- throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
- }
+ stats: publicProcedure
+ .input(z.object({ variantId: z.string() }))
+ .query(async ({ input, ctx }) => {
+ const variant = await prisma.promptVariant.findUnique({
+ where: {
+ id: input.variantId,
+ },
+ });
- const outputEvals = await prisma.outputEvaluation.groupBy({
- by: ["evaluationId"],
- _sum: {
- result: true,
- },
- _count: {
- id: true,
- },
- where: {
- modelOutput: {
- scenarioVariantCell: {
- promptVariant: {
- id: input.variantId,
- visible: true,
+ if (!variant) {
+ throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
+ }
+
+ await requireCanViewExperiment(variant.experimentId, ctx);
+
+ const outputEvals = await prisma.outputEvaluation.groupBy({
+ by: ["evaluationId"],
+ _sum: {
+ result: true,
+ },
+ _count: {
+ id: true,
+ },
+ where: {
+ modelOutput: {
+ scenarioVariantCell: {
+ promptVariant: {
+ id: input.variantId,
+ visible: true,
+ },
+ testScenario: {
+ visible: true,
+ },
},
+ },
+ },
+ });
+
+ const evals = await prisma.evaluation.findMany({
+ where: {
+ experimentId: variant.experimentId,
+ },
+ });
+
+ const evalResults = evals.map((evalItem) => {
+ const evalResult = outputEvals.find(
+ (outputEval) => outputEval.evaluationId === evalItem.id,
+ );
+ return {
+ id: evalItem.id,
+ label: evalItem.label,
+ passCount: evalResult?._sum?.result ?? 0,
+ totalCount: evalResult?._count?.id ?? 1,
+ };
+ });
+
+ const scenarioCount = await prisma.testScenario.count({
+ where: {
+ experimentId: variant.experimentId,
+ visible: true,
+ },
+ });
+ const outputCount = await prisma.scenarioVariantCell.count({
+ where: {
+ promptVariantId: input.variantId,
+ testScenario: { visible: true },
+ modelOutput: {
+ is: {},
+ },
+ },
+ });
+
+ const overallTokens = await prisma.modelOutput.aggregate({
+ where: {
+ scenarioVariantCell: {
+ promptVariantId: input.variantId,
testScenario: {
visible: true,
},
},
},
- },
- });
-
- const evals = await prisma.evaluation.findMany({
- where: {
- experimentId: variant.experimentId,
- },
- });
-
- const evalResults = evals.map((evalItem) => {
- const evalResult = outputEvals.find((outputEval) => outputEval.evaluationId === evalItem.id);
- return {
- id: evalItem.id,
- label: evalItem.label,
- passCount: evalResult?._sum?.result ?? 0,
- totalCount: evalResult?._count?.id ?? 1,
- };
- });
-
- const scenarioCount = await prisma.testScenario.count({
- where: {
- experimentId: variant.experimentId,
- visible: true,
- },
- });
- const outputCount = await prisma.scenarioVariantCell.count({
- where: {
- promptVariantId: input.variantId,
- testScenario: { visible: true },
- modelOutput: {
- is: {},
+ _sum: {
+ promptTokens: true,
+ completionTokens: true,
},
- },
- });
+ });
- const overallTokens = await prisma.modelOutput.aggregate({
- where: {
- scenarioVariantCell: {
+ const promptTokens = overallTokens._sum?.promptTokens ?? 0;
+ const overallPromptCost = calculateTokenCost(variant.model, promptTokens);
+ const completionTokens = overallTokens._sum?.completionTokens ?? 0;
+ const overallCompletionCost = calculateTokenCost(variant.model, completionTokens, true);
+
+ const overallCost = overallPromptCost + overallCompletionCost;
+
+ const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({
+ where: {
promptVariantId: input.variantId,
- testScenario: {
- visible: true,
+ testScenario: { visible: true },
+ // Check if is PENDING or IN_PROGRESS
+ retrievalStatus: {
+ in: ["PENDING", "IN_PROGRESS"],
},
},
- },
- _sum: {
- promptTokens: true,
- completionTokens: true,
- },
- });
+ }));
- const promptTokens = overallTokens._sum?.promptTokens ?? 0;
- const overallPromptCost = calculateTokenCost(variant.model, promptTokens);
- const completionTokens = overallTokens._sum?.completionTokens ?? 0;
- const overallCompletionCost = calculateTokenCost(variant.model, completionTokens, true);
+ return {
+ evalResults,
+ promptTokens,
+ completionTokens,
+ overallCost,
+ scenarioCount,
+ outputCount,
+ awaitingRetrievals,
+ };
+ }),
- const overallCost = overallPromptCost + overallCompletionCost;
-
- const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({
- where: {
- promptVariantId: input.variantId,
- testScenario: { visible: true },
- // Check if is PENDING or IN_PROGRESS
- retrievalStatus: {
- in: ["PENDING", "IN_PROGRESS"],
- },
- },
- }));
-
- return {
- evalResults,
- promptTokens,
- completionTokens,
- overallCost,
- scenarioCount,
- outputCount,
- awaitingRetrievals,
- };
- }),
-
- create: publicProcedure
+ create: protectedProcedure
.input(
z.object({
experimentId: z.string(),
@@ -141,7 +152,9 @@ export const promptVariantsRouter = createTRPCRouter({
newModel: z.string().optional(),
}),
)
- .mutation(async ({ input }) => {
+ .mutation(async ({ input, ctx }) => {
+ await requireCanViewExperiment(input.experimentId, ctx);
+
let originalVariant: PromptVariant | null = null;
if (input.variantId) {
originalVariant = await prisma.promptVariant.findUnique({
@@ -217,7 +230,7 @@ export const promptVariantsRouter = createTRPCRouter({
return newVariant;
}),
- update: publicProcedure
+ update: protectedProcedure
.input(
z.object({
id: z.string(),
@@ -226,7 +239,7 @@ export const promptVariantsRouter = createTRPCRouter({
}),
}),
)
- .mutation(async ({ input }) => {
+ .mutation(async ({ input, ctx }) => {
const existing = await prisma.promptVariant.findUnique({
where: {
id: input.id,
@@ -237,6 +250,8 @@ export const promptVariantsRouter = createTRPCRouter({
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
}
+ await requireCanModifyExperiment(existing.experimentId, ctx);
+
const updatePromptVariantAction = prisma.promptVariant.update({
where: {
id: input.id,
@@ -252,13 +267,18 @@ export const promptVariantsRouter = createTRPCRouter({
return updatedPromptVariant;
}),
- hide: publicProcedure
+ hide: protectedProcedure
.input(
z.object({
id: z.string(),
}),
)
- .mutation(async ({ input }) => {
+ .mutation(async ({ input, ctx }) => {
+ const { experimentId } = await prisma.promptVariant.findUniqueOrThrow({
+ where: { id: input.id },
+ });
+ await requireCanModifyExperiment(experimentId, ctx);
+
const updatedPromptVariant = await prisma.promptVariant.update({
where: { id: input.id },
data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
@@ -267,19 +287,20 @@ export const promptVariantsRouter = createTRPCRouter({
return updatedPromptVariant;
}),
- replaceVariant: publicProcedure
+ replaceVariant: protectedProcedure
.input(
z.object({
id: z.string(),
constructFn: z.string(),
}),
)
- .mutation(async ({ input }) => {
- const existing = await prisma.promptVariant.findUnique({
+ .mutation(async ({ input, ctx }) => {
+ const existing = await prisma.promptVariant.findUniqueOrThrow({
where: {
id: input.id,
},
});
+ await requireCanModifyExperiment(existing.experimentId, ctx);
if (!existing) {
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
@@ -347,14 +368,19 @@ export const promptVariantsRouter = createTRPCRouter({
return { status: "ok" } as const;
}),
- reorder: publicProcedure
+ reorder: protectedProcedure
.input(
z.object({
draggedId: z.string(),
droppedId: z.string(),
}),
)
- .mutation(async ({ input }) => {
+ .mutation(async ({ input, ctx }) => {
+ const { experimentId } = await prisma.promptVariant.findUniqueOrThrow({
+ where: { id: input.draggedId },
+ });
+ await requireCanModifyExperiment(experimentId, ctx);
+
await reorderPromptVariants(input.draggedId, input.droppedId);
}),
});
diff --git a/src/server/api/routers/scenarioVariantCells.router.ts b/src/server/api/routers/scenarioVariantCells.router.ts
index b07657e..fee3d7a 100644
--- a/src/server/api/routers/scenarioVariantCells.router.ts
+++ b/src/server/api/routers/scenarioVariantCells.router.ts
@@ -1,8 +1,9 @@
import { z } from "zod";
-import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
+import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { generateNewCell } from "~/server/utils/generateNewCell";
import { queueLLMRetrievalTask } from "~/server/utils/queueLLMRetrievalTask";
+import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const scenarioVariantCellsRouter = createTRPCRouter({
get: publicProcedure
@@ -12,7 +13,12 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
variantId: z.string(),
}),
)
- .query(async ({ input }) => {
+ .query(async ({ input, ctx }) => {
+ const { experimentId } = await prisma.testScenario.findUniqueOrThrow({
+ where: { id: input.scenarioId },
+ });
+ await requireCanViewExperiment(experimentId, ctx);
+
return await prisma.scenarioVariantCell.findUnique({
where: {
promptVariantId_testScenarioId: {
@@ -35,14 +41,20 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
},
});
}),
- forceRefetch: publicProcedure
+ forceRefetch: protectedProcedure
.input(
z.object({
scenarioId: z.string(),
variantId: z.string(),
}),
)
- .mutation(async ({ input }) => {
+ .mutation(async ({ input, ctx }) => {
+ const { experimentId } = await prisma.testScenario.findUniqueOrThrow({
+ where: { id: input.scenarioId },
+ });
+
+ await requireCanModifyExperiment(experimentId, ctx);
+
const cell = await prisma.scenarioVariantCell.findUnique({
where: {
promptVariantId_testScenarioId: {
diff --git a/src/server/api/routers/scenarios.router.ts b/src/server/api/routers/scenarios.router.ts
index 0ddfb0b..91f1852 100644
--- a/src/server/api/routers/scenarios.router.ts
+++ b/src/server/api/routers/scenarios.router.ts
@@ -1,32 +1,39 @@
import { z } from "zod";
-import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
+import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { autogenerateScenarioValues } from "../autogen";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
import { runAllEvals } from "~/server/utils/evaluations";
import { generateNewCell } from "~/server/utils/generateNewCell";
+import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const scenariosRouter = createTRPCRouter({
- list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
- return await prisma.testScenario.findMany({
- where: {
- experimentId: input.experimentId,
- visible: true,
- },
- orderBy: {
- sortIndex: "asc",
- },
- });
- }),
+ list: publicProcedure
+ .input(z.object({ experimentId: z.string() }))
+ .query(async ({ input, ctx }) => {
+ await requireCanViewExperiment(input.experimentId, ctx);
- create: publicProcedure
+ return await prisma.testScenario.findMany({
+ where: {
+ experimentId: input.experimentId,
+ visible: true,
+ },
+ orderBy: {
+ sortIndex: "asc",
+ },
+ });
+ }),
+
+ create: protectedProcedure
.input(
z.object({
experimentId: z.string(),
autogenerate: z.boolean().optional(),
}),
)
- .mutation(async ({ input }) => {
+ .mutation(async ({ input, ctx }) => {
+ await requireCanModifyExperiment(input.experimentId, ctx);
+
const maxSortIndex =
(
await prisma.testScenario.aggregate({
@@ -66,7 +73,14 @@ export const scenariosRouter = createTRPCRouter({
}
}),
- hide: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
+ hide: protectedProcedure.input(z.object({ id: z.string() })).mutation(async ({ input, ctx }) => {
+ const experimentId = (
+ await prisma.testScenario.findUniqueOrThrow({
+ where: { id: input.id },
+ })
+ ).experimentId;
+
+ await requireCanModifyExperiment(experimentId, ctx);
const hiddenScenario = await prisma.testScenario.update({
where: { id: input.id },
data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
@@ -78,14 +92,14 @@ export const scenariosRouter = createTRPCRouter({
return hiddenScenario;
}),
- reorder: publicProcedure
+ reorder: protectedProcedure
.input(
z.object({
draggedId: z.string(),
droppedId: z.string(),
}),
)
- .mutation(async ({ input }) => {
+ .mutation(async ({ input, ctx }) => {
const dragged = await prisma.testScenario.findUnique({
where: {
id: input.draggedId,
@@ -104,6 +118,8 @@ export const scenariosRouter = createTRPCRouter({
);
}
+ await requireCanModifyExperiment(dragged.experimentId, ctx);
+
const visibleItems = await prisma.testScenario.findMany({
where: {
experimentId: dragged.experimentId,
@@ -147,14 +163,14 @@ export const scenariosRouter = createTRPCRouter({
);
}),
- replaceWithValues: publicProcedure
+ replaceWithValues: protectedProcedure
.input(
z.object({
id: z.string(),
values: z.record(z.string()),
}),
)
- .mutation(async ({ input }) => {
+ .mutation(async ({ input, ctx }) => {
const existing = await prisma.testScenario.findUnique({
where: {
id: input.id,
@@ -165,6 +181,8 @@ export const scenariosRouter = createTRPCRouter({
throw new Error(`Scenario with id ${input.id} does not exist`);
}
+ await requireCanModifyExperiment(existing.experimentId, ctx);
+
const newScenario = await prisma.testScenario.create({
data: {
experimentId: existing.experimentId,
diff --git a/src/server/api/routers/templateVariables.router.ts b/src/server/api/routers/templateVariables.router.ts
index 6762112..d62fec4 100644
--- a/src/server/api/routers/templateVariables.router.ts
+++ b/src/server/api/routers/templateVariables.router.ts
@@ -1,11 +1,14 @@
import { z } from "zod";
-import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
+import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
+import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const templateVarsRouter = createTRPCRouter({
- create: publicProcedure
+ create: protectedProcedure
.input(z.object({ experimentId: z.string(), label: z.string() }))
- .mutation(async ({ input }) => {
+ .mutation(async ({ input, ctx }) => {
+ await requireCanModifyExperiment(input.experimentId, ctx);
+
await prisma.templateVariable.create({
data: {
experimentId: input.experimentId,
@@ -14,22 +17,33 @@ export const templateVarsRouter = createTRPCRouter({
});
}),
- delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
- await prisma.templateVariable.delete({ where: { id: input.id } });
- }),
+ delete: protectedProcedure
+ .input(z.object({ id: z.string() }))
+ .mutation(async ({ input, ctx }) => {
+ const { experimentId } = await prisma.templateVariable.findUniqueOrThrow({
+ where: { id: input.id },
+ });
- list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
- return await prisma.templateVariable.findMany({
- where: {
- experimentId: input.experimentId,
- },
- orderBy: {
- createdAt: "asc",
- },
- select: {
- id: true,
- label: true,
- },
- });
- }),
+ await requireCanModifyExperiment(experimentId, ctx);
+
+ await prisma.templateVariable.delete({ where: { id: input.id } });
+ }),
+
+ list: publicProcedure
+ .input(z.object({ experimentId: z.string() }))
+ .query(async ({ input, ctx }) => {
+ await requireCanViewExperiment(input.experimentId, ctx);
+ return await prisma.templateVariable.findMany({
+ where: {
+ experimentId: input.experimentId,
+ },
+ orderBy: {
+ createdAt: "asc",
+ },
+ select: {
+ id: true,
+ label: true,
+ },
+ });
+ }),
});
diff --git a/src/server/api/trpc.ts b/src/server/api/trpc.ts
index 244cabd..e795857 100644
--- a/src/server/api/trpc.ts
+++ b/src/server/api/trpc.ts
@@ -27,6 +27,9 @@ type CreateContextOptions = {
session: Session | null;
};
+// eslint-disable-next-line @typescript-eslint/no-empty-function
+const noOp = () => {};
+
/**
* This helper generates the "internals" for a tRPC context. If you need to use it, you can export
* it from here.
@@ -41,6 +44,7 @@ const createInnerTRPCContext = (opts: CreateContextOptions) => {
return {
session: opts.session,
prisma,
+ markAccessControlRun: noOp,
};
};
@@ -69,6 +73,8 @@ export const createTRPCContext = async (opts: CreateNextContextOptions) => {
* errors on the backend.
*/
+export type TRPCContext = Awaited>;
+
const t = initTRPC.context().create({
transformer: superjson,
errorFormatter({ shape, error }) {
@@ -106,16 +112,29 @@ export const createTRPCRouter = t.router;
export const publicProcedure = t.procedure;
/** Reusable middleware that enforces users are logged in before running the procedure. */
-const enforceUserIsAuthed = t.middleware(({ ctx, next }) => {
+const enforceUserIsAuthed = t.middleware(async ({ ctx, next }) => {
if (!ctx.session || !ctx.session.user) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
- return next({
+
+ let accessControlRun = false;
+ const resp = await next({
ctx: {
// infers the `session` as non-nullable
session: { ...ctx.session, user: ctx.session.user },
+ markAccessControlRun: () => {
+ accessControlRun = true;
+ },
},
});
+ if (!accessControlRun)
+ throw new TRPCError({
+ code: "INTERNAL_SERVER_ERROR",
+ message:
+ "Protected routes must perform access control checks then explicitly invoke the `ctx.markAccessControlRun()` function to ensure we don't forget access control on a route.",
+ });
+
+ return resp;
});
/**
diff --git a/src/server/auth.ts b/src/server/auth.ts
index 4531237..f2b779c 100644
--- a/src/server/auth.ts
+++ b/src/server/auth.ts
@@ -2,6 +2,8 @@ import { PrismaAdapter } from "@next-auth/prisma-adapter";
import { type GetServerSidePropsContext } from "next";
import { getServerSession, type NextAuthOptions, type DefaultSession } from "next-auth";
import { prisma } from "~/server/db";
+import GitHubProvider from "next-auth/providers/github";
+import { env } from "~/env.mjs";
/**
* Module augmentation for `next-auth` types. Allows us to add custom properties to the `session`
@@ -41,20 +43,15 @@ export const authOptions: NextAuthOptions = {
},
adapter: PrismaAdapter(prisma),
providers: [
- // DiscordProvider({
- // clientId: env.DISCORD_CLIENT_ID,
- // clientSecret: env.DISCORD_CLIENT_SECRET,
- // }),
- /**
- * ...add more providers here.
- *
- * Most other providers require a bit more work than the Discord provider. For example, the
- * GitHub provider requires you to add the `refresh_token_expires_in` field to the Account
- * model. Refer to the NextAuth.js docs for the provider you want to use. Example:
- *
- * @see https://next-auth.js.org/providers/github
- */
+ GitHubProvider({
+ clientId: env.GITHUB_CLIENT_ID,
+ clientSecret: env.GITHUB_CLIENT_SECRET,
+ }),
],
+ theme: {
+ logo: "/logo.svg",
+ brandColor: "#ff5733",
+ },
};
/**
diff --git a/src/server/utils/userOrg.ts b/src/server/utils/userOrg.ts
new file mode 100644
index 0000000..4158f87
--- /dev/null
+++ b/src/server/utils/userOrg.ts
@@ -0,0 +1,19 @@
+import { prisma } from "~/server/db";
+
+export default async function userOrg(userId: string) {
+ return await prisma.organization.upsert({
+ where: {
+ personalOrgUserId: userId,
+ },
+ update: {},
+ create: {
+ personalOrgUserId: userId,
+ OrganizationUser: {
+ create: {
+ userId: userId,
+ role: "ADMIN",
+ },
+ },
+ },
+ });
+}
diff --git a/src/utils/accessControl.ts b/src/utils/accessControl.ts
new file mode 100644
index 0000000..e43d57d
--- /dev/null
+++ b/src/utils/accessControl.ts
@@ -0,0 +1,49 @@
+import { OrganizationUserRole } from "@prisma/client";
+import { TRPCError } from "@trpc/server";
+import { type TRPCContext } from "~/server/api/trpc";
+import { prisma } from "~/server/db";
+
+// No-op method for protected routes that really should be accessible to anyone.
+export const requireNothing = (ctx: TRPCContext) => {
+ ctx.markAccessControlRun();
+};
+
+export const requireCanViewExperiment = async (experimentId: string, ctx: TRPCContext) => {
+ await prisma.experiment.findFirst({
+ where: { id: experimentId },
+ });
+
+ // Right now all experiments are publicly viewable, so this is a no-op.
+ ctx.markAccessControlRun();
+};
+
+export const canModifyExperiment = async (experimentId: string, userId: string) => {
+ const experiment = await prisma.experiment.findFirst({
+ where: {
+ id: experimentId,
+ organization: {
+ OrganizationUser: {
+ some: {
+ role: { in: [OrganizationUserRole.ADMIN, OrganizationUserRole.MEMBER] },
+ userId,
+ },
+ },
+ },
+ },
+ });
+
+ return !!experiment;
+};
+
+export const requireCanModifyExperiment = async (experimentId: string, ctx: TRPCContext) => {
+ const userId = ctx.session?.user.id;
+ if (!userId) {
+ throw new TRPCError({ code: "UNAUTHORIZED" });
+ }
+
+ if (!(await canModifyExperiment(experimentId, userId))) {
+ throw new TRPCError({ code: "UNAUTHORIZED" });
+ }
+
+ ctx.markAccessControlRun();
+};
diff --git a/src/utils/hooks.ts b/src/utils/hooks.ts
index 7bb1c4b..fed6637 100644
--- a/src/utils/hooks.ts
+++ b/src/utils/hooks.ts
@@ -12,6 +12,10 @@ export const useExperiment = () => {
return experiment;
};
+export const useExperimentAccess = () => {
+ return useExperiment().data?.access ?? { canView: false, canModify: false };
+};
+
type AsyncFunction = (...args: T) => Promise;
export function useHandledAsyncCallback(