User accounts

Allows for the creation of user accounts. A few notes on the specifics:

 - Experiments are the main access control objects. If you can view an experiment, you can view all its prompts/scenarios/evals. If you can edit it, you can edit or delete all of those as well.
 - Experiments are owned by Organizations in the database. Organizations can have multiple members and members can have roles of ADMIN, MEMBER or VIEWER.
 - Organizations can either be "personal" or general. Each user has a "personal" organization created as soon as they try to create an experiment. There's currently no UI support for creating general orgs or adding users to them; they're just in the database to future-proof all the ACL logic.
 - You can require that a user is signed-in to see a route using the `protectedProcedure` helper. When you use `protectedProcedure`, you also have to call `ctx.markAccessControlRun()` (or delegate to a function that does it for you; see accessControl.ts). This is to remind us to actually check for access control when we define a new endpoint.
This commit is contained in:
Kyle Corbitt
2023-07-18 17:39:14 -07:00
parent e0e64c4207
commit 1dcdba04a6
34 changed files with 963 additions and 416 deletions

View File

@@ -18,3 +18,11 @@ DATABASE_URL="postgresql://postgres:postgres@localhost:5432/openpipe?schema=publ
OPENAI_API_KEY=""
NEXT_PUBLIC_SOCKET_URL="http://localhost:3318"
# Next Auth
NEXTAUTH_SECRET="your_secret"
NEXTAUTH_URL="http://localhost:3000"
# Next Auth Github Provider
GITHUB_CLIENT_ID="your_client_id"
GITHUB_CLIENT_SECRET="your_secret"

View File

@@ -11,6 +11,7 @@ declare module "nextjs-routes" {
} from "next";
export type Route =
| StaticRoute<"/account/signin">
| DynamicRoute<"/api/auth/[...nextauth]", { "nextauth": string[] }>
| DynamicRoute<"/api/trpc/[trpc]", { "trpc": string }>
| DynamicRoute<"/experiments/[id]", { "id": string }>

View File

@@ -4,11 +4,18 @@
OpenPipe is a flexible playground for comparing and optimizing LLM prompts. It lets you quickly generate, test and compare candidate prompts with realistic sample data.
**Live Demo:** https://openpipe.ai
## Sample Experiments
These are simple experiments users have created that show how OpenPipe works.
- [Country Capitals](https://openpipe.ai/experiments/11111111-1111-1111-1111-111111111111)
- [Reddit User Needs](https://openpipe.ai/experiments/22222222-2222-2222-2222-222222222222)
- [OpenAI Function Calls](https://openpipe.ai/experiments/2ebbdcb3-ed51-456e-87dc-91f72eaf3e2b)
- [Activity Classification](https://openpipe.ai/experiments/3950940f-ab6b-4b74-841d-7e9dbc4e4ff8)
<img src="https://github.com/openpipe/openpipe/assets/176426/fc7624c6-5b65-4d4d-82b7-4a816f3e5678" alt="demo" height="400px">
Currently there's a public playground available at [https://openpipe.ai/](https://openpipe.ai/), but the recommended approach is to [run locally](#running-locally).
You can use our hosted version of OpenPipe at [https://openpipe.ai]. You can also clone this repository and [run it locally](#running-locally).
## High-Level Features
@@ -47,5 +54,6 @@ OpenPipe currently supports GPT-3.5 and GPT-4. Wider model support is planned.
5. Install the dependencies: `cd openpipe && pnpm install`
6. Create a `.env` file (`cp .env.example .env`) and enter your `OPENAI_API_KEY`.
7. Update `DATABASE_URL` if necessary to point to your Postgres instance and run `pnpm prisma db push` to create the database.
8. Start the app: `pnpm dev`.
9. Navigate to [http://localhost:3000](http://localhost:3000)
8. Create a [GitHub OAuth App](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/creating-an-oauth-app) and update the `GITHUB_CLIENT_ID` and `GITHUB_CLIENT_SECRET` values. (Note: a PR to make auth optional when running locally would be a great contribution!)
9. Start the app: `pnpm dev`.
10. Navigate to [http://localhost:3000](http://localhost:3000)

2
pnpm-lock.yaml generated
View File

@@ -1,4 +1,4 @@
lockfileVersion: '6.0'
lockfileVersion: '6.1'
settings:
autoInstallPeers: true

View File

@@ -0,0 +1,124 @@
DROP TABLE "Account";
DROP TABLE "Session";
DROP TABLE "User";
DROP TABLE "VerificationToken";
CREATE TYPE "OrganizationUserRole" AS ENUM ('ADMIN', 'MEMBER', 'VIEWER');
-- CreateTable
CREATE TABLE "Organization" (
"id" UUID NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"personalOrgUserId" UUID,
CONSTRAINT "Organization_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "OrganizationUser" (
"id" UUID NOT NULL,
"role" "OrganizationUserRole" NOT NULL,
"organizationId" UUID NOT NULL,
"userId" UUID NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
CONSTRAINT "OrganizationUser_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "Account" (
"id" UUID NOT NULL,
"userId" UUID NOT NULL,
"type" TEXT NOT NULL,
"provider" TEXT NOT NULL,
"providerAccountId" TEXT NOT NULL,
"refresh_token" TEXT,
"refresh_token_expires_in" INTEGER,
"access_token" TEXT,
"expires_at" INTEGER,
"token_type" TEXT,
"scope" TEXT,
"id_token" TEXT,
"session_state" TEXT,
CONSTRAINT "Account_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "Session" (
"id" UUID NOT NULL,
"sessionToken" TEXT NOT NULL,
"userId" UUID NOT NULL,
"expires" TIMESTAMP(3) NOT NULL,
CONSTRAINT "Session_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "User" (
"id" UUID NOT NULL,
"name" TEXT,
"email" TEXT,
"emailVerified" TIMESTAMP(3),
"image" TEXT,
CONSTRAINT "User_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "VerificationToken" (
"identifier" TEXT NOT NULL,
"token" TEXT NOT NULL,
"expires" TIMESTAMP(3) NOT NULL
);
INSERT INTO "Organization" ("id", "updatedAt") VALUES ('11111111-1111-1111-1111-111111111111', CURRENT_TIMESTAMP);
-- AlterTable add organizationId as NULLABLE
ALTER TABLE "Experiment" ADD COLUMN "organizationId" UUID;
-- Set default organization for existing experiments
UPDATE "Experiment" SET "organizationId" = '11111111-1111-1111-1111-111111111111';
-- AlterTable set organizationId as NOT NULL
ALTER TABLE "Experiment" ALTER COLUMN "organizationId" SET NOT NULL;
-- CreateIndex
CREATE UNIQUE INDEX "OrganizationUser_organizationId_userId_key" ON "OrganizationUser"("organizationId", "userId");
-- CreateIndex
CREATE UNIQUE INDEX "Account_provider_providerAccountId_key" ON "Account"("provider", "providerAccountId");
-- CreateIndex
CREATE UNIQUE INDEX "Session_sessionToken_key" ON "Session"("sessionToken");
-- CreateIndex
CREATE UNIQUE INDEX "User_email_key" ON "User"("email");
-- CreateIndex
CREATE UNIQUE INDEX "VerificationToken_token_key" ON "VerificationToken"("token");
-- CreateIndex
CREATE UNIQUE INDEX "VerificationToken_identifier_token_key" ON "VerificationToken"("identifier", "token");
-- AddForeignKey
ALTER TABLE "Experiment" ADD CONSTRAINT "Experiment_organizationId_fkey" FOREIGN KEY ("organizationId") REFERENCES "Organization"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OrganizationUser" ADD CONSTRAINT "OrganizationUser_organizationId_fkey" FOREIGN KEY ("organizationId") REFERENCES "Organization"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OrganizationUser" ADD CONSTRAINT "OrganizationUser_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "Account" ADD CONSTRAINT "Account_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "Session" ADD CONSTRAINT "Session_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
CREATE UNIQUE INDEX "Organization_personalOrgUserId_key" ON "Organization"("personalOrgUserId");
ALTER TABLE "Organization" ADD CONSTRAINT "Organization_personalOrgUserId_fkey" FOREIGN KEY ("personalOrgUserId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -16,8 +16,12 @@ model Experiment {
sortIndex Int @default(0)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
organizationId String @db.Uuid
organization Organization? @relation(fields: [organizationId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
TemplateVariable TemplateVariable[]
PromptVariant PromptVariant[]
TestScenario TestScenario[]
@@ -169,41 +173,77 @@ model OutputEvaluation {
@@unique([modelOutputId, evaluationId])
}
// Necessary for Next auth
model Organization {
id String @id @default(uuid()) @db.Uuid
personalOrgUserId String? @unique @db.Uuid
PersonalOrgUser User? @relation(fields: [personalOrgUserId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
OrganizationUser OrganizationUser[]
Experiment Experiment[]
}
enum OrganizationUserRole {
ADMIN
MEMBER
VIEWER
}
model OrganizationUser {
id String @id @default(uuid()) @db.Uuid
role OrganizationUserRole
organizationId String @db.Uuid
organization Organization? @relation(fields: [organizationId], references: [id], onDelete: Cascade)
userId String @db.Uuid
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
@@unique([organizationId, userId])
}
model Account {
id String @id @default(cuid())
userId String
type String
provider String
providerAccountId String
refresh_token String? // @db.Text
access_token String? // @db.Text
expires_at Int?
token_type String?
scope String?
id_token String? // @db.Text
session_state String?
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
id String @id @default(uuid()) @db.Uuid
userId String @db.Uuid
type String
provider String
providerAccountId String
refresh_token String? @db.Text
refresh_token_expires_in Int?
access_token String? @db.Text
expires_at Int?
token_type String?
scope String?
id_token String? @db.Text
session_state String?
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@unique([provider, providerAccountId])
}
model Session {
id String @id @default(cuid())
id String @id @default(uuid()) @db.Uuid
sessionToken String @unique
userId String
userId String @db.Uuid
expires DateTime
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
}
model User {
id String @id @default(cuid())
name String?
email String? @unique
emailVerified DateTime?
image String?
accounts Account[]
sessions Session[]
id String @id @default(uuid()) @db.Uuid
name String?
email String? @unique
emailVerified DateTime?
image String?
accounts Account[]
sessions Session[]
OrganizationUser OrganizationUser[]
Organization Organization[]
}
model VerificationToken {

View File

@@ -2,40 +2,47 @@ import { prisma } from "~/server/db";
import dedent from "dedent";
import { generateNewCell } from "~/server/utils/generateNewCell";
const experimentId = "11111111-1111-1111-1111-111111111111";
const defaultId = "11111111-1111-1111-1111-111111111111";
await prisma.organization.deleteMany({
where: { id: defaultId },
});
await prisma.organization.create({
data: { id: defaultId },
});
// Delete the existing experiment
await prisma.experiment.deleteMany({
where: {
id: experimentId,
id: defaultId,
},
});
await prisma.experiment.create({
data: {
id: experimentId,
id: defaultId,
label: "Country Capitals Example",
organizationId: defaultId,
},
});
await prisma.scenarioVariantCell.deleteMany({
where: {
promptVariant: {
experimentId,
experimentId: defaultId,
},
},
});
await prisma.promptVariant.deleteMany({
where: {
experimentId,
experimentId: defaultId,
},
});
await prisma.promptVariant.createMany({
data: [
{
experimentId,
experimentId: defaultId,
label: "Prompt Variant 1",
sortIndex: 0,
model: "gpt-3.5-turbo-0613",
@@ -52,7 +59,7 @@ await prisma.promptVariant.createMany({
}`,
},
{
experimentId,
experimentId: defaultId,
label: "Prompt Variant 2",
sortIndex: 1,
model: "gpt-3.5-turbo-0613",
@@ -73,14 +80,14 @@ await prisma.promptVariant.createMany({
await prisma.templateVariable.deleteMany({
where: {
experimentId,
experimentId: defaultId,
},
});
await prisma.templateVariable.createMany({
data: [
{
experimentId,
experimentId: defaultId,
label: "country",
},
],
@@ -88,28 +95,28 @@ await prisma.templateVariable.createMany({
await prisma.testScenario.deleteMany({
where: {
experimentId,
experimentId: defaultId,
},
});
await prisma.testScenario.createMany({
data: [
{
experimentId,
experimentId: defaultId,
sortIndex: 0,
variableValues: {
country: "Spain",
},
},
{
experimentId,
experimentId: defaultId,
sortIndex: 1,
variableValues: {
country: "USA",
},
},
{
experimentId,
experimentId: defaultId,
sortIndex: 2,
variableValues: {
country: "Chile",
@@ -120,13 +127,13 @@ await prisma.testScenario.createMany({
const variants = await prisma.promptVariant.findMany({
where: {
experimentId,
experimentId: defaultId,
},
});
const scenarios = await prisma.testScenario.findMany({
where: {
experimentId,
experimentId: defaultId,
},
});

View File

@@ -12,7 +12,7 @@ services:
dockerContext: .
plan: standard
domains:
- openpipe.ai
- app.openpipe.ai
envVars:
- key: NODE_ENV
value: production

View File

@@ -1,7 +1,7 @@
import { Button, type ButtonProps, HStack, Spinner, Icon } from "@chakra-ui/react";
import { BsPlus } from "react-icons/bs";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
// Extracted Button styling into reusable component
const StyledButton = ({ children, onClick }: ButtonProps) => (
@@ -17,6 +17,8 @@ const StyledButton = ({ children, onClick }: ButtonProps) => (
);
export default function NewScenarioButton() {
const { canModify } = useExperimentAccess();
const experiment = useExperiment();
const mutation = api.scenarios.create.useMutation();
const utils = api.useContext();
@@ -38,6 +40,8 @@ export default function NewScenarioButton() {
await utils.scenarios.list.invalidate();
}, [mutation]);
if (!canModify) return null;
return (
<HStack spacing={2}>
<StyledButton onClick={onClick}>

View File

@@ -1,7 +1,7 @@
import { Button, Icon, Spinner } from "@chakra-ui/react";
import { Box, Button, Icon, Spinner } from "@chakra-ui/react";
import { BsPlus } from "react-icons/bs";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
import { cellPadding, headerMinHeight } from "../constants";
export default function NewVariantButton() {
@@ -17,6 +17,9 @@ export default function NewVariantButton() {
await utils.promptVariants.list.invalidate();
}, [mutation]);
const { canModify } = useExperimentAccess();
if (!canModify) return <Box w={cellPadding.x} />;
return (
<Button
w="100%"

View File

@@ -1,5 +1,6 @@
import { Button, HStack, Icon } from "@chakra-ui/react";
import { Button, HStack, Icon, Tooltip } from "@chakra-ui/react";
import { BsArrowClockwise } from "react-icons/bs";
import { useExperimentAccess } from "~/utils/hooks";
export const CellOptions = ({
refetchingOutput,
@@ -8,25 +9,28 @@ export const CellOptions = ({
refetchingOutput: boolean;
refetchOutput: () => void;
}) => {
const { canModify } = useExperimentAccess();
return (
<HStack justifyContent="flex-end" w="full">
{!refetchingOutput && (
<Button
size="xs"
w={4}
h={4}
py={4}
px={4}
minW={0}
borderRadius={8}
color="gray.500"
variant="ghost"
cursor="pointer"
onClick={refetchOutput}
aria-label="refetch output"
>
<Icon as={BsArrowClockwise} boxSize={4} />
</Button>
{!refetchingOutput && canModify && (
<Tooltip label="Refetch output" aria-label="refetch output">
<Button
size="xs"
w={4}
h={4}
py={4}
px={4}
minW={0}
borderRadius={8}
color="gray.500"
variant="ghost"
cursor="pointer"
onClick={refetchOutput}
aria-label="refetch output"
>
<Icon as={BsArrowClockwise} boxSize={4} />
</Button>
</Tooltip>
)}
</HStack>
);

View File

@@ -2,7 +2,7 @@ import { type DragEvent } from "react";
import { api } from "~/utils/api";
import { isEqual } from "lodash-es";
import { type Scenario } from "./types";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
import { useState } from "react";
import { Box, Button, Flex, HStack, Icon, Spinner, Stack, Tooltip, VStack } from "@chakra-ui/react";
@@ -19,6 +19,8 @@ export default function ScenarioEditor({
hovered: boolean;
canHide: boolean;
}) {
const { canModify } = useExperimentAccess();
const savedValues = scenario.variableValues as Record<string, string>;
const utils = api.useContext();
const [isDragTarget, setIsDragTarget] = useState(false);
@@ -74,6 +76,7 @@ export default function ScenarioEditor({
alignItems="flex-start"
pr={cellPadding.x}
py={cellPadding.y}
pl={canModify ? 0 : cellPadding.x}
height="100%"
draggable={!variableInputHovered}
onDragStart={(e) => {
@@ -93,35 +96,38 @@ export default function ScenarioEditor({
onDrop={onReorder}
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
>
<Stack alignSelf="flex-start" opacity={props.hovered ? 1 : 0} spacing={0}>
{props.canHide && (
<>
<Tooltip label="Hide scenario" hasArrow>
{/* for some reason the tooltip can't position itself properly relative to the icon without the wrapping box */}
<Button
variant="unstyled"
{canModify && (
<Stack alignSelf="flex-start" opacity={props.hovered ? 1 : 0} spacing={0}>
{props.canHide && (
<>
<Tooltip label="Hide scenario" hasArrow>
{/* for some reason the tooltip can't position itself properly relative to the icon without the wrapping box */}
<Button
variant="unstyled"
color="gray.400"
height="unset"
width="unset"
minW="unset"
onClick={onHide}
_hover={{
color: "gray.800",
cursor: "pointer",
}}
>
<Icon as={hidingInProgress ? Spinner : BsX} boxSize={6} />
</Button>
</Tooltip>
<Icon
as={RiDraggable}
boxSize={6}
color="gray.400"
height="unset"
width="unset"
minW="unset"
onClick={onHide}
_hover={{
color: "gray.800",
cursor: "pointer",
}}
>
<Icon as={hidingInProgress ? Spinner : BsX} boxSize={6} />
</Button>
</Tooltip>
<Icon
as={RiDraggable}
boxSize={6}
color="gray.400"
_hover={{ color: "gray.800", cursor: "pointer" }}
/>
</>
)}
</Stack>
_hover={{ color: "gray.800", cursor: "pointer" }}
/>
</>
)}
</Stack>
)}
{variableLabels.length === 0 ? (
<Box color="gray.500">{vars.data ? "No scenario variables configured" : "Loading..."}</Box>
) : (
@@ -155,6 +161,8 @@ export default function ScenarioEditor({
fontSize="sm"
lineHeight={1.2}
value={value}
isDisabled={!canModify}
_disabled={{ opacity: 1, cursor: "default" }}
onChange={(e) => {
setValues((prev) => ({ ...prev, [key]: e.target.value }));
}}

View File

@@ -1,6 +1,6 @@
import { Button, GridItem, HStack, Heading } from "@chakra-ui/react";
import { cellPadding } from "../constants";
import { useElementDimensions } from "~/utils/hooks";
import { useElementDimensions, useExperimentAccess } from "~/utils/hooks";
import { stickyHeaderStyle } from "./styles";
import { BsPencil } from "react-icons/bs";
import { useAppStore } from "~/state/store";
@@ -13,6 +13,7 @@ export const ScenariosHeader = ({
numScenarios: number;
}) => {
const openDrawer = useAppStore((s) => s.openDrawer);
const { canModify } = useExperimentAccess();
const [ref, dimensions] = useElementDimensions();
const topValue = dimensions ? `-${dimensions.height - 24}px` : "-455px";
@@ -33,16 +34,18 @@ export const ScenariosHeader = ({
<Heading size="xs" fontWeight="bold" flex={1}>
Scenarios ({numScenarios})
</Heading>
<Button
size="xs"
variant="ghost"
color="gray.500"
aria-label="Edit"
leftIcon={<BsPencil />}
onClick={openDrawer}
>
Edit Vars
</Button>
{canModify && (
<Button
size="xs"
variant="ghost"
color="gray.500"
aria-label="Edit"
leftIcon={<BsPencil />}
onClick={openDrawer}
>
Edit Vars
</Button>
)}
</HStack>
</GridItem>
);

View File

@@ -1,11 +1,12 @@
import { Box, Button, HStack, Spinner, Tooltip, useToast, Text } from "@chakra-ui/react";
import { useRef, useEffect, useState, useCallback } from "react";
import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
import { useExperimentAccess, useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
import { type PromptVariant } from "./types";
import { api } from "~/utils/api";
import { useAppStore } from "~/state/store";
export default function VariantEditor(props: { variant: PromptVariant }) {
const { canModify } = useExperimentAccess();
const monaco = useAppStore.use.sharedVariantEditor.monaco();
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
@@ -40,18 +41,6 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
const model = editorRef.current.getModel();
if (!model) return;
const markers = monaco?.editor.getModelMarkers({ resource: model.uri });
const hasErrors = markers?.some((m) => m.severity === monaco?.MarkerSeverity.Error);
if (hasErrors) {
toast({
title: "Invalid TypeScript",
description: "Please fix the TypeScript errors before saving.",
status: "error",
});
return;
}
// Make sure the user defined the prompt with the string "prompt\w*=" somewhere
const promptRegex = /prompt\s*=/;
if (!promptRegex.test(currentFn)) {
@@ -103,6 +92,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
wordWrapBreakAfterCharacters: "",
wordWrapBreakBeforeCharacters: "",
quickSuggestions: true,
readOnly: !canModify,
});
editorRef.current.onDidFocusEditorText(() => {
@@ -130,6 +120,13 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
/* eslint-disable-next-line react-hooks/exhaustive-deps */
}, [monaco, editorId]);
useEffect(() => {
if (!editorRef.current) return;
editorRef.current.updateOptions({
readOnly: !canModify,
});
}, [canModify]);
return (
<Box w="100%" pos="relative">
<div id={editorId} style={{ height: "400px", width: "100%" }}></div>

View File

@@ -2,7 +2,7 @@ import { type SystemStyleObject } from "@chakra-ui/react";
export const stickyHeaderStyle: SystemStyleObject = {
position: "sticky",
top: "-1px",
top: "0",
backgroundColor: "#fff",
zIndex: 1,
};

View File

@@ -1,21 +0,0 @@
import { Flex, Icon, Link, Text } from "@chakra-ui/react";
import { BsExclamationTriangleFill } from "react-icons/bs";
import { env } from "~/env.mjs";
export default function PublicPlaygroundWarning() {
if (!env.NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND) return null;
return (
<Flex bgColor="red.600" color="whiteAlpha.900" p={2} align="center">
<Icon boxSize={4} mr={2} as={BsExclamationTriangleFill} />
<Text>
Warning: this is a public playground. Anyone can see, edit or delete your experiments. For
private use,{" "}
<Link textDecor="underline" href="https://github.com/openpipe/openpipe" target="_blank">
run a local copy
</Link>
.
</Text>
</Flex>
);
}

View File

@@ -1,15 +1,16 @@
import { useState, type DragEvent } from "react";
import { type PromptVariant } from "../OutputsTable/types";
import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks";
import { HStack, Icon, GridItem } from "@chakra-ui/react"; // Changed here
import { RiDraggable } from "react-icons/ri";
import { useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
import { HStack, Icon, Text, GridItem } from "@chakra-ui/react"; // Changed here
import { cellPadding, headerMinHeight } from "../constants";
import AutoResizeTextArea from "../AutoResizeTextArea";
import { stickyHeaderStyle } from "../OutputsTable/styles";
import VariantHeaderMenuButton from "./VariantHeaderMenuButton";
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
const { canModify } = useExperimentAccess();
const utils = api.useContext();
const [isDragTarget, setIsDragTarget] = useState(false);
const [isInputHovered, setIsInputHovered] = useState(false);
@@ -44,6 +45,16 @@ export default function VariantHeader(props: { variant: PromptVariant; canHide:
const [menuOpen, setMenuOpen] = useState(false);
if (!canModify) {
return (
<GridItem padding={0} sx={stickyHeaderStyle} borderTopWidth={1}>
<Text fontSize={16} fontWeight="bold" px={cellPadding.x} py={cellPadding.y}>
{props.variant.label}
</Text>
</GridItem>
);
}
return (
<GridItem
padding={0}

View File

@@ -14,17 +14,19 @@ import {
Link,
} from "@chakra-ui/react";
import Head from "next/head";
import { BsGithub, BsTwitter } from "react-icons/bs";
import { BsGithub, BsPersonCircle } from "react-icons/bs";
import { useRouter } from "next/router";
import PublicPlaygroundWarning from "../PublicPlaygroundWarning";
import { type IconType } from "react-icons";
import { RiFlaskLine } from "react-icons/ri";
import { useState, useEffect } from "react";
import { signIn, useSession } from "next-auth/react";
import UserMenu from "./UserMenu";
type IconLinkProps = BoxProps & LinkProps & { label: string; icon: IconType; href: string };
type IconLinkProps = BoxProps & LinkProps & { label?: string; icon: IconType };
const IconLink = ({ icon, label, href, target, color, ...props }: IconLinkProps) => {
const isActive = useRouter().pathname.startsWith(href);
const router = useRouter();
const isActive = href && router.pathname.startsWith(href);
return (
<Box
as={Link}
@@ -32,7 +34,7 @@ const IconLink = ({ icon, label, href, target, color, ...props }: IconLinkProps)
target={target}
w="full"
bgColor={isActive ? "gray.300" : "transparent"}
_hover={{ bgColor: "gray.300" }}
_hover={{ bgColor: "gray.200", textDecoration: "none" }}
py={4}
justifyContent="start"
cursor="pointer"
@@ -47,6 +49,8 @@ const IconLink = ({ icon, label, href, target, color, ...props }: IconLinkProps)
};
const NavSidebar = () => {
const user = useSession().data;
return (
<VStack align="stretch" bgColor="gray.100" py={2} pb={0} height="100%">
<Link href="/" w="full" _hover={{ textDecoration: "none" }}>
@@ -59,26 +63,32 @@ const NavSidebar = () => {
</Link>
<Divider />
<VStack spacing={0} align="flex-start" overflowY="auto" overflowX="hidden" flex={1}>
<IconLink icon={RiFlaskLine} label="Experiments" href="/experiments" />
{user != null && (
<>
<IconLink icon={RiFlaskLine} label="Experiments" href="/experiments" />
</>
)}
{user === null && (
<IconLink
icon={BsPersonCircle}
label="Sign In"
onClick={() => {
signIn("github").catch(console.error);
}}
/>
)}
</VStack>
<Divider />
<VStack w="full" spacing={0} pb={2}>
<IconLink
icon={BsGithub}
label="GitHub"
{user ? <UserMenu user={user} /> : <Divider />}
<VStack spacing={0} align="center">
<Link
href="https://github.com/openpipe/openpipe"
target="_blank"
color="gray.500"
_hover={{ color: "gray.800" }}
/>
<IconLink
icon={BsTwitter}
label="Twitter"
href="https://twitter.com/corbtt"
target="_blank"
color="gray.500"
_hover={{ color: "gray.800" }}
/>
p={2}
>
<Icon as={BsGithub} boxSize={6} />
</Link>
</VStack>
</VStack>
);
@@ -108,16 +118,13 @@ export default function AppShell(props: { children: React.ReactNode; title?: str
<Grid
h={vh}
w="100vw"
templateColumns={{ base: "56px minmax(0, 1fr)", md: "200px minmax(0, 1fr)" }}
templateRows="max-content 1fr"
templateAreas={'"warning warning"\n"sidebar main"'}
templateColumns={{ base: "56px minmax(0, 1fr)", md: "220px minmax(0, 1fr)" }}
templateRows="1fr"
templateAreas={'"sidebar main"'}
>
<Head>
<title>{props.title ? `${props.title} | OpenPipe` : "OpenPipe"}</title>
</Head>
<GridItem area="warning">
<PublicPlaygroundWarning />
</GridItem>
<GridItem area="sidebar" overflow="hidden">
<NavSidebar />
</GridItem>

View File

@@ -0,0 +1,72 @@
import {
HStack,
Icon,
Image,
VStack,
Text,
Popover,
PopoverTrigger,
PopoverContent,
Link,
} from "@chakra-ui/react";
import { type Session } from "next-auth";
import { signOut } from "next-auth/react";
import { BsBoxArrowRight, BsChevronRight, BsPersonCircle } from "react-icons/bs";
export default function UserMenu({ user }: { user: Session }) {
const profileImage = user.user.image ? (
<Image src={user.user.image} alt="profile picture" w={8} h={8} borderRadius="50%" />
) : (
<Icon as={BsPersonCircle} boxSize="md" />
);
return (
<>
<Popover placement="right">
<PopoverTrigger>
<HStack
px={2}
py={2}
borderColor={"gray.200"}
borderTopWidth={1}
borderBottomWidth={1}
cursor="pointer"
_hover={{
bgColor: "gray.200",
}}
>
{profileImage}
<VStack spacing={0} align="start" flex={1}>
<Text fontWeight="bold" fontSize="sm">
{user.user.name}
</Text>
<Text color="gray.500" fontSize="xs">
{user.user.email}
</Text>
</VStack>
<Icon as={BsChevronRight} boxSize={4} color="gray.500" />
</HStack>
</PopoverTrigger>
<PopoverContent _focusVisible={{ boxShadow: "unset", outline: "unset" }} maxW="200px">
<VStack align="stretch" spacing={0}>
{/* sign out */}
<HStack
as={Link}
onClick={() => {
signOut().catch(console.error);
}}
px={4}
py={2}
spacing={4}
color="gray.500"
fontSize="sm"
>
<Icon as={BsBoxArrowRight} boxSize={6} />
<Text>Sign out</Text>
</HStack>
</VStack>
</PopoverContent>
</Popover>
</>
);
}

View File

@@ -15,6 +15,8 @@ export const env = createEnv({
.optional()
.default("false")
.transform((val) => val.toLowerCase() === "true"),
GITHUB_CLIENT_ID: z.string().min(1),
GITHUB_CLIENT_SECRET: z.string().min(1),
},
/**
@@ -24,11 +26,6 @@ export const env = createEnv({
*/
client: {
NEXT_PUBLIC_POSTHOG_KEY: z.string().optional(),
NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND: z
.string()
.optional()
.default("false")
.transform((val) => val.toLowerCase() === "true"),
NEXT_PUBLIC_SOCKET_URL: z.string().url().default("http://localhost:3318"),
},
@@ -42,8 +39,9 @@ export const env = createEnv({
OPENAI_API_KEY: process.env.OPENAI_API_KEY,
RESTRICT_PRISMA_LOGS: process.env.RESTRICT_PRISMA_LOGS,
NEXT_PUBLIC_POSTHOG_KEY: process.env.NEXT_PUBLIC_POSTHOG_KEY,
NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND: process.env.NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND,
NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL,
GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID,
GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET,
},
/**
* Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation.

View File

@@ -0,0 +1,23 @@
import { signIn, useSession } from "next-auth/react";
import { useRouter } from "next/router";
import { useEffect } from "react";
import AppShell from "~/components/nav/AppShell";
export default function SignIn() {
const session = useSession().data;
const router = useRouter();
useEffect(() => {
if (session) {
router.push("/experiments").catch(console.error);
} else if (session === null) {
signIn("github").catch(console.error);
}
}, [session, router]);
return (
<AppShell>
<div />
</AppShell>
);
}

View File

@@ -124,6 +124,8 @@ export default function Experiment() {
);
}
const canModify = experiment.data?.access.canModify ?? false;
return (
<AppShell title={experiment.data?.label}>
<VStack h="full">
@@ -143,37 +145,45 @@ export default function Experiment() {
</Link>
</BreadcrumbItem>
<BreadcrumbItem isCurrentPage>
<Input
size="sm"
value={label}
onChange={(e) => setLabel(e.target.value)}
onBlur={onSaveLabel}
borderWidth={1}
borderColor="transparent"
fontSize={16}
px={0}
minW={{ base: 100, lg: 300 }}
flex={1}
_hover={{ borderColor: "gray.300" }}
_focus={{ borderColor: "blue.500", outline: "none" }}
/>
{canModify ? (
<Input
size="sm"
value={label}
onChange={(e) => setLabel(e.target.value)}
onBlur={onSaveLabel}
borderWidth={1}
borderColor="transparent"
fontSize={16}
px={0}
minW={{ base: 100, lg: 300 }}
flex={1}
_hover={{ borderColor: "gray.300" }}
_focus={{ borderColor: "blue.500", outline: "none" }}
/>
) : (
<Text fontSize={16} px={0} minW={{ base: 100, lg: 300 }} flex={1}>
{experiment.data?.label}
</Text>
)}
</BreadcrumbItem>
</Breadcrumb>
<HStack>
<Button
size="sm"
variant={{ base: "outline", lg: "ghost" }}
colorScheme="gray"
fontWeight="normal"
onClick={openDrawer}
>
<Icon as={BsGearFill} boxSize={4} color="gray.600" />
<Text display={{ base: "none", lg: "block" }} ml={2}>
Edit Vars & Evals
</Text>
</Button>
<DeleteButton />
</HStack>
{canModify && (
<HStack>
<Button
size="sm"
variant={{ base: "outline", lg: "ghost" }}
colorScheme="gray"
fontWeight="normal"
onClick={openDrawer}
>
<Icon as={BsGearFill} boxSize={4} color="gray.600" />
<Text display={{ base: "none", lg: "block" }} ml={2}>
Edit Vars & Evals
</Text>
</Button>
<DeleteButton />
</HStack>
)}
</Flex>
<SettingsDrawer />
<Box w="100%" overflowX="auto" flex={1}>

View File

@@ -6,18 +6,44 @@ import {
Breadcrumb,
BreadcrumbItem,
Flex,
Center,
Text,
Link,
} from "@chakra-ui/react";
import { RiFlaskLine } from "react-icons/ri";
import AppShell from "~/components/nav/AppShell";
import { api } from "~/utils/api";
import { NewExperimentButton } from "~/components/experiments/NewExperimentButton";
import { ExperimentCard } from "~/components/experiments/ExperimentCard";
import { signIn, useSession } from "next-auth/react";
export default function ExperimentsPage() {
const experiments = api.experiments.list.useQuery();
const user = useSession().data;
if (user === null) {
return (
<AppShell title="Experiments">
<Center h="100%">
<Text>
<Link
onClick={() => {
signIn("github").catch(console.error);
}}
textDecor="underline"
>
Sign in
</Link>{" "}
to view or create new experiments!
</Text>
</Center>
</AppShell>
);
}
return (
<AppShell>
<AppShell title="Experiments">
<VStack alignItems={"flex-start"} m={4} mt={1}>
<HStack w="full" justifyContent="space-between" mb={4}>
<Breadcrumb flex={1}>

View File

@@ -1,20 +1,25 @@
import { EvalType } from "@prisma/client";
import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { runAllEvals } from "~/server/utils/evaluations";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const evaluationsRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
return await prisma.evaluation.findMany({
where: {
experimentId: input.experimentId,
},
orderBy: { createdAt: "asc" },
});
}),
list: publicProcedure
.input(z.object({ experimentId: z.string() }))
.query(async ({ input, ctx }) => {
await requireCanViewExperiment(input.experimentId, ctx);
create: publicProcedure
return await prisma.evaluation.findMany({
where: {
experimentId: input.experimentId,
},
orderBy: { createdAt: "asc" },
});
}),
create: protectedProcedure
.input(
z.object({
experimentId: z.string(),
@@ -23,7 +28,9 @@ export const evaluationsRouter = createTRPCRouter({
evalType: z.nativeEnum(EvalType),
}),
)
.mutation(async ({ input }) => {
.mutation(async ({ input, ctx }) => {
await requireCanModifyExperiment(input.experimentId, ctx);
await prisma.evaluation.create({
data: {
experimentId: input.experimentId,
@@ -38,7 +45,7 @@ export const evaluationsRouter = createTRPCRouter({
await runAllEvals(input.experimentId);
}),
update: publicProcedure
update: protectedProcedure
.input(
z.object({
id: z.string(),
@@ -49,7 +56,12 @@ export const evaluationsRouter = createTRPCRouter({
}),
}),
)
.mutation(async ({ input }) => {
.mutation(async ({ input, ctx }) => {
const { experimentId } = await prisma.evaluation.findUniqueOrThrow({
where: { id: input.id },
});
await requireCanModifyExperiment(experimentId, ctx);
const evaluation = await prisma.evaluation.update({
where: { id: input.id },
data: {
@@ -69,9 +81,16 @@ export const evaluationsRouter = createTRPCRouter({
await runAllEvals(evaluation.experimentId);
}),
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
await prisma.evaluation.delete({
where: { id: input.id },
});
}),
delete: protectedProcedure
.input(z.object({ id: z.string() }))
.mutation(async ({ input, ctx }) => {
const { experimentId } = await prisma.evaluation.findUniqueOrThrow({
where: { id: input.id },
});
await requireCanModifyExperiment(experimentId, ctx);
await prisma.evaluation.delete({
where: { id: input.id },
});
}),
});

View File

@@ -1,12 +1,29 @@
import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import dedent from "dedent";
import { generateNewCell } from "~/server/utils/generateNewCell";
import {
canModifyExperiment,
requireCanModifyExperiment,
requireCanViewExperiment,
requireNothing,
} from "~/utils/accessControl";
import userOrg from "~/server/utils/userOrg";
export const experimentsRouter = createTRPCRouter({
list: publicProcedure.query(async () => {
list: protectedProcedure.query(async ({ ctx }) => {
// Anyone can list experiments
requireNothing(ctx);
const experiments = await prisma.experiment.findMany({
where: {
organization: {
OrganizationUser: {
some: { userId: ctx.session.user.id },
},
},
},
orderBy: {
sortIndex: "asc",
},
@@ -40,15 +57,29 @@ export const experimentsRouter = createTRPCRouter({
return experimentsWithCounts;
}),
get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input }) => {
return await prisma.experiment.findFirst({
where: {
id: input.id,
},
get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
await requireCanViewExperiment(input.id, ctx);
const experiment = await prisma.experiment.findFirstOrThrow({
where: { id: input.id },
});
const canModify = ctx.session?.user.id
? await canModifyExperiment(experiment.id, ctx.session?.user.id)
: false;
return {
...experiment,
access: {
canView: true,
canModify,
},
};
}),
create: publicProcedure.input(z.object({})).mutation(async () => {
create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => {
// Anyone can create an experiment
requireNothing(ctx);
const maxSortIndex =
(
await prisma.experiment.aggregate({
@@ -62,6 +93,7 @@ export const experimentsRouter = createTRPCRouter({
data: {
sortIndex: maxSortIndex + 1,
label: `Experiment ${maxSortIndex + 1}`,
organizationId: (await userOrg(ctx.session.user.id)).id,
},
});
@@ -117,9 +149,10 @@ export const experimentsRouter = createTRPCRouter({
return exp;
}),
update: publicProcedure
update: protectedProcedure
.input(z.object({ id: z.string(), updates: z.object({ label: z.string() }) }))
.mutation(async ({ input }) => {
.mutation(async ({ input, ctx }) => {
await requireCanModifyExperiment(input.id, ctx);
return await prisma.experiment.update({
where: {
id: input.id,
@@ -130,11 +163,15 @@ export const experimentsRouter = createTRPCRouter({
});
}),
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
await prisma.experiment.delete({
where: {
id: input.id,
},
});
}),
delete: protectedProcedure
.input(z.object({ id: z.string() }))
.mutation(async ({ input, ctx }) => {
await requireCanModifyExperiment(input.id, ctx);
await prisma.experiment.delete({
where: {
id: input.id,
},
});
}),
});

View File

@@ -1,6 +1,6 @@
import { isObject } from "lodash-es";
import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { generateNewCell } from "~/server/utils/generateNewCell";
import { OpenAIChatModel, type SupportedModel } from "~/server/types";
@@ -11,129 +11,140 @@ import { calculateTokenCost } from "~/utils/calculateTokenCost";
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
import { type PromptVariant } from "@prisma/client";
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const promptVariantsRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
return await prisma.promptVariant.findMany({
where: {
experimentId: input.experimentId,
visible: true,
},
orderBy: { sortIndex: "asc" },
});
}),
list: publicProcedure
.input(z.object({ experimentId: z.string() }))
.query(async ({ input, ctx }) => {
await requireCanViewExperiment(input.experimentId, ctx);
stats: publicProcedure.input(z.object({ variantId: z.string() })).query(async ({ input }) => {
const variant = await prisma.promptVariant.findUnique({
where: {
id: input.variantId,
},
});
return await prisma.promptVariant.findMany({
where: {
experimentId: input.experimentId,
visible: true,
},
orderBy: { sortIndex: "asc" },
});
}),
if (!variant) {
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
}
stats: publicProcedure
.input(z.object({ variantId: z.string() }))
.query(async ({ input, ctx }) => {
const variant = await prisma.promptVariant.findUnique({
where: {
id: input.variantId,
},
});
const outputEvals = await prisma.outputEvaluation.groupBy({
by: ["evaluationId"],
_sum: {
result: true,
},
_count: {
id: true,
},
where: {
modelOutput: {
scenarioVariantCell: {
promptVariant: {
id: input.variantId,
visible: true,
if (!variant) {
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
}
await requireCanViewExperiment(variant.experimentId, ctx);
const outputEvals = await prisma.outputEvaluation.groupBy({
by: ["evaluationId"],
_sum: {
result: true,
},
_count: {
id: true,
},
where: {
modelOutput: {
scenarioVariantCell: {
promptVariant: {
id: input.variantId,
visible: true,
},
testScenario: {
visible: true,
},
},
},
},
});
const evals = await prisma.evaluation.findMany({
where: {
experimentId: variant.experimentId,
},
});
const evalResults = evals.map((evalItem) => {
const evalResult = outputEvals.find(
(outputEval) => outputEval.evaluationId === evalItem.id,
);
return {
id: evalItem.id,
label: evalItem.label,
passCount: evalResult?._sum?.result ?? 0,
totalCount: evalResult?._count?.id ?? 1,
};
});
const scenarioCount = await prisma.testScenario.count({
where: {
experimentId: variant.experimentId,
visible: true,
},
});
const outputCount = await prisma.scenarioVariantCell.count({
where: {
promptVariantId: input.variantId,
testScenario: { visible: true },
modelOutput: {
is: {},
},
},
});
const overallTokens = await prisma.modelOutput.aggregate({
where: {
scenarioVariantCell: {
promptVariantId: input.variantId,
testScenario: {
visible: true,
},
},
},
},
});
const evals = await prisma.evaluation.findMany({
where: {
experimentId: variant.experimentId,
},
});
const evalResults = evals.map((evalItem) => {
const evalResult = outputEvals.find((outputEval) => outputEval.evaluationId === evalItem.id);
return {
id: evalItem.id,
label: evalItem.label,
passCount: evalResult?._sum?.result ?? 0,
totalCount: evalResult?._count?.id ?? 1,
};
});
const scenarioCount = await prisma.testScenario.count({
where: {
experimentId: variant.experimentId,
visible: true,
},
});
const outputCount = await prisma.scenarioVariantCell.count({
where: {
promptVariantId: input.variantId,
testScenario: { visible: true },
modelOutput: {
is: {},
_sum: {
promptTokens: true,
completionTokens: true,
},
},
});
});
const overallTokens = await prisma.modelOutput.aggregate({
where: {
scenarioVariantCell: {
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
const overallPromptCost = calculateTokenCost(variant.model, promptTokens);
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
const overallCompletionCost = calculateTokenCost(variant.model, completionTokens, true);
const overallCost = overallPromptCost + overallCompletionCost;
const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({
where: {
promptVariantId: input.variantId,
testScenario: {
visible: true,
testScenario: { visible: true },
// Check if is PENDING or IN_PROGRESS
retrievalStatus: {
in: ["PENDING", "IN_PROGRESS"],
},
},
},
_sum: {
promptTokens: true,
completionTokens: true,
},
});
}));
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
const overallPromptCost = calculateTokenCost(variant.model, promptTokens);
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
const overallCompletionCost = calculateTokenCost(variant.model, completionTokens, true);
return {
evalResults,
promptTokens,
completionTokens,
overallCost,
scenarioCount,
outputCount,
awaitingRetrievals,
};
}),
const overallCost = overallPromptCost + overallCompletionCost;
const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({
where: {
promptVariantId: input.variantId,
testScenario: { visible: true },
// Check if is PENDING or IN_PROGRESS
retrievalStatus: {
in: ["PENDING", "IN_PROGRESS"],
},
},
}));
return {
evalResults,
promptTokens,
completionTokens,
overallCost,
scenarioCount,
outputCount,
awaitingRetrievals,
};
}),
create: publicProcedure
create: protectedProcedure
.input(
z.object({
experimentId: z.string(),
@@ -141,7 +152,9 @@ export const promptVariantsRouter = createTRPCRouter({
newModel: z.string().optional(),
}),
)
.mutation(async ({ input }) => {
.mutation(async ({ input, ctx }) => {
await requireCanViewExperiment(input.experimentId, ctx);
let originalVariant: PromptVariant | null = null;
if (input.variantId) {
originalVariant = await prisma.promptVariant.findUnique({
@@ -217,7 +230,7 @@ export const promptVariantsRouter = createTRPCRouter({
return newVariant;
}),
update: publicProcedure
update: protectedProcedure
.input(
z.object({
id: z.string(),
@@ -226,7 +239,7 @@ export const promptVariantsRouter = createTRPCRouter({
}),
}),
)
.mutation(async ({ input }) => {
.mutation(async ({ input, ctx }) => {
const existing = await prisma.promptVariant.findUnique({
where: {
id: input.id,
@@ -237,6 +250,8 @@ export const promptVariantsRouter = createTRPCRouter({
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
}
await requireCanModifyExperiment(existing.experimentId, ctx);
const updatePromptVariantAction = prisma.promptVariant.update({
where: {
id: input.id,
@@ -252,13 +267,18 @@ export const promptVariantsRouter = createTRPCRouter({
return updatedPromptVariant;
}),
hide: publicProcedure
hide: protectedProcedure
.input(
z.object({
id: z.string(),
}),
)
.mutation(async ({ input }) => {
.mutation(async ({ input, ctx }) => {
const { experimentId } = await prisma.promptVariant.findUniqueOrThrow({
where: { id: input.id },
});
await requireCanModifyExperiment(experimentId, ctx);
const updatedPromptVariant = await prisma.promptVariant.update({
where: { id: input.id },
data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
@@ -267,19 +287,20 @@ export const promptVariantsRouter = createTRPCRouter({
return updatedPromptVariant;
}),
replaceVariant: publicProcedure
replaceVariant: protectedProcedure
.input(
z.object({
id: z.string(),
constructFn: z.string(),
}),
)
.mutation(async ({ input }) => {
const existing = await prisma.promptVariant.findUnique({
.mutation(async ({ input, ctx }) => {
const existing = await prisma.promptVariant.findUniqueOrThrow({
where: {
id: input.id,
},
});
await requireCanModifyExperiment(existing.experimentId, ctx);
if (!existing) {
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
@@ -347,14 +368,19 @@ export const promptVariantsRouter = createTRPCRouter({
return { status: "ok" } as const;
}),
reorder: publicProcedure
reorder: protectedProcedure
.input(
z.object({
draggedId: z.string(),
droppedId: z.string(),
}),
)
.mutation(async ({ input }) => {
.mutation(async ({ input, ctx }) => {
const { experimentId } = await prisma.promptVariant.findUniqueOrThrow({
where: { id: input.draggedId },
});
await requireCanModifyExperiment(experimentId, ctx);
await reorderPromptVariants(input.draggedId, input.droppedId);
}),
});

View File

@@ -1,8 +1,9 @@
import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { generateNewCell } from "~/server/utils/generateNewCell";
import { queueLLMRetrievalTask } from "~/server/utils/queueLLMRetrievalTask";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const scenarioVariantCellsRouter = createTRPCRouter({
get: publicProcedure
@@ -12,7 +13,12 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
variantId: z.string(),
}),
)
.query(async ({ input }) => {
.query(async ({ input, ctx }) => {
const { experimentId } = await prisma.testScenario.findUniqueOrThrow({
where: { id: input.scenarioId },
});
await requireCanViewExperiment(experimentId, ctx);
return await prisma.scenarioVariantCell.findUnique({
where: {
promptVariantId_testScenarioId: {
@@ -35,14 +41,20 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
},
});
}),
forceRefetch: publicProcedure
forceRefetch: protectedProcedure
.input(
z.object({
scenarioId: z.string(),
variantId: z.string(),
}),
)
.mutation(async ({ input }) => {
.mutation(async ({ input, ctx }) => {
const { experimentId } = await prisma.testScenario.findUniqueOrThrow({
where: { id: input.scenarioId },
});
await requireCanModifyExperiment(experimentId, ctx);
const cell = await prisma.scenarioVariantCell.findUnique({
where: {
promptVariantId_testScenarioId: {

View File

@@ -1,32 +1,39 @@
import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { autogenerateScenarioValues } from "../autogen";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
import { runAllEvals } from "~/server/utils/evaluations";
import { generateNewCell } from "~/server/utils/generateNewCell";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const scenariosRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
return await prisma.testScenario.findMany({
where: {
experimentId: input.experimentId,
visible: true,
},
orderBy: {
sortIndex: "asc",
},
});
}),
list: publicProcedure
.input(z.object({ experimentId: z.string() }))
.query(async ({ input, ctx }) => {
await requireCanViewExperiment(input.experimentId, ctx);
create: publicProcedure
return await prisma.testScenario.findMany({
where: {
experimentId: input.experimentId,
visible: true,
},
orderBy: {
sortIndex: "asc",
},
});
}),
create: protectedProcedure
.input(
z.object({
experimentId: z.string(),
autogenerate: z.boolean().optional(),
}),
)
.mutation(async ({ input }) => {
.mutation(async ({ input, ctx }) => {
await requireCanModifyExperiment(input.experimentId, ctx);
const maxSortIndex =
(
await prisma.testScenario.aggregate({
@@ -66,7 +73,14 @@ export const scenariosRouter = createTRPCRouter({
}
}),
hide: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
hide: protectedProcedure.input(z.object({ id: z.string() })).mutation(async ({ input, ctx }) => {
const experimentId = (
await prisma.testScenario.findUniqueOrThrow({
where: { id: input.id },
})
).experimentId;
await requireCanModifyExperiment(experimentId, ctx);
const hiddenScenario = await prisma.testScenario.update({
where: { id: input.id },
data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
@@ -78,14 +92,14 @@ export const scenariosRouter = createTRPCRouter({
return hiddenScenario;
}),
reorder: publicProcedure
reorder: protectedProcedure
.input(
z.object({
draggedId: z.string(),
droppedId: z.string(),
}),
)
.mutation(async ({ input }) => {
.mutation(async ({ input, ctx }) => {
const dragged = await prisma.testScenario.findUnique({
where: {
id: input.draggedId,
@@ -104,6 +118,8 @@ export const scenariosRouter = createTRPCRouter({
);
}
await requireCanModifyExperiment(dragged.experimentId, ctx);
const visibleItems = await prisma.testScenario.findMany({
where: {
experimentId: dragged.experimentId,
@@ -147,14 +163,14 @@ export const scenariosRouter = createTRPCRouter({
);
}),
replaceWithValues: publicProcedure
replaceWithValues: protectedProcedure
.input(
z.object({
id: z.string(),
values: z.record(z.string()),
}),
)
.mutation(async ({ input }) => {
.mutation(async ({ input, ctx }) => {
const existing = await prisma.testScenario.findUnique({
where: {
id: input.id,
@@ -165,6 +181,8 @@ export const scenariosRouter = createTRPCRouter({
throw new Error(`Scenario with id ${input.id} does not exist`);
}
await requireCanModifyExperiment(existing.experimentId, ctx);
const newScenario = await prisma.testScenario.create({
data: {
experimentId: existing.experimentId,

View File

@@ -1,11 +1,14 @@
import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const templateVarsRouter = createTRPCRouter({
create: publicProcedure
create: protectedProcedure
.input(z.object({ experimentId: z.string(), label: z.string() }))
.mutation(async ({ input }) => {
.mutation(async ({ input, ctx }) => {
await requireCanModifyExperiment(input.experimentId, ctx);
await prisma.templateVariable.create({
data: {
experimentId: input.experimentId,
@@ -14,22 +17,33 @@ export const templateVarsRouter = createTRPCRouter({
});
}),
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
await prisma.templateVariable.delete({ where: { id: input.id } });
}),
delete: protectedProcedure
.input(z.object({ id: z.string() }))
.mutation(async ({ input, ctx }) => {
const { experimentId } = await prisma.templateVariable.findUniqueOrThrow({
where: { id: input.id },
});
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
return await prisma.templateVariable.findMany({
where: {
experimentId: input.experimentId,
},
orderBy: {
createdAt: "asc",
},
select: {
id: true,
label: true,
},
});
}),
await requireCanModifyExperiment(experimentId, ctx);
await prisma.templateVariable.delete({ where: { id: input.id } });
}),
list: publicProcedure
.input(z.object({ experimentId: z.string() }))
.query(async ({ input, ctx }) => {
await requireCanViewExperiment(input.experimentId, ctx);
return await prisma.templateVariable.findMany({
where: {
experimentId: input.experimentId,
},
orderBy: {
createdAt: "asc",
},
select: {
id: true,
label: true,
},
});
}),
});

View File

@@ -27,6 +27,9 @@ type CreateContextOptions = {
session: Session | null;
};
// eslint-disable-next-line @typescript-eslint/no-empty-function
const noOp = () => {};
/**
* This helper generates the "internals" for a tRPC context. If you need to use it, you can export
* it from here.
@@ -41,6 +44,7 @@ const createInnerTRPCContext = (opts: CreateContextOptions) => {
return {
session: opts.session,
prisma,
markAccessControlRun: noOp,
};
};
@@ -69,6 +73,8 @@ export const createTRPCContext = async (opts: CreateNextContextOptions) => {
* errors on the backend.
*/
export type TRPCContext = Awaited<ReturnType<typeof createTRPCContext>>;
const t = initTRPC.context<typeof createTRPCContext>().create({
transformer: superjson,
errorFormatter({ shape, error }) {
@@ -106,16 +112,29 @@ export const createTRPCRouter = t.router;
export const publicProcedure = t.procedure;
/** Reusable middleware that enforces users are logged in before running the procedure. */
const enforceUserIsAuthed = t.middleware(({ ctx, next }) => {
const enforceUserIsAuthed = t.middleware(async ({ ctx, next }) => {
if (!ctx.session || !ctx.session.user) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
return next({
let accessControlRun = false;
const resp = await next({
ctx: {
// infers the `session` as non-nullable
session: { ...ctx.session, user: ctx.session.user },
markAccessControlRun: () => {
accessControlRun = true;
},
},
});
if (!accessControlRun)
throw new TRPCError({
code: "INTERNAL_SERVER_ERROR",
message:
"Protected routes must perform access control checks then explicitly invoke the `ctx.markAccessControlRun()` function to ensure we don't forget access control on a route.",
});
return resp;
});
/**

View File

@@ -2,6 +2,8 @@ import { PrismaAdapter } from "@next-auth/prisma-adapter";
import { type GetServerSidePropsContext } from "next";
import { getServerSession, type NextAuthOptions, type DefaultSession } from "next-auth";
import { prisma } from "~/server/db";
import GitHubProvider from "next-auth/providers/github";
import { env } from "~/env.mjs";
/**
* Module augmentation for `next-auth` types. Allows us to add custom properties to the `session`
@@ -41,20 +43,15 @@ export const authOptions: NextAuthOptions = {
},
adapter: PrismaAdapter(prisma),
providers: [
// DiscordProvider({
// clientId: env.DISCORD_CLIENT_ID,
// clientSecret: env.DISCORD_CLIENT_SECRET,
// }),
/**
* ...add more providers here.
*
* Most other providers require a bit more work than the Discord provider. For example, the
* GitHub provider requires you to add the `refresh_token_expires_in` field to the Account
* model. Refer to the NextAuth.js docs for the provider you want to use. Example:
*
* @see https://next-auth.js.org/providers/github
*/
GitHubProvider({
clientId: env.GITHUB_CLIENT_ID,
clientSecret: env.GITHUB_CLIENT_SECRET,
}),
],
theme: {
logo: "/logo.svg",
brandColor: "#ff5733",
},
};
/**

View File

@@ -0,0 +1,19 @@
import { prisma } from "~/server/db";
export default async function userOrg(userId: string) {
return await prisma.organization.upsert({
where: {
personalOrgUserId: userId,
},
update: {},
create: {
personalOrgUserId: userId,
OrganizationUser: {
create: {
userId: userId,
role: "ADMIN",
},
},
},
});
}

View File

@@ -0,0 +1,49 @@
import { OrganizationUserRole } from "@prisma/client";
import { TRPCError } from "@trpc/server";
import { type TRPCContext } from "~/server/api/trpc";
import { prisma } from "~/server/db";
// No-op method for protected routes that really should be accessible to anyone.
export const requireNothing = (ctx: TRPCContext) => {
ctx.markAccessControlRun();
};
export const requireCanViewExperiment = async (experimentId: string, ctx: TRPCContext) => {
await prisma.experiment.findFirst({
where: { id: experimentId },
});
// Right now all experiments are publicly viewable, so this is a no-op.
ctx.markAccessControlRun();
};
export const canModifyExperiment = async (experimentId: string, userId: string) => {
const experiment = await prisma.experiment.findFirst({
where: {
id: experimentId,
organization: {
OrganizationUser: {
some: {
role: { in: [OrganizationUserRole.ADMIN, OrganizationUserRole.MEMBER] },
userId,
},
},
},
},
});
return !!experiment;
};
export const requireCanModifyExperiment = async (experimentId: string, ctx: TRPCContext) => {
const userId = ctx.session?.user.id;
if (!userId) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
if (!(await canModifyExperiment(experimentId, userId))) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
ctx.markAccessControlRun();
};

View File

@@ -12,6 +12,10 @@ export const useExperiment = () => {
return experiment;
};
export const useExperimentAccess = () => {
return useExperiment().data?.access ?? { canView: false, canModify: false };
};
type AsyncFunction<T extends unknown[], U> = (...args: T) => Promise<U>;
export function useHandledAsyncCallback<T extends unknown[], U>(