Compare commits

..

13 Commits

Author SHA1 Message Date
Kyle Corbitt
60765e51ac Remove model from promptVariant and add cost
Storing the model on promptVariant is problematic because it isn't always in sync with the actual prompt definition. I'm removing it for now to see if we can get away with that -- might have to add it back in later if this causes trouble.

Added `cost` to modelOutput as well so we can cache that, which is important given that the cost calculations won't be the same between different API providers.
2023-07-19 16:20:53 -07:00
arcticfly
4c97b9f147 Refine prompt (#63)
* Remove unused ScenarioVariantCell fields

* Refine deriveNewConstructFn

* Fix prettier

* Remove migration script

* Add refine modal

* Fix prettier

* Fix diff checker overflow

* Decrease diff height
2023-07-19 15:31:40 -07:00
arcticfly
58892d8b63 Remove unused fields, refine model translation (#62)
* Remove unused ScenarioVariantCell fields

* Refine deriveNewConstructFn

* Fix prettier
2023-07-19 13:59:11 -07:00
Kyle Corbitt
4fa2dffbcb styling tweaks for SelectModelModal 2023-07-19 07:17:56 -07:00
Kyle Corbitt
654f8c7cf2 Merge pull request #61 from OpenPipe/experiment-page
More visual tweaks
2023-07-19 06:56:58 -07:00
Kyle Corbitt
d02482468d more visual tweaks 2023-07-19 06:54:07 -07:00
Kyle Corbitt
5c6ed22f1d Merge pull request #60 from OpenPipe/experiment-page
experiment page visual tweaks
2023-07-18 22:26:05 -07:00
Kyle Corbitt
2cb623f332 experiment page visual tweaks 2023-07-18 22:22:58 -07:00
Kyle Corbitt
1c1cefe286 Merge pull request #59 from OpenPipe/auth
User accounts
2023-07-18 21:21:46 -07:00
Kyle Corbitt
b4aa95edca sidebar mobile styles 2023-07-18 21:19:06 -07:00
Kyle Corbitt
1dcdba04a6 User accounts
Allows for the creation of user accounts. A few notes on the specifics:

 - Experiments are the main access control objects. If you can view an experiment, you can view all its prompts/scenarios/evals. If you can edit it, you can edit or delete all of those as well.
 - Experiments are owned by Organizations in the database. Organizations can have multiple members and members can have roles of ADMIN, MEMBER or VIEWER.
 - Organizations can either be "personal" or general. Each user has a "personal" organization created as soon as they try to create an experiment. There's currently no UI support for creating general orgs or adding users to them; they're just in the database to future-proof all the ACL logic.
 - You can require that a user is signed-in to see a route using the `protectedProcedure` helper. When you use `protectedProcedure`, you also have to call `ctx.markAccessControlRun()` (or delegate to a function that does it for you; see accessControl.ts). This is to remind us to actually check for access control when we define a new endpoint.
2023-07-18 21:19:03 -07:00
arcticfly
e0e64c4207 Allow user to create a version of their current prompt with a new model (#58)
* Add dropdown header for model switching

* Allow variant duplication

* Fix prettier

* Use env variable to restrict prisma logs

* Fix env.mjs

* Remove unnecessary scroll bar from function call output

* Properly record when 404 error occurs in queryLLM task

* Add SelectedModelInfo in SelectModelModal

* Add react-select

* Calculate new prompt after switching model

* Send newly selected model with creation request

* Get new prompt construction function back from GPT-4

* Fix prettier

* Fix prettier
2023-07-18 18:24:04 -07:00
arcticfly
fa5b1ab1c5 Allow user to duplicate prompt (#57)
* Add dropdown header for model switching

* Allow variant duplication

* Fix prettier
2023-07-18 13:49:33 -07:00
64 changed files with 2504 additions and 2112 deletions

View File

@@ -18,3 +18,11 @@ DATABASE_URL="postgresql://postgres:postgres@localhost:5432/openpipe?schema=publ
OPENAI_API_KEY="" OPENAI_API_KEY=""
NEXT_PUBLIC_SOCKET_URL="http://localhost:3318" NEXT_PUBLIC_SOCKET_URL="http://localhost:3318"
# Next Auth
NEXTAUTH_SECRET="your_secret"
NEXTAUTH_URL="http://localhost:3000"
# Next Auth Github Provider
GITHUB_CLIENT_ID="your_client_id"
GITHUB_CLIENT_SECRET="your_secret"

View File

@@ -11,6 +11,7 @@ declare module "nextjs-routes" {
} from "next"; } from "next";
export type Route = export type Route =
| StaticRoute<"/account/signin">
| DynamicRoute<"/api/auth/[...nextauth]", { "nextauth": string[] }> | DynamicRoute<"/api/auth/[...nextauth]", { "nextauth": string[] }>
| DynamicRoute<"/api/trpc/[trpc]", { "trpc": string }> | DynamicRoute<"/api/trpc/[trpc]", { "trpc": string }>
| DynamicRoute<"/experiments/[id]", { "id": string }> | DynamicRoute<"/experiments/[id]", { "id": string }>

View File

@@ -4,11 +4,18 @@
OpenPipe is a flexible playground for comparing and optimizing LLM prompts. It lets you quickly generate, test and compare candidate prompts with realistic sample data. OpenPipe is a flexible playground for comparing and optimizing LLM prompts. It lets you quickly generate, test and compare candidate prompts with realistic sample data.
**Live Demo:** https://openpipe.ai ## Sample Experiments
These are simple experiments users have created that show how OpenPipe works.
- [Country Capitals](https://openpipe.ai/experiments/11111111-1111-1111-1111-111111111111)
- [Reddit User Needs](https://openpipe.ai/experiments/22222222-2222-2222-2222-222222222222)
- [OpenAI Function Calls](https://openpipe.ai/experiments/2ebbdcb3-ed51-456e-87dc-91f72eaf3e2b)
- [Activity Classification](https://openpipe.ai/experiments/3950940f-ab6b-4b74-841d-7e9dbc4e4ff8)
<img src="https://github.com/openpipe/openpipe/assets/176426/fc7624c6-5b65-4d4d-82b7-4a816f3e5678" alt="demo" height="400px"> <img src="https://github.com/openpipe/openpipe/assets/176426/fc7624c6-5b65-4d4d-82b7-4a816f3e5678" alt="demo" height="400px">
Currently there's a public playground available at [https://openpipe.ai/](https://openpipe.ai/), but the recommended approach is to [run locally](#running-locally). You can use our hosted version of OpenPipe at [https://openpipe.ai]. You can also clone this repository and [run it locally](#running-locally).
## High-Level Features ## High-Level Features
@@ -47,5 +54,6 @@ OpenPipe currently supports GPT-3.5 and GPT-4. Wider model support is planned.
5. Install the dependencies: `cd openpipe && pnpm install` 5. Install the dependencies: `cd openpipe && pnpm install`
6. Create a `.env` file (`cp .env.example .env`) and enter your `OPENAI_API_KEY`. 6. Create a `.env` file (`cp .env.example .env`) and enter your `OPENAI_API_KEY`.
7. Update `DATABASE_URL` if necessary to point to your Postgres instance and run `pnpm prisma db push` to create the database. 7. Update `DATABASE_URL` if necessary to point to your Postgres instance and run `pnpm prisma db push` to create the database.
8. Start the app: `pnpm dev`. 8. Create a [GitHub OAuth App](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/creating-an-oauth-app) and update the `GITHUB_CLIENT_ID` and `GITHUB_CLIENT_SECRET` values. (Note: a PR to make auth optional when running locally would be a great contribution!)
9. Navigate to [http://localhost:3000](http://localhost:3000) 9. Start the app: `pnpm dev`.
10. Navigate to [http://localhost:3000](http://localhost:3000)

View File

@@ -17,7 +17,8 @@
"lint": "next lint", "lint": "next lint",
"start": "next start", "start": "next start",
"codegen": "tsx src/codegen/export-openai-types.ts", "codegen": "tsx src/codegen/export-openai-types.ts",
"seed": "tsx prisma/seed.ts" "seed": "tsx prisma/seed.ts",
"check": "concurrently 'pnpm lint' 'pnpm tsc' 'pnpm prettier . --check'"
}, },
"dependencies": { "dependencies": {
"@babel/preset-typescript": "^7.22.5", "@babel/preset-typescript": "^7.22.5",
@@ -27,6 +28,7 @@
"@emotion/react": "^11.11.1", "@emotion/react": "^11.11.1",
"@emotion/server": "^11.11.0", "@emotion/server": "^11.11.0",
"@emotion/styled": "^11.11.0", "@emotion/styled": "^11.11.0",
"@fontsource/inconsolata": "^5.0.5",
"@monaco-editor/loader": "^1.3.3", "@monaco-editor/loader": "^1.3.3",
"@next-auth/prisma-adapter": "^1.0.5", "@next-auth/prisma-adapter": "^1.0.5",
"@prisma/client": "^4.14.0", "@prisma/client": "^4.14.0",
@@ -58,9 +60,12 @@
"pluralize": "^8.0.0", "pluralize": "^8.0.0",
"posthog-js": "^1.68.4", "posthog-js": "^1.68.4",
"prettier": "^3.0.0", "prettier": "^3.0.0",
"prismjs": "^1.29.0",
"react": "18.2.0", "react": "18.2.0",
"react-diff-viewer": "^3.1.1",
"react-dom": "18.2.0", "react-dom": "18.2.0",
"react-icons": "^4.10.1", "react-icons": "^4.10.1",
"react-select": "^5.7.4",
"react-syntax-highlighter": "^15.5.0", "react-syntax-highlighter": "^15.5.0",
"react-textarea-autosize": "^8.5.0", "react-textarea-autosize": "^8.5.0",
"socket.io": "^4.7.1", "socket.io": "^4.7.1",
@@ -81,6 +86,7 @@
"@types/lodash-es": "^4.17.8", "@types/lodash-es": "^4.17.8",
"@types/node": "^18.16.0", "@types/node": "^18.16.0",
"@types/pluralize": "^0.0.30", "@types/pluralize": "^0.0.30",
"@types/prismjs": "^1.26.0",
"@types/react": "^18.2.6", "@types/react": "^18.2.6",
"@types/react-dom": "^18.2.4", "@types/react-dom": "^18.2.4",
"@types/react-syntax-highlighter": "^15.5.7", "@types/react-syntax-highlighter": "^15.5.7",

501
pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,124 @@
DROP TABLE "Account";
DROP TABLE "Session";
DROP TABLE "User";
DROP TABLE "VerificationToken";
CREATE TYPE "OrganizationUserRole" AS ENUM ('ADMIN', 'MEMBER', 'VIEWER');
-- CreateTable
CREATE TABLE "Organization" (
"id" UUID NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"personalOrgUserId" UUID,
CONSTRAINT "Organization_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "OrganizationUser" (
"id" UUID NOT NULL,
"role" "OrganizationUserRole" NOT NULL,
"organizationId" UUID NOT NULL,
"userId" UUID NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
CONSTRAINT "OrganizationUser_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "Account" (
"id" UUID NOT NULL,
"userId" UUID NOT NULL,
"type" TEXT NOT NULL,
"provider" TEXT NOT NULL,
"providerAccountId" TEXT NOT NULL,
"refresh_token" TEXT,
"refresh_token_expires_in" INTEGER,
"access_token" TEXT,
"expires_at" INTEGER,
"token_type" TEXT,
"scope" TEXT,
"id_token" TEXT,
"session_state" TEXT,
CONSTRAINT "Account_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "Session" (
"id" UUID NOT NULL,
"sessionToken" TEXT NOT NULL,
"userId" UUID NOT NULL,
"expires" TIMESTAMP(3) NOT NULL,
CONSTRAINT "Session_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "User" (
"id" UUID NOT NULL,
"name" TEXT,
"email" TEXT,
"emailVerified" TIMESTAMP(3),
"image" TEXT,
CONSTRAINT "User_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "VerificationToken" (
"identifier" TEXT NOT NULL,
"token" TEXT NOT NULL,
"expires" TIMESTAMP(3) NOT NULL
);
INSERT INTO "Organization" ("id", "updatedAt") VALUES ('11111111-1111-1111-1111-111111111111', CURRENT_TIMESTAMP);
-- AlterTable add organizationId as NULLABLE
ALTER TABLE "Experiment" ADD COLUMN "organizationId" UUID;
-- Set default organization for existing experiments
UPDATE "Experiment" SET "organizationId" = '11111111-1111-1111-1111-111111111111';
-- AlterTable set organizationId as NOT NULL
ALTER TABLE "Experiment" ALTER COLUMN "organizationId" SET NOT NULL;
-- CreateIndex
CREATE UNIQUE INDEX "OrganizationUser_organizationId_userId_key" ON "OrganizationUser"("organizationId", "userId");
-- CreateIndex
CREATE UNIQUE INDEX "Account_provider_providerAccountId_key" ON "Account"("provider", "providerAccountId");
-- CreateIndex
CREATE UNIQUE INDEX "Session_sessionToken_key" ON "Session"("sessionToken");
-- CreateIndex
CREATE UNIQUE INDEX "User_email_key" ON "User"("email");
-- CreateIndex
CREATE UNIQUE INDEX "VerificationToken_token_key" ON "VerificationToken"("token");
-- CreateIndex
CREATE UNIQUE INDEX "VerificationToken_identifier_token_key" ON "VerificationToken"("identifier", "token");
-- AddForeignKey
ALTER TABLE "Experiment" ADD CONSTRAINT "Experiment_organizationId_fkey" FOREIGN KEY ("organizationId") REFERENCES "Organization"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OrganizationUser" ADD CONSTRAINT "OrganizationUser_organizationId_fkey" FOREIGN KEY ("organizationId") REFERENCES "Organization"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OrganizationUser" ADD CONSTRAINT "OrganizationUser_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "Account" ADD CONSTRAINT "Account_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "Session" ADD CONSTRAINT "Session_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
CREATE UNIQUE INDEX "Organization_personalOrgUserId_key" ON "Organization"("personalOrgUserId");
ALTER TABLE "Organization" ADD CONSTRAINT "Organization_personalOrgUserId_fkey" FOREIGN KEY ("personalOrgUserId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -0,0 +1,16 @@
/*
Warnings:
- You are about to drop the column `completionTokens` on the `ScenarioVariantCell` table. All the data in the column will be lost.
- You are about to drop the column `inputHash` on the `ScenarioVariantCell` table. All the data in the column will be lost.
- You are about to drop the column `output` on the `ScenarioVariantCell` table. All the data in the column will be lost.
- You are about to drop the column `promptTokens` on the `ScenarioVariantCell` table. All the data in the column will be lost.
- You are about to drop the column `timeToComplete` on the `ScenarioVariantCell` table. All the data in the column will be lost.
*/
-- AlterTable
ALTER TABLE "ScenarioVariantCell" DROP COLUMN "completionTokens",
DROP COLUMN "inputHash",
DROP COLUMN "output",
DROP COLUMN "promptTokens",
DROP COLUMN "timeToComplete";

View File

@@ -0,0 +1,8 @@
/*
Warnings:
- You are about to drop the column `model` on the `PromptVariant` table. All the data in the column will be lost.
*/
-- AlterTable
ALTER TABLE "ModelOutput" ADD COLUMN "cost" DOUBLE PRECISION;

View File

@@ -16,8 +16,12 @@ model Experiment {
sortIndex Int @default(0) sortIndex Int @default(0)
createdAt DateTime @default(now()) organizationId String @db.Uuid
updatedAt DateTime @updatedAt organization Organization? @relation(fields: [organizationId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
TemplateVariable TemplateVariable[] TemplateVariable TemplateVariable[]
PromptVariant PromptVariant[] PromptVariant PromptVariant[]
TestScenario TestScenario[] TestScenario TestScenario[]
@@ -84,18 +88,13 @@ enum CellRetrievalStatus {
model ScenarioVariantCell { model ScenarioVariantCell {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
inputHash String? // TODO: Remove once migration is complete
output Json? // TODO: Remove once migration is complete
statusCode Int? statusCode Int?
errorMessage String? errorMessage String?
timeToComplete Int? @default(0) // TODO: Remove once migration is complete
retryTime DateTime? retryTime DateTime?
streamingChannel String? streamingChannel String?
retrievalStatus CellRetrievalStatus @default(COMPLETE) retrievalStatus CellRetrievalStatus @default(COMPLETE)
promptTokens Int? // TODO: Remove once migration is complete modelOutput ModelOutput?
completionTokens Int? // TODO: Remove once migration is complete
modelOutput ModelOutput?
promptVariantId String @db.Uuid promptVariantId String @db.Uuid
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade) promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
@@ -115,6 +114,7 @@ model ModelOutput {
inputHash String inputHash String
output Json output Json
timeToComplete Int @default(0) timeToComplete Int @default(0)
cost Float?
promptTokens Int? promptTokens Int?
completionTokens Int? completionTokens Int?
@@ -169,41 +169,77 @@ model OutputEvaluation {
@@unique([modelOutputId, evaluationId]) @@unique([modelOutputId, evaluationId])
} }
// Necessary for Next auth model Organization {
id String @id @default(uuid()) @db.Uuid
personalOrgUserId String? @unique @db.Uuid
PersonalOrgUser User? @relation(fields: [personalOrgUserId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
OrganizationUser OrganizationUser[]
Experiment Experiment[]
}
enum OrganizationUserRole {
ADMIN
MEMBER
VIEWER
}
model OrganizationUser {
id String @id @default(uuid()) @db.Uuid
role OrganizationUserRole
organizationId String @db.Uuid
organization Organization? @relation(fields: [organizationId], references: [id], onDelete: Cascade)
userId String @db.Uuid
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
@@unique([organizationId, userId])
}
model Account { model Account {
id String @id @default(cuid()) id String @id @default(uuid()) @db.Uuid
userId String userId String @db.Uuid
type String type String
provider String provider String
providerAccountId String providerAccountId String
refresh_token String? // @db.Text refresh_token String? @db.Text
access_token String? // @db.Text refresh_token_expires_in Int?
expires_at Int? access_token String? @db.Text
token_type String? expires_at Int?
scope String? token_type String?
id_token String? // @db.Text scope String?
session_state String? id_token String? @db.Text
user User @relation(fields: [userId], references: [id], onDelete: Cascade) session_state String?
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@unique([provider, providerAccountId]) @@unique([provider, providerAccountId])
} }
model Session { model Session {
id String @id @default(cuid()) id String @id @default(uuid()) @db.Uuid
sessionToken String @unique sessionToken String @unique
userId String userId String @db.Uuid
expires DateTime expires DateTime
user User @relation(fields: [userId], references: [id], onDelete: Cascade) user User @relation(fields: [userId], references: [id], onDelete: Cascade)
} }
model User { model User {
id String @id @default(cuid()) id String @id @default(uuid()) @db.Uuid
name String? name String?
email String? @unique email String? @unique
emailVerified DateTime? emailVerified DateTime?
image String? image String?
accounts Account[] accounts Account[]
sessions Session[] sessions Session[]
OrganizationUser OrganizationUser[]
Organization Organization[]
} }
model VerificationToken { model VerificationToken {

View File

@@ -2,40 +2,47 @@ import { prisma } from "~/server/db";
import dedent from "dedent"; import dedent from "dedent";
import { generateNewCell } from "~/server/utils/generateNewCell"; import { generateNewCell } from "~/server/utils/generateNewCell";
const experimentId = "11111111-1111-1111-1111-111111111111"; const defaultId = "11111111-1111-1111-1111-111111111111";
await prisma.organization.deleteMany({
where: { id: defaultId },
});
await prisma.organization.create({
data: { id: defaultId },
});
// Delete the existing experiment
await prisma.experiment.deleteMany({ await prisma.experiment.deleteMany({
where: { where: {
id: experimentId, id: defaultId,
}, },
}); });
await prisma.experiment.create({ await prisma.experiment.create({
data: { data: {
id: experimentId, id: defaultId,
label: "Country Capitals Example", label: "Country Capitals Example",
organizationId: defaultId,
}, },
}); });
await prisma.scenarioVariantCell.deleteMany({ await prisma.scenarioVariantCell.deleteMany({
where: { where: {
promptVariant: { promptVariant: {
experimentId, experimentId: defaultId,
}, },
}, },
}); });
await prisma.promptVariant.deleteMany({ await prisma.promptVariant.deleteMany({
where: { where: {
experimentId, experimentId: defaultId,
}, },
}); });
await prisma.promptVariant.createMany({ await prisma.promptVariant.createMany({
data: [ data: [
{ {
experimentId, experimentId: defaultId,
label: "Prompt Variant 1", label: "Prompt Variant 1",
sortIndex: 0, sortIndex: 0,
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
@@ -52,7 +59,7 @@ await prisma.promptVariant.createMany({
}`, }`,
}, },
{ {
experimentId, experimentId: defaultId,
label: "Prompt Variant 2", label: "Prompt Variant 2",
sortIndex: 1, sortIndex: 1,
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
@@ -73,14 +80,14 @@ await prisma.promptVariant.createMany({
await prisma.templateVariable.deleteMany({ await prisma.templateVariable.deleteMany({
where: { where: {
experimentId, experimentId: defaultId,
}, },
}); });
await prisma.templateVariable.createMany({ await prisma.templateVariable.createMany({
data: [ data: [
{ {
experimentId, experimentId: defaultId,
label: "country", label: "country",
}, },
], ],
@@ -88,28 +95,28 @@ await prisma.templateVariable.createMany({
await prisma.testScenario.deleteMany({ await prisma.testScenario.deleteMany({
where: { where: {
experimentId, experimentId: defaultId,
}, },
}); });
await prisma.testScenario.createMany({ await prisma.testScenario.createMany({
data: [ data: [
{ {
experimentId, experimentId: defaultId,
sortIndex: 0, sortIndex: 0,
variableValues: { variableValues: {
country: "Spain", country: "Spain",
}, },
}, },
{ {
experimentId, experimentId: defaultId,
sortIndex: 1, sortIndex: 1,
variableValues: { variableValues: {
country: "USA", country: "USA",
}, },
}, },
{ {
experimentId, experimentId: defaultId,
sortIndex: 2, sortIndex: 2,
variableValues: { variableValues: {
country: "Chile", country: "Chile",
@@ -120,13 +127,13 @@ await prisma.testScenario.createMany({
const variants = await prisma.promptVariant.findMany({ const variants = await prisma.promptVariant.findMany({
where: { where: {
experimentId, experimentId: defaultId,
}, },
}); });
const scenarios = await prisma.testScenario.findMany({ const scenarios = await prisma.testScenario.findMany({
where: { where: {
experimentId, experimentId: defaultId,
}, },
}); });

File diff suppressed because one or more lines are too long

View File

@@ -12,7 +12,7 @@ services:
dockerContext: . dockerContext: .
plan: standard plan: standard
domains: domains:
- openpipe.ai - app.openpipe.ai
envVars: envVars:
- key: NODE_ENV - key: NODE_ENV
value: production value: production

View File

@@ -1,7 +1,7 @@
import { Button, type ButtonProps, HStack, Spinner, Icon } from "@chakra-ui/react"; import { Button, type ButtonProps, HStack, Spinner, Icon } from "@chakra-ui/react";
import { BsPlus } from "react-icons/bs"; import { BsPlus } from "react-icons/bs";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
// Extracted Button styling into reusable component // Extracted Button styling into reusable component
const StyledButton = ({ children, onClick }: ButtonProps) => ( const StyledButton = ({ children, onClick }: ButtonProps) => (
@@ -17,6 +17,8 @@ const StyledButton = ({ children, onClick }: ButtonProps) => (
); );
export default function NewScenarioButton() { export default function NewScenarioButton() {
const { canModify } = useExperimentAccess();
const experiment = useExperiment(); const experiment = useExperiment();
const mutation = api.scenarios.create.useMutation(); const mutation = api.scenarios.create.useMutation();
const utils = api.useContext(); const utils = api.useContext();
@@ -38,6 +40,8 @@ export default function NewScenarioButton() {
await utils.scenarios.list.invalidate(); await utils.scenarios.list.invalidate();
}, [mutation]); }, [mutation]);
if (!canModify) return null;
return ( return (
<HStack spacing={2}> <HStack spacing={2}>
<StyledButton onClick={onClick}> <StyledButton onClick={onClick}>

View File

@@ -1,7 +1,7 @@
import { Button, Icon, Spinner } from "@chakra-ui/react"; import { Box, Button, Icon, Spinner, Text } from "@chakra-ui/react";
import { BsPlus } from "react-icons/bs"; import { BsPlus } from "react-icons/bs";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
import { cellPadding, headerMinHeight } from "../constants"; import { cellPadding, headerMinHeight } from "../constants";
export default function NewVariantButton() { export default function NewVariantButton() {
@@ -17,6 +17,9 @@ export default function NewVariantButton() {
await utils.promptVariants.list.invalidate(); await utils.promptVariants.list.invalidate();
}, [mutation]); }, [mutation]);
const { canModify } = useExperimentAccess();
if (!canModify) return <Box w={cellPadding.x} />;
return ( return (
<Button <Button
w="100%" w="100%"
@@ -31,7 +34,7 @@ export default function NewVariantButton() {
minH={headerMinHeight} minH={headerMinHeight}
> >
<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} /> <Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />
Add Variant <Text display={{ base: "none", md: "flex" }}>Add Variant</Text>
</Button> </Button>
); );
} }

View File

@@ -1,5 +1,6 @@
import { Button, HStack, Icon } from "@chakra-ui/react"; import { Button, HStack, Icon, Tooltip } from "@chakra-ui/react";
import { BsArrowClockwise } from "react-icons/bs"; import { BsArrowClockwise } from "react-icons/bs";
import { useExperimentAccess } from "~/utils/hooks";
export const CellOptions = ({ export const CellOptions = ({
refetchingOutput, refetchingOutput,
@@ -8,25 +9,28 @@ export const CellOptions = ({
refetchingOutput: boolean; refetchingOutput: boolean;
refetchOutput: () => void; refetchOutput: () => void;
}) => { }) => {
const { canModify } = useExperimentAccess();
return ( return (
<HStack justifyContent="flex-end" w="full"> <HStack justifyContent="flex-end" w="full">
{!refetchingOutput && ( {!refetchingOutput && canModify && (
<Button <Tooltip label="Refetch output" aria-label="refetch output">
size="xs" <Button
w={4} size="xs"
h={4} w={4}
py={4} h={4}
px={4} py={4}
minW={0} px={4}
borderRadius={8} minW={0}
color="gray.500" borderRadius={8}
variant="ghost" color="gray.500"
cursor="pointer" variant="ghost"
onClick={refetchOutput} cursor="pointer"
aria-label="refetch output" onClick={refetchOutput}
> aria-label="refetch output"
<Icon as={BsArrowClockwise} boxSize={4} /> >
</Button> <Icon as={BsArrowClockwise} boxSize={4} />
</Button>
</Tooltip>
)} )}
</HStack> </HStack>
); );

View File

@@ -106,7 +106,7 @@ export default function OutputCell({
h="100%" h="100%"
fontSize="xs" fontSize="xs"
flexWrap="wrap" flexWrap="wrap"
overflowX="auto" overflowX="hidden"
justifyContent="space-between" justifyContent="space-between"
> >
<VStack w="full" flex={1} spacing={0}> <VStack w="full" flex={1} spacing={0}>
@@ -129,7 +129,7 @@ export default function OutputCell({
)} )}
</SyntaxHighlighter> </SyntaxHighlighter>
</VStack> </VStack>
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} /> <OutputStats modelOutput={modelOutput} scenario={scenario} />
</VStack> </VStack>
); );
} }
@@ -143,9 +143,7 @@ export default function OutputCell({
<CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} /> <CellOptions refetchingOutput={refetchingOutput} refetchOutput={hardRefetch} />
<Text>{contentToDisplay}</Text> <Text>{contentToDisplay}</Text>
</VStack> </VStack>
{modelOutput && ( {modelOutput && <OutputStats modelOutput={modelOutput} scenario={scenario} />}
<OutputStats model={variant.model} modelOutput={modelOutput} scenario={scenario} />
)}
</VStack> </VStack>
); );
} }

View File

@@ -1,19 +1,14 @@
import { type SupportedModel } from "~/server/types";
import { type Scenario } from "../types"; import { type Scenario } from "../types";
import { type RouterOutputs } from "~/utils/api"; import { type RouterOutputs } from "~/utils/api";
import { calculateTokenCost } from "~/utils/calculateTokenCost";
import { HStack, Icon, Text, Tooltip } from "@chakra-ui/react"; import { HStack, Icon, Text, Tooltip } from "@chakra-ui/react";
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs"; import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
import { CostTooltip } from "~/components/tooltip/CostTooltip"; import { CostTooltip } from "~/components/tooltip/CostTooltip";
const SHOW_COST = true;
const SHOW_TIME = true; const SHOW_TIME = true;
export const OutputStats = ({ export const OutputStats = ({
model,
modelOutput, modelOutput,
}: { }: {
model: SupportedModel | string | null;
modelOutput: NonNullable< modelOutput: NonNullable<
NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelOutput"] NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelOutput"]
>; >;
@@ -24,12 +19,6 @@ export const OutputStats = ({
const promptTokens = modelOutput.promptTokens; const promptTokens = modelOutput.promptTokens;
const completionTokens = modelOutput.completionTokens; const completionTokens = modelOutput.completionTokens;
const promptCost = promptTokens && model ? calculateTokenCost(model, promptTokens) : 0;
const completionCost =
completionTokens && model ? calculateTokenCost(model, completionTokens, true) : 0;
const cost = promptCost + completionCost;
return ( return (
<HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}> <HStack w="full" align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
<HStack flex={1}> <HStack flex={1}>
@@ -53,11 +42,15 @@ export const OutputStats = ({
); );
})} })}
</HStack> </HStack>
{SHOW_COST && ( {modelOutput.cost && (
<CostTooltip promptTokens={promptTokens} completionTokens={completionTokens} cost={cost}> <CostTooltip
promptTokens={promptTokens}
completionTokens={completionTokens}
cost={modelOutput.cost}
>
<HStack spacing={0}> <HStack spacing={0}>
<Icon as={BsCurrencyDollar} /> <Icon as={BsCurrencyDollar} />
<Text mr={1}>{cost.toFixed(3)}</Text> <Text mr={1}>{modelOutput.cost.toFixed(3)}</Text>
</HStack> </HStack>
</CostTooltip> </CostTooltip>
)} )}

View File

@@ -2,7 +2,7 @@ import { type DragEvent } from "react";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { isEqual } from "lodash-es"; import { isEqual } from "lodash-es";
import { type Scenario } from "./types"; import { type Scenario } from "./types";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
import { useState } from "react"; import { useState } from "react";
import { Box, Button, Flex, HStack, Icon, Spinner, Stack, Tooltip, VStack } from "@chakra-ui/react"; import { Box, Button, Flex, HStack, Icon, Spinner, Stack, Tooltip, VStack } from "@chakra-ui/react";
@@ -19,6 +19,8 @@ export default function ScenarioEditor({
hovered: boolean; hovered: boolean;
canHide: boolean; canHide: boolean;
}) { }) {
const { canModify } = useExperimentAccess();
const savedValues = scenario.variableValues as Record<string, string>; const savedValues = scenario.variableValues as Record<string, string>;
const utils = api.useContext(); const utils = api.useContext();
const [isDragTarget, setIsDragTarget] = useState(false); const [isDragTarget, setIsDragTarget] = useState(false);
@@ -74,6 +76,7 @@ export default function ScenarioEditor({
alignItems="flex-start" alignItems="flex-start"
pr={cellPadding.x} pr={cellPadding.x}
py={cellPadding.y} py={cellPadding.y}
pl={canModify ? 0 : cellPadding.x}
height="100%" height="100%"
draggable={!variableInputHovered} draggable={!variableInputHovered}
onDragStart={(e) => { onDragStart={(e) => {
@@ -93,35 +96,38 @@ export default function ScenarioEditor({
onDrop={onReorder} onDrop={onReorder}
backgroundColor={isDragTarget ? "gray.100" : "transparent"} backgroundColor={isDragTarget ? "gray.100" : "transparent"}
> >
<Stack alignSelf="flex-start" opacity={props.hovered ? 1 : 0} spacing={0}> {canModify && (
{props.canHide && ( <Stack alignSelf="flex-start" opacity={props.hovered ? 1 : 0} spacing={0}>
<> {props.canHide && (
<Tooltip label="Hide scenario" hasArrow> <>
{/* for some reason the tooltip can't position itself properly relative to the icon without the wrapping box */} <Tooltip label="Hide scenario" hasArrow>
<Button {/* for some reason the tooltip can't position itself properly relative to the icon without the wrapping box */}
variant="unstyled" <Button
variant="unstyled"
color="gray.400"
height="unset"
width="unset"
minW="unset"
onClick={onHide}
_hover={{
color: "gray.800",
cursor: "pointer",
}}
>
<Icon as={hidingInProgress ? Spinner : BsX} boxSize={6} />
</Button>
</Tooltip>
<Icon
as={RiDraggable}
boxSize={6}
color="gray.400" color="gray.400"
height="unset" _hover={{ color: "gray.800", cursor: "pointer" }}
width="unset" />
minW="unset" </>
onClick={onHide} )}
_hover={{ </Stack>
color: "gray.800", )}
cursor: "pointer",
}}
>
<Icon as={hidingInProgress ? Spinner : BsX} boxSize={6} />
</Button>
</Tooltip>
<Icon
as={RiDraggable}
boxSize={6}
color="gray.400"
_hover={{ color: "gray.800", cursor: "pointer" }}
/>
</>
)}
</Stack>
{variableLabels.length === 0 ? ( {variableLabels.length === 0 ? (
<Box color="gray.500">{vars.data ? "No scenario variables configured" : "Loading..."}</Box> <Box color="gray.500">{vars.data ? "No scenario variables configured" : "Loading..."}</Box>
) : ( ) : (
@@ -155,6 +161,8 @@ export default function ScenarioEditor({
fontSize="sm" fontSize="sm"
lineHeight={1.2} lineHeight={1.2}
value={value} value={value}
isDisabled={!canModify}
_disabled={{ opacity: 1, cursor: "default" }}
onChange={(e) => { onChange={(e) => {
setValues((prev) => ({ ...prev, [key]: e.target.value })); setValues((prev) => ({ ...prev, [key]: e.target.value }));
}} }}

View File

@@ -1,6 +1,6 @@
import { Button, GridItem, HStack, Heading } from "@chakra-ui/react"; import { Button, GridItem, HStack, Heading } from "@chakra-ui/react";
import { cellPadding } from "../constants"; import { cellPadding } from "../constants";
import { useElementDimensions } from "~/utils/hooks"; import { useElementDimensions, useExperimentAccess } from "~/utils/hooks";
import { stickyHeaderStyle } from "./styles"; import { stickyHeaderStyle } from "./styles";
import { BsPencil } from "react-icons/bs"; import { BsPencil } from "react-icons/bs";
import { useAppStore } from "~/state/store"; import { useAppStore } from "~/state/store";
@@ -13,6 +13,7 @@ export const ScenariosHeader = ({
numScenarios: number; numScenarios: number;
}) => { }) => {
const openDrawer = useAppStore((s) => s.openDrawer); const openDrawer = useAppStore((s) => s.openDrawer);
const { canModify } = useExperimentAccess();
const [ref, dimensions] = useElementDimensions(); const [ref, dimensions] = useElementDimensions();
const topValue = dimensions ? `-${dimensions.height - 24}px` : "-455px"; const topValue = dimensions ? `-${dimensions.height - 24}px` : "-455px";
@@ -33,16 +34,18 @@ export const ScenariosHeader = ({
<Heading size="xs" fontWeight="bold" flex={1}> <Heading size="xs" fontWeight="bold" flex={1}>
Scenarios ({numScenarios}) Scenarios ({numScenarios})
</Heading> </Heading>
<Button {canModify && (
size="xs" <Button
variant="ghost" size="xs"
color="gray.500" variant="ghost"
aria-label="Edit" color="gray.500"
leftIcon={<BsPencil />} aria-label="Edit"
onClick={openDrawer} leftIcon={<BsPencil />}
> onClick={openDrawer}
Edit Vars >
</Button> Edit Vars
</Button>
)}
</HStack> </HStack>
</GridItem> </GridItem>
); );

View File

@@ -1,11 +1,12 @@
import { Box, Button, HStack, Spinner, Tooltip, useToast, Text } from "@chakra-ui/react"; import { Box, Button, HStack, Spinner, Tooltip, useToast, Text } from "@chakra-ui/react";
import { useRef, useEffect, useState, useCallback } from "react"; import { useRef, useEffect, useState, useCallback } from "react";
import { useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks"; import { useExperimentAccess, useHandledAsyncCallback, useModifierKeyLabel } from "~/utils/hooks";
import { type PromptVariant } from "./types"; import { type PromptVariant } from "./types";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useAppStore } from "~/state/store"; import { useAppStore } from "~/state/store";
export default function VariantEditor(props: { variant: PromptVariant }) { export default function VariantEditor(props: { variant: PromptVariant }) {
const { canModify } = useExperimentAccess();
const monaco = useAppStore.use.sharedVariantEditor.monaco(); const monaco = useAppStore.use.sharedVariantEditor.monaco();
const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null); const editorRef = useRef<ReturnType<NonNullable<typeof monaco>["editor"]["create"]> | null>(null);
const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`); const [editorId] = useState(() => `editor_${Math.random().toString(36).substring(7)}`);
@@ -21,7 +22,13 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
setIsChanged(currentFn.length > 0 && currentFn !== lastSavedFn); setIsChanged(currentFn.length > 0 && currentFn !== lastSavedFn);
}, [lastSavedFn]); }, [lastSavedFn]);
useEffect(checkForChanges, [checkForChanges, lastSavedFn]); const matchUpdatedSavedFn = useCallback(() => {
if (!editorRef.current) return;
editorRef.current.setValue(lastSavedFn);
setIsChanged(false);
}, [lastSavedFn]);
useEffect(matchUpdatedSavedFn, [matchUpdatedSavedFn, lastSavedFn]);
const replaceVariant = api.promptVariants.replaceVariant.useMutation(); const replaceVariant = api.promptVariants.replaceVariant.useMutation();
const utils = api.useContext(); const utils = api.useContext();
@@ -40,18 +47,6 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
const model = editorRef.current.getModel(); const model = editorRef.current.getModel();
if (!model) return; if (!model) return;
const markers = monaco?.editor.getModelMarkers({ resource: model.uri });
const hasErrors = markers?.some((m) => m.severity === monaco?.MarkerSeverity.Error);
if (hasErrors) {
toast({
title: "Invalid TypeScript",
description: "Please fix the TypeScript errors before saving.",
status: "error",
});
return;
}
// Make sure the user defined the prompt with the string "prompt\w*=" somewhere // Make sure the user defined the prompt with the string "prompt\w*=" somewhere
const promptRegex = /prompt\s*=/; const promptRegex = /prompt\s*=/;
if (!promptRegex.test(currentFn)) { if (!promptRegex.test(currentFn)) {
@@ -103,6 +98,7 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
wordWrapBreakAfterCharacters: "", wordWrapBreakAfterCharacters: "",
wordWrapBreakBeforeCharacters: "", wordWrapBreakBeforeCharacters: "",
quickSuggestions: true, quickSuggestions: true,
readOnly: !canModify,
}); });
editorRef.current.onDidFocusEditorText(() => { editorRef.current.onDidFocusEditorText(() => {
@@ -130,6 +126,13 @@ export default function VariantEditor(props: { variant: PromptVariant }) {
/* eslint-disable-next-line react-hooks/exhaustive-deps */ /* eslint-disable-next-line react-hooks/exhaustive-deps */
}, [monaco, editorId]); }, [monaco, editorId]);
useEffect(() => {
if (!editorRef.current) return;
editorRef.current.updateOptions({
readOnly: !canModify,
});
}, [canModify]);
return ( return (
<Box w="100%" pos="relative"> <Box w="100%" pos="relative">
<div id={editorId} style={{ height: "400px", width: "100%" }}></div> <div id={editorId} style={{ height: "400px", width: "100%" }}></div>

View File

@@ -1,107 +0,0 @@
import { useState, type DragEvent } from "react";
import { type PromptVariant } from "./types";
import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks";
import { Button, HStack, Icon, Tooltip } from "@chakra-ui/react"; // Changed here
import { BsX } from "react-icons/bs";
import { RiDraggable } from "react-icons/ri";
import { cellPadding, headerMinHeight } from "../constants";
import AutoResizeTextArea from "../AutoResizeTextArea";
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
const utils = api.useContext();
const [isDragTarget, setIsDragTarget] = useState(false);
const [isInputHovered, setIsInputHovered] = useState(false);
const [label, setLabel] = useState(props.variant.label);
const updateMutation = api.promptVariants.update.useMutation();
const [onSaveLabel] = useHandledAsyncCallback(async () => {
if (label && label !== props.variant.label) {
await updateMutation.mutateAsync({
id: props.variant.id,
updates: { label: label },
});
}
}, [updateMutation, props.variant.id, props.variant.label, label]);
const hideMutation = api.promptVariants.hide.useMutation();
const [onHide] = useHandledAsyncCallback(async () => {
await hideMutation.mutateAsync({
id: props.variant.id,
});
await utils.promptVariants.list.invalidate();
}, [hideMutation, props.variant.id]);
const reorderMutation = api.promptVariants.reorder.useMutation();
const [onReorder] = useHandledAsyncCallback(
async (e: DragEvent<HTMLDivElement>) => {
e.preventDefault();
setIsDragTarget(false);
const draggedId = e.dataTransfer.getData("text/plain");
const droppedId = props.variant.id;
if (!draggedId || !droppedId || draggedId === droppedId) return;
await reorderMutation.mutateAsync({
draggedId,
droppedId,
});
await utils.promptVariants.list.invalidate();
},
[reorderMutation, props.variant.id],
);
return (
<HStack
spacing={4}
alignItems="center"
minH={headerMinHeight}
draggable={!isInputHovered}
onDragStart={(e) => {
e.dataTransfer.setData("text/plain", props.variant.id);
e.currentTarget.style.opacity = "0.4";
}}
onDragEnd={(e) => {
e.currentTarget.style.opacity = "1";
}}
onDragOver={(e) => {
e.preventDefault();
setIsDragTarget(true);
}}
onDragLeave={() => {
setIsDragTarget(false);
}}
onDrop={onReorder}
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
>
<Icon
as={RiDraggable}
boxSize={6}
color="gray.400"
_hover={{ color: "gray.800", cursor: "pointer" }}
/>
<AutoResizeTextArea // Changed to Input
size="sm"
value={label}
onChange={(e) => setLabel(e.target.value)}
onBlur={onSaveLabel}
placeholder="Variant Name"
borderWidth={1}
borderColor="transparent"
fontWeight="bold"
fontSize={16}
_hover={{ borderColor: "gray.300" }}
_focus={{ borderColor: "blue.500", outline: "none" }}
flex={1}
px={cellPadding.x}
onMouseEnter={() => setIsInputHovered(true)}
onMouseLeave={() => setIsInputHovered(false)}
/>
{props.canHide && (
<Tooltip label="Remove Variant" hasArrow>
<Button variant="ghost" colorScheme="gray" size="sm" onClick={onHide}>
<Icon as={BsX} boxSize={6} />
</Button>
</Tooltip>
)}
</HStack>
);
}

View File

@@ -1,4 +1,4 @@
import { HStack, Icon, Skeleton, Text, useToken } from "@chakra-ui/react"; import { HStack, Icon, Text, useToken } from "@chakra-ui/react";
import { type PromptVariant } from "./types"; import { type PromptVariant } from "./types";
import { cellPadding } from "../constants"; import { cellPadding } from "../constants";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
@@ -69,7 +69,7 @@ export default function VariantStats(props: { variant: PromptVariant }) {
); );
})} })}
</HStack> </HStack>
{data.overallCost && !data.awaitingRetrievals ? ( {data.overallCost && !data.awaitingRetrievals && (
<CostTooltip <CostTooltip
promptTokens={data.promptTokens} promptTokens={data.promptTokens}
completionTokens={data.completionTokens} completionTokens={data.completionTokens}
@@ -80,8 +80,6 @@ export default function VariantStats(props: { variant: PromptVariant }) {
<Text mr={1}>{data.overallCost.toFixed(3)}</Text> <Text mr={1}>{data.overallCost.toFixed(3)}</Text>
</HStack> </HStack>
</CostTooltip> </CostTooltip>
) : (
<Skeleton height={4} width={12} mr={1} />
)} )}
</HStack> </HStack>
); );

View File

@@ -4,7 +4,7 @@ import NewScenarioButton from "./NewScenarioButton";
import NewVariantButton from "./NewVariantButton"; import NewVariantButton from "./NewVariantButton";
import ScenarioRow from "./ScenarioRow"; import ScenarioRow from "./ScenarioRow";
import VariantEditor from "./VariantEditor"; import VariantEditor from "./VariantEditor";
import VariantHeader from "./VariantHeader"; import VariantHeader from "../VariantHeader/VariantHeader";
import VariantStats from "./VariantStats"; import VariantStats from "./VariantStats";
import { ScenariosHeader } from "./ScenariosHeader"; import { ScenariosHeader } from "./ScenariosHeader";
import { stickyHeaderStyle } from "./styles"; import { stickyHeaderStyle } from "./styles";
@@ -43,9 +43,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
<ScenariosHeader headerRows={headerRows} numScenarios={scenarios.data.length} /> <ScenariosHeader headerRows={headerRows} numScenarios={scenarios.data.length} />
{variants.data.map((variant) => ( {variants.data.map((variant) => (
<GridItem key={variant.uiId} padding={0} sx={stickyHeaderStyle} borderTopWidth={1}> <VariantHeader key={variant.uiId} variant={variant} canHide={variants.data.length > 1} />
<VariantHeader variant={variant} canHide={variants.data.length > 1} />
</GridItem>
))} ))}
<GridItem <GridItem
rowSpan={scenarios.data.length + headerRows} rowSpan={scenarios.data.length + headerRows}

View File

@@ -2,7 +2,7 @@ import { type SystemStyleObject } from "@chakra-ui/react";
export const stickyHeaderStyle: SystemStyleObject = { export const stickyHeaderStyle: SystemStyleObject = {
position: "sticky", position: "sticky",
top: "-1px", top: "0",
backgroundColor: "#fff", backgroundColor: "#fff",
zIndex: 1, zIndex: 1,
}; };

View File

@@ -1,21 +0,0 @@
import { Flex, Icon, Link, Text } from "@chakra-ui/react";
import { BsExclamationTriangleFill } from "react-icons/bs";
import { env } from "~/env.mjs";
export default function PublicPlaygroundWarning() {
if (!env.NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND) return null;
return (
<Flex bgColor="red.600" color="whiteAlpha.900" p={2} align="center">
<Icon boxSize={4} mr={2} as={BsExclamationTriangleFill} />
<Text>
Warning: this is a public playground. Anyone can see, edit or delete your experiments. For
private use,{" "}
<Link textDecor="underline" href="https://github.com/openpipe/openpipe" target="_blank">
run a local copy
</Link>
.
</Text>
</Flex>
);
}

View File

@@ -0,0 +1,45 @@
import { HStack, VStack } from "@chakra-ui/react";
import React from "react";
import DiffViewer, { DiffMethod } from "react-diff-viewer";
import Prism from "prismjs";
import "prismjs/components/prism-javascript";
import "prismjs/themes/prism.css"; // choose a theme you like
const CompareFunctions = ({
originalFunction,
newFunction = "",
}: {
originalFunction: string;
newFunction?: string;
}) => {
console.log("newFunction", newFunction);
const highlightSyntax = (str: string) => {
let highlighted;
try {
highlighted = Prism.highlight(str, Prism.languages.javascript as Prism.Grammar, "javascript");
} catch (e) {
console.error("Error highlighting:", e);
highlighted = str;
}
return <pre style={{ display: "inline" }} dangerouslySetInnerHTML={{ __html: highlighted }} />;
};
return (
<HStack w="full" spacing={5}>
<VStack w="full" spacing={4} maxH="65vh" fontSize={12} lineHeight={1} overflowY="auto">
<DiffViewer
oldValue={originalFunction}
newValue={newFunction || originalFunction}
splitView={true}
hideLineNumbers={true}
leftTitle="Original"
rightTitle={newFunction ? "Modified" : "Unmodified"}
disableWordDiff={true}
compareMethod={DiffMethod.CHARS}
renderContent={highlightSyntax}
/>
</VStack>
</HStack>
);
};
export default CompareFunctions;

View File

@@ -0,0 +1,103 @@
import {
Button,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
VStack,
Text,
Spinner,
HStack,
} from "@chakra-ui/react";
import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks";
import { type PromptVariant } from "@prisma/client";
import { useState } from "react";
import AutoResizeTextArea from "../AutoResizeTextArea";
import CompareFunctions from "./CompareFunctions";
export const RefinePromptModal = ({
variant,
onClose,
}: {
variant: PromptVariant;
onClose: () => void;
}) => {
const utils = api.useContext();
const { mutateAsync: getRefinedPromptMutateAsync, data: refinedPromptFn } =
api.promptVariants.getRefinedPromptFn.useMutation();
const [instructions, setInstructions] = useState<string>("");
const [getRefinedPromptFn, refiningInProgress] = useHandledAsyncCallback(async () => {
if (!variant.experimentId) return;
await getRefinedPromptMutateAsync({
id: variant.id,
instructions,
});
}, [getRefinedPromptMutateAsync, onClose, variant, instructions]);
const replaceVariantMutation = api.promptVariants.replaceVariant.useMutation();
const [replaceVariant, replacementInProgress] = useHandledAsyncCallback(async () => {
if (!variant.experimentId || !refinedPromptFn) return;
await replaceVariantMutation.mutateAsync({
id: variant.id,
constructFn: refinedPromptFn,
});
await utils.promptVariants.list.invalidate();
onClose();
}, [replaceVariantMutation, variant, onClose, refinedPromptFn]);
return (
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "7xl" }}>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader>Refine Your Prompt</ModalHeader>
<ModalCloseButton />
<ModalBody maxW="unset">
<VStack spacing={8}>
<HStack w="full">
<AutoResizeTextArea
value={instructions}
onChange={(e) => setInstructions(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter" && !e.metaKey && !e.ctrlKey && !e.shiftKey) {
e.preventDefault();
e.currentTarget.blur();
getRefinedPromptFn();
}
}}
placeholder="Use chain of thought"
/>
<Button onClick={getRefinedPromptFn}>
{refiningInProgress ? <Spinner boxSize={4} /> : <Text>Submit</Text>}
</Button>
</HStack>
<CompareFunctions
originalFunction={variant.constructFn}
newFunction={refinedPromptFn}
/>
</VStack>
</ModalBody>
<ModalFooter>
<HStack spacing={4}>
<Button onClick={onClose}>Cancel</Button>
<Button
colorScheme="blue"
onClick={replaceVariant}
minW={24}
disabled={!refinedPromptFn}
>
{replacementInProgress ? <Spinner boxSize={4} /> : <Text>Accept</Text>}
</Button>
</HStack>
</ModalFooter>
</ModalContent>
</Modal>
);
};

View File

@@ -0,0 +1,89 @@
import {
VStack,
Text,
HStack,
type StackProps,
GridItem,
SimpleGrid,
Link,
} from "@chakra-ui/react";
import { modelStats } from "~/server/modelStats";
import { type SupportedModel } from "~/server/types";
export const ModelStatsCard = ({ label, model }: { label: string; model: SupportedModel }) => {
const stats = modelStats[model];
return (
<VStack w="full" align="start">
<Text fontWeight="bold" fontSize="sm" textTransform="uppercase">
{label}
</Text>
<VStack w="full" spacing={6} bgColor="gray.100" p={4} borderRadius={4}>
<HStack w="full" align="flex-start">
<Text flex={1} fontSize="lg">
<Text as="span" color="gray.600">
{stats.provider} /{" "}
</Text>
<Text as="span" fontWeight="bold" color="gray.900">
{model}
</Text>
</Text>
<Link
href={stats.learnMoreUrl}
isExternal
color="blue.500"
fontWeight="bold"
fontSize="sm"
ml={2}
>
Learn More
</Link>
</HStack>
<SimpleGrid
w="full"
justifyContent="space-between"
alignItems="flex-start"
fontSize="sm"
columns={{ base: 2, md: 4 }}
>
<SelectedModelLabeledInfo label="Context" info={stats.contextLength} />
<SelectedModelLabeledInfo
label="Input"
info={
<Text>
${(stats.promptTokenPrice * 1000).toFixed(3)}
<Text color="gray.500"> / 1K tokens</Text>
</Text>
}
/>
<SelectedModelLabeledInfo
label="Output"
info={
<Text>
${(stats.promptTokenPrice * 1000).toFixed(3)}
<Text color="gray.500"> / 1K tokens</Text>
</Text>
}
/>
<SelectedModelLabeledInfo label="Speed" info={<Text>{stats.speed}</Text>} />
</SimpleGrid>
</VStack>
</VStack>
);
};
const SelectedModelLabeledInfo = ({
label,
info,
...props
}: {
label: string;
info: string | number | React.ReactElement;
} & StackProps) => (
<GridItem>
<VStack alignItems="flex-start" {...props}>
<Text fontWeight="bold">{label}</Text>
<Text>{info}</Text>
</VStack>
</GridItem>
);

View File

@@ -0,0 +1,77 @@
import {
Button,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
VStack,
Text,
Spinner,
} from "@chakra-ui/react";
import { useState } from "react";
import { type SupportedModel } from "~/server/types";
import { ModelStatsCard } from "./ModelStatsCard";
import { SelectModelSearch } from "./SelectModelSearch";
import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
export const SelectModelModal = ({
originalModel,
variantId,
onClose,
}: {
originalModel: SupportedModel;
variantId: string;
onClose: () => void;
}) => {
const [selectedModel, setSelectedModel] = useState<SupportedModel>(originalModel);
const utils = api.useContext();
const experiment = useExperiment();
const createMutation = api.promptVariants.create.useMutation();
const [createNewVariant, creationInProgress] = useHandledAsyncCallback(async () => {
if (!experiment?.data?.id) return;
await createMutation.mutateAsync({
experimentId: experiment?.data?.id,
variantId,
newModel: selectedModel,
});
await utils.promptVariants.list.invalidate();
onClose();
}, [createMutation, experiment?.data?.id, variantId, onClose]);
return (
<Modal isOpen onClose={onClose} size={{ base: "xl", sm: "2xl", md: "3xl" }}>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader>Select a New Model</ModalHeader>
<ModalCloseButton />
<ModalBody maxW="unset">
<VStack spacing={8}>
<ModelStatsCard label="Original Model" model={originalModel} />
{originalModel !== selectedModel && (
<ModelStatsCard label="New Model" model={selectedModel} />
)}
<SelectModelSearch selectedModel={selectedModel} setSelectedModel={setSelectedModel} />
</VStack>
</ModalBody>
<ModalFooter>
<Button
colorScheme="blue"
onClick={createNewVariant}
minW={24}
disabled={originalModel === selectedModel}
>
{creationInProgress ? <Spinner boxSize={4} /> : <Text>Continue</Text>}
</Button>
</ModalFooter>
</ModalContent>
</Modal>
);
};

View File

@@ -0,0 +1,47 @@
import { VStack, Text } from "@chakra-ui/react";
import { type LegacyRef, useCallback } from "react";
import Select, { type SingleValue } from "react-select";
import { type SupportedModel } from "~/server/types";
import { useElementDimensions } from "~/utils/hooks";
const modelOptions: { value: SupportedModel; label: string }[] = [
{ value: "gpt-3.5-turbo", label: "gpt-3.5-turbo" },
{ value: "gpt-3.5-turbo-0613", label: "gpt-3.5-turbo-0613" },
{ value: "gpt-3.5-turbo-16k", label: "gpt-3.5-turbo-16k" },
{ value: "gpt-3.5-turbo-16k-0613", label: "gpt-3.5-turbo-16k-0613" },
{ value: "gpt-4", label: "gpt-4" },
{ value: "gpt-4-0613", label: "gpt-4-0613" },
{ value: "gpt-4-32k", label: "gpt-4-32k" },
{ value: "gpt-4-32k-0613", label: "gpt-4-32k-0613" },
];
export const SelectModelSearch = ({
selectedModel,
setSelectedModel,
}: {
selectedModel: SupportedModel;
setSelectedModel: (model: SupportedModel) => void;
}) => {
const handleSelection = useCallback(
(option: SingleValue<{ value: SupportedModel; label: string }>) => {
if (!option) return;
setSelectedModel(option.value);
},
[setSelectedModel],
);
const selectedOption = modelOptions.find((option) => option.value === selectedModel);
const [containerRef, containerDimensions] = useElementDimensions();
return (
<VStack ref={containerRef as LegacyRef<HTMLDivElement>} w="full">
<Text>Browse Models</Text>
<Select
styles={{ control: (provided) => ({ ...provided, width: containerDimensions?.width }) }}
value={selectedOption}
options={modelOptions}
onChange={handleSelection}
/>
</VStack>
);
};

View File

@@ -0,0 +1,123 @@
import { useState, type DragEvent } from "react";
import { type PromptVariant } from "../OutputsTable/types";
import { api } from "~/utils/api";
import { RiDraggable } from "react-icons/ri";
import { useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
import { HStack, Icon, Text, GridItem } from "@chakra-ui/react"; // Changed here
import { cellPadding, headerMinHeight } from "../constants";
import AutoResizeTextArea from "../AutoResizeTextArea";
import { stickyHeaderStyle } from "../OutputsTable/styles";
import VariantHeaderMenuButton from "./VariantHeaderMenuButton";
export default function VariantHeader(props: { variant: PromptVariant; canHide: boolean }) {
const { canModify } = useExperimentAccess();
const utils = api.useContext();
const [isDragTarget, setIsDragTarget] = useState(false);
const [isInputHovered, setIsInputHovered] = useState(false);
const [label, setLabel] = useState(props.variant.label);
const updateMutation = api.promptVariants.update.useMutation();
const [onSaveLabel] = useHandledAsyncCallback(async () => {
if (label && label !== props.variant.label) {
await updateMutation.mutateAsync({
id: props.variant.id,
updates: { label: label },
});
}
}, [updateMutation, props.variant.id, props.variant.label, label]);
const reorderMutation = api.promptVariants.reorder.useMutation();
const [onReorder] = useHandledAsyncCallback(
async (e: DragEvent<HTMLDivElement>) => {
e.preventDefault();
setIsDragTarget(false);
const draggedId = e.dataTransfer.getData("text/plain");
const droppedId = props.variant.id;
if (!draggedId || !droppedId || draggedId === droppedId) return;
await reorderMutation.mutateAsync({
draggedId,
droppedId,
});
await utils.promptVariants.list.invalidate();
},
[reorderMutation, props.variant.id],
);
const [menuOpen, setMenuOpen] = useState(false);
if (!canModify) {
return (
<GridItem padding={0} sx={stickyHeaderStyle} borderTopWidth={1}>
<Text fontSize={16} fontWeight="bold" px={cellPadding.x} py={cellPadding.y}>
{props.variant.label}
</Text>
</GridItem>
);
}
return (
<GridItem
padding={0}
sx={{
...stickyHeaderStyle,
// Ensure that the menu always appears above the sticky header of other variants
zIndex: menuOpen ? "dropdown" : stickyHeaderStyle.zIndex,
}}
borderTopWidth={1}
>
<HStack
spacing={4}
alignItems="flex-start"
minH={headerMinHeight}
draggable={!isInputHovered}
onDragStart={(e) => {
e.dataTransfer.setData("text/plain", props.variant.id);
e.currentTarget.style.opacity = "0.4";
}}
onDragEnd={(e) => {
e.currentTarget.style.opacity = "1";
}}
onDragOver={(e) => {
e.preventDefault();
setIsDragTarget(true);
}}
onDragLeave={() => {
setIsDragTarget(false);
}}
onDrop={onReorder}
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
>
<Icon
as={RiDraggable}
boxSize={6}
mt={2}
color="gray.400"
_hover={{ color: "gray.800", cursor: "pointer" }}
/>
<AutoResizeTextArea
size="sm"
value={label}
onChange={(e) => setLabel(e.target.value)}
onBlur={onSaveLabel}
placeholder="Variant Name"
borderWidth={1}
borderColor="transparent"
fontWeight="bold"
fontSize={16}
_hover={{ borderColor: "gray.300" }}
_focus={{ borderColor: "blue.500", outline: "none" }}
flex={1}
px={cellPadding.x}
onMouseEnter={() => setIsInputHovered(true)}
onMouseLeave={() => setIsInputHovered(false)}
/>
<VariantHeaderMenuButton
variant={props.variant}
canHide={props.canHide}
menuOpen={menuOpen}
setMenuOpen={setMenuOpen}
/>
</HStack>
</GridItem>
);
}

View File

@@ -0,0 +1,114 @@
import { type PromptVariant } from "../OutputsTable/types";
import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks";
import {
Button,
Icon,
Menu,
MenuButton,
MenuItem,
MenuList,
MenuDivider,
Text,
Spinner,
} from "@chakra-ui/react";
import { BsFillTrashFill, BsGear } from "react-icons/bs";
import { FaRegClone } from "react-icons/fa";
import { AiOutlineDiff } from "react-icons/ai";
import { useState } from "react";
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
import { RiExchangeFundsFill } from "react-icons/ri";
import { SelectModelModal } from "../SelectModelModal/SelectModelModal";
import { type SupportedModel } from "~/server/types";
export default function VariantHeaderMenuButton({
variant,
canHide,
menuOpen,
setMenuOpen,
}: {
variant: PromptVariant;
canHide: boolean;
menuOpen: boolean;
setMenuOpen: (open: boolean) => void;
}) {
const utils = api.useContext();
const duplicateMutation = api.promptVariants.create.useMutation();
const [duplicateVariant, duplicationInProgress] = useHandledAsyncCallback(async () => {
await duplicateMutation.mutateAsync({
experimentId: variant.experimentId,
variantId: variant.id,
});
await utils.promptVariants.list.invalidate();
}, [duplicateMutation, variant.experimentId, variant.id]);
const hideMutation = api.promptVariants.hide.useMutation();
const [onHide] = useHandledAsyncCallback(async () => {
await hideMutation.mutateAsync({
id: variant.id,
});
await utils.promptVariants.list.invalidate();
}, [hideMutation, variant.id]);
const [selectModelModalOpen, setSelectModelModalOpen] = useState(false);
const [refinePromptModalOpen, setRefinePromptModalOpen] = useState(false);
return (
<>
<Menu isOpen={menuOpen} onOpen={() => setMenuOpen(true)} onClose={() => setMenuOpen(false)}>
{duplicationInProgress ? (
<Spinner boxSize={4} mx={3} my={3} />
) : (
<MenuButton>
<Button variant="ghost">
<Icon as={BsGear} />
</Button>
</MenuButton>
)}
<MenuList mt={-3} fontSize="md">
<MenuItem icon={<Icon as={FaRegClone} boxSize={4} w={5} />} onClick={duplicateVariant}>
Duplicate
</MenuItem>
<MenuItem
icon={<Icon as={RiExchangeFundsFill} boxSize={5} />}
onClick={() => setSelectModelModalOpen(true)}
>
Change Model
</MenuItem>
<MenuItem
icon={<Icon as={AiOutlineDiff} boxSize={5} />}
onClick={() => setRefinePromptModalOpen(true)}
>
Refine
</MenuItem>
{canHide && (
<>
<MenuDivider />
<MenuItem
onClick={onHide}
icon={<Icon as={BsFillTrashFill} boxSize={5} />}
color="red.600"
_hover={{ backgroundColor: "red.50" }}
>
<Text>Hide</Text>
</MenuItem>
</>
)}
</MenuList>
</Menu>
{selectModelModalOpen && (
<SelectModelModal
originalModel={variant.model as SupportedModel}
variantId={variant.id}
onClose={() => setSelectModelModalOpen(false)}
/>
)}
{refinePromptModalOpen && (
<RefinePromptModal variant={variant} onClose={() => setRefinePromptModalOpen(false)} />
)}
</>
);
}

View File

@@ -1,17 +1,11 @@
import { import { HStack, Icon, VStack, Text, Divider, Spinner, AspectRatio } from "@chakra-ui/react";
Card,
CardBody,
HStack,
Icon,
VStack,
Text,
CardHeader,
Divider,
Box,
} from "@chakra-ui/react";
import { RiFlaskLine } from "react-icons/ri"; import { RiFlaskLine } from "react-icons/ri";
import { formatTimePast } from "~/utils/dayjs"; import { formatTimePast } from "~/utils/dayjs";
import Link from "next/link";
import { useRouter } from "next/router"; import { useRouter } from "next/router";
import { BsPlusSquare } from "react-icons/bs";
import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks";
type ExperimentData = { type ExperimentData = {
testScenarioCount: number; testScenarioCount: number;
@@ -24,47 +18,42 @@ type ExperimentData = {
}; };
export const ExperimentCard = ({ exp }: { exp: ExperimentData }) => { export const ExperimentCard = ({ exp }: { exp: ExperimentData }) => {
const router = useRouter();
return ( return (
<Box <AspectRatio ratio={1.2} w="full">
as={Card} <VStack
variant="elevated" as={Link}
bg="gray.50" href={{ pathname: "/experiments/[id]", query: { id: exp.id } }}
_hover={{ bg: "gray.100" }} bg="gray.50"
transition="background 0.2s" _hover={{ bg: "gray.100" }}
cursor="pointer" transition="background 0.2s"
onClick={(e) => { cursor="pointer"
e.preventDefault(); borderColor="gray.200"
void router.push({ pathname: "/experiments/[id]", query: { id: exp.id } }, undefined, { borderWidth={1}
shallow: true, p={4}
}); justify="space-between"
}} >
> <HStack w="full" color="gray.700" justify="center">
<CardHeader>
<HStack w="full" color="gray.700">
<Icon as={RiFlaskLine} boxSize={4} /> <Icon as={RiFlaskLine} boxSize={4} />
<Text fontWeight="bold">{exp.label}</Text> <Text fontWeight="bold">{exp.label}</Text>
</HStack> </HStack>
</CardHeader> <HStack h="full" spacing={4} flex={1} align="center">
<CardBody>
<HStack w="full" mb={8} spacing={4}>
<CountLabel label="Variants" count={exp.promptVariantCount} /> <CountLabel label="Variants" count={exp.promptVariantCount} />
<Divider h={12} orientation="vertical" /> <Divider h={12} orientation="vertical" />
<CountLabel label="Scenarios" count={exp.testScenarioCount} /> <CountLabel label="Scenarios" count={exp.testScenarioCount} />
</HStack> </HStack>
<HStack w="full" color="gray.500" fontSize="xs"> <HStack w="full" color="gray.500" fontSize="xs" textAlign="center">
<Text>Created {formatTimePast(exp.createdAt)}</Text> <Text flex={1}>Created {formatTimePast(exp.createdAt)}</Text>
<Divider h={4} orientation="vertical" /> <Divider h={4} orientation="vertical" />
<Text>Updated {formatTimePast(exp.updatedAt)}</Text> <Text flex={1}>Updated {formatTimePast(exp.updatedAt)}</Text>
</HStack> </HStack>
</CardBody> </VStack>
</Box> </AspectRatio>
); );
}; };
const CountLabel = ({ label, count }: { label: string; count: number }) => { const CountLabel = ({ label, count }: { label: string; count: number }) => {
return ( return (
<VStack alignItems="flex-start"> <VStack alignItems="center" flex={1}>
<Text color="gray.500" fontWeight="bold"> <Text color="gray.500" fontWeight="bold">
{label} {label}
</Text> </Text>
@@ -74,3 +63,33 @@ const CountLabel = ({ label, count }: { label: string; count: number }) => {
</VStack> </VStack>
); );
}; };
export const NewExperimentCard = () => {
const router = useRouter();
const createMutation = api.experiments.create.useMutation();
const [createExperiment, isLoading] = useHandledAsyncCallback(async () => {
const newExperiment = await createMutation.mutateAsync({ label: "New Experiment" });
await router.push({ pathname: "/experiments/[id]", query: { id: newExperiment.id } });
}, [createMutation, router]);
return (
<AspectRatio ratio={1.2} w="full">
<VStack
align="center"
justify="center"
_hover={{ cursor: "pointer", bg: "gray.50" }}
transition="background 0.2s"
cursor="pointer"
borderColor="gray.200"
borderWidth={1}
p={4}
onClick={createExperiment}
>
<Icon as={isLoading ? Spinner : BsPlusSquare} boxSize={8} />
<Text display={{ base: "none", md: "block" }} ml={2}>
New Experiment
</Text>
</VStack>
</AspectRatio>
);
};

View File

@@ -1,31 +0,0 @@
import { Icon, Button, Spinner, Text, type ButtonProps } from "@chakra-ui/react";
import { api } from "~/utils/api";
import { useRouter } from "next/router";
import { BsPlusSquare } from "react-icons/bs";
import { useHandledAsyncCallback } from "~/utils/hooks";
export const NewExperimentButton = (props: ButtonProps) => {
const router = useRouter();
const utils = api.useContext();
const createMutation = api.experiments.create.useMutation();
const [createExperiment, isLoading] = useHandledAsyncCallback(async () => {
const newExperiment = await createMutation.mutateAsync({ label: "New Experiment" });
await utils.experiments.list.invalidate();
await router.push({ pathname: "/experiments/[id]", query: { id: newExperiment.id } });
}, [createMutation, router]);
return (
<Button
onClick={createExperiment}
display="flex"
alignItems="center"
variant={{ base: "solid", md: "ghost" }}
{...props}
>
<Icon as={isLoading ? Spinner : BsPlusSquare} boxSize={4} />
<Text display={{ base: "none", md: "block" }} ml={2}>
New Experiment
</Text>
</Button>
);
};

View File

@@ -1,84 +1,100 @@
import { useState, useEffect } from "react";
import { import {
Heading, Heading,
VStack, VStack,
Icon, Icon,
HStack, HStack,
Image, Image,
Grid,
GridItem,
Divider,
Text, Text,
Box, Box,
type BoxProps, type BoxProps,
type LinkProps, type LinkProps,
Link, Link,
Flex,
} from "@chakra-ui/react"; } from "@chakra-ui/react";
import Head from "next/head"; import Head from "next/head";
import { BsGithub, BsTwitter } from "react-icons/bs"; import { BsGithub, BsPersonCircle } from "react-icons/bs";
import { useRouter } from "next/router"; import { useRouter } from "next/router";
import PublicPlaygroundWarning from "../PublicPlaygroundWarning";
import { type IconType } from "react-icons"; import { type IconType } from "react-icons";
import { RiFlaskLine } from "react-icons/ri"; import { RiFlaskLine } from "react-icons/ri";
import { useState, useEffect } from "react"; import { signIn, useSession } from "next-auth/react";
import UserMenu from "./UserMenu";
type IconLinkProps = BoxProps & LinkProps & { label: string; icon: IconType; href: string }; type IconLinkProps = BoxProps & LinkProps & { label?: string; icon: IconType };
const IconLink = ({ icon, label, href, target, color, ...props }: IconLinkProps) => { const IconLink = ({ icon, label, href, target, color, ...props }: IconLinkProps) => {
const isActive = useRouter().pathname.startsWith(href); const router = useRouter();
const isActive = href && router.pathname.startsWith(href);
return ( return (
<Box <HStack
w="full"
p={4}
color={color}
as={Link} as={Link}
href={href} href={href}
target={target} target={target}
w="full" bgColor={isActive ? "gray.200" : "transparent"}
bgColor={isActive ? "gray.300" : "transparent"} _hover={{ bgColor: "gray.200", textDecoration: "none" }}
_hover={{ bgColor: "gray.300" }}
py={4}
justifyContent="start" justifyContent="start"
cursor="pointer" cursor="pointer"
{...props} {...props}
> >
<HStack w="full" px={4} color={color}> <Icon as={icon} boxSize={6} mr={2} />
<Icon as={icon} boxSize={6} mr={2} /> <Text fontWeight="bold" fontSize="sm">
<Text fontWeight="bold">{label}</Text> {label}
</HStack> </Text>
</Box> </HStack>
); );
}; };
const Divider = () => <Box h="1px" bgColor="gray.200" />;
const NavSidebar = () => { const NavSidebar = () => {
const user = useSession().data;
return ( return (
<VStack align="stretch" bgColor="gray.100" py={2} pb={0} height="100%"> <VStack
<Link href="/" w="full" _hover={{ textDecoration: "none" }}> align="stretch"
<HStack spacing={0} pl="3"> bgColor="gray.100"
<Image src="/logo.svg" alt="" w={8} h={8} /> py={2}
<Heading size="md" p={2} pl={{ base: 16, md: 2 }}> pb={0}
OpenPipe height="100%"
</Heading> w={{ base: "56px", md: "200px" }}
</HStack> overflow="hidden"
</Link> >
<Divider /> <HStack as={Link} href="/" _hover={{ textDecoration: "none" }} spacing={0} px={4} py={2}>
<Image src="/logo.svg" alt="" boxSize={6} mr={4} />
<Heading size="md" fontFamily="inconsolata, monospace">
OpenPipe
</Heading>
</HStack>
<VStack spacing={0} align="flex-start" overflowY="auto" overflowX="hidden" flex={1}> <VStack spacing={0} align="flex-start" overflowY="auto" overflowX="hidden" flex={1}>
<IconLink icon={RiFlaskLine} label="Experiments" href="/experiments" /> {user != null && (
<>
<IconLink icon={RiFlaskLine} label="Experiments" href="/experiments" />
</>
)}
{user === null && (
<IconLink
icon={BsPersonCircle}
label="Sign In"
onClick={() => {
signIn("github").catch(console.error);
}}
/>
)}
</VStack> </VStack>
<Divider /> {user ? <UserMenu user={user} /> : <Divider />}
<VStack w="full" spacing={0} pb={2}> <VStack spacing={0} align="center">
<IconLink <Link
icon={BsGithub}
label="GitHub"
href="https://github.com/openpipe/openpipe" href="https://github.com/openpipe/openpipe"
target="_blank" target="_blank"
color="gray.500" color="gray.500"
_hover={{ color: "gray.800" }} _hover={{ color: "gray.800" }}
/> p={2}
<IconLink >
icon={BsTwitter} <Icon as={BsGithub} boxSize={6} />
label="Twitter" </Link>
href="https://twitter.com/corbtt"
target="_blank"
color="gray.500"
_hover={{ color: "gray.800" }}
/>
</VStack> </VStack>
</VStack> </VStack>
); );
@@ -105,25 +121,14 @@ export default function AppShell(props: { children: React.ReactNode; title?: str
}, []); }, []);
return ( return (
<Grid <Flex h={vh} w="100vw">
h={vh}
w="100vw"
templateColumns={{ base: "56px minmax(0, 1fr)", md: "200px minmax(0, 1fr)" }}
templateRows="max-content 1fr"
templateAreas={'"warning warning"\n"sidebar main"'}
>
<Head> <Head>
<title>{props.title ? `${props.title} | OpenPipe` : "OpenPipe"}</title> <title>{props.title ? `${props.title} | OpenPipe` : "OpenPipe"}</title>
</Head> </Head>
<GridItem area="warning"> <NavSidebar />
<PublicPlaygroundWarning /> <Box h="100%" flex={1} overflowY="auto">
</GridItem>
<GridItem area="sidebar" overflow="hidden">
<NavSidebar />
</GridItem>
<GridItem area="main" overflowY="auto">
{props.children} {props.children}
</GridItem> </Box>
</Grid> </Flex>
); );
} }

View File

@@ -0,0 +1,74 @@
import {
HStack,
Icon,
Image,
VStack,
Text,
Popover,
PopoverTrigger,
PopoverContent,
Link,
} from "@chakra-ui/react";
import { type Session } from "next-auth";
import { signOut } from "next-auth/react";
import { BsBoxArrowRight, BsChevronRight, BsPersonCircle } from "react-icons/bs";
export default function UserMenu({ user }: { user: Session }) {
const profileImage = user.user.image ? (
<Image src={user.user.image} alt="profile picture" boxSize={8} borderRadius="50%" />
) : (
<Icon as={BsPersonCircle} boxSize={6} />
);
return (
<>
<Popover placement="right">
<PopoverTrigger>
<HStack
// Weird values to make mobile look right; can clean up when we make the sidebar disappear on mobile
px={3}
spacing={3}
py={2}
borderColor={"gray.200"}
borderTopWidth={1}
borderBottomWidth={1}
cursor="pointer"
_hover={{
bgColor: "gray.200",
}}
>
{profileImage}
<VStack spacing={0} align="start" flex={1} flexShrink={1}>
<Text fontWeight="bold" fontSize="sm">
{user.user.name}
</Text>
<Text color="gray.500" fontSize="xs">
{user.user.email}
</Text>
</VStack>
<Icon as={BsChevronRight} boxSize={4} color="gray.500" />
</HStack>
</PopoverTrigger>
<PopoverContent _focusVisible={{ boxShadow: "unset", outline: "unset" }} maxW="200px">
<VStack align="stretch" spacing={0}>
{/* sign out */}
<HStack
as={Link}
onClick={() => {
signOut().catch(console.error);
}}
px={4}
py={2}
spacing={4}
color="gray.500"
fontSize="sm"
>
<Icon as={BsBoxArrowRight} boxSize={6} />
<Text>Sign out</Text>
</HStack>
</VStack>
</PopoverContent>
</Popover>
</>
);
}

View File

@@ -10,6 +10,13 @@ export const env = createEnv({
DATABASE_URL: z.string().url(), DATABASE_URL: z.string().url(),
NODE_ENV: z.enum(["development", "test", "production"]).default("development"), NODE_ENV: z.enum(["development", "test", "production"]).default("development"),
OPENAI_API_KEY: z.string().min(1), OPENAI_API_KEY: z.string().min(1),
RESTRICT_PRISMA_LOGS: z
.string()
.optional()
.default("false")
.transform((val) => val.toLowerCase() === "true"),
GITHUB_CLIENT_ID: z.string().min(1),
GITHUB_CLIENT_SECRET: z.string().min(1),
}, },
/** /**
@@ -19,11 +26,6 @@ export const env = createEnv({
*/ */
client: { client: {
NEXT_PUBLIC_POSTHOG_KEY: z.string().optional(), NEXT_PUBLIC_POSTHOG_KEY: z.string().optional(),
NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND: z
.string()
.optional()
.default("false")
.transform((val) => val.toLowerCase() === "true"),
NEXT_PUBLIC_SOCKET_URL: z.string().url().default("http://localhost:3318"), NEXT_PUBLIC_SOCKET_URL: z.string().url().default("http://localhost:3318"),
}, },
@@ -35,9 +37,11 @@ export const env = createEnv({
DATABASE_URL: process.env.DATABASE_URL, DATABASE_URL: process.env.DATABASE_URL,
NODE_ENV: process.env.NODE_ENV, NODE_ENV: process.env.NODE_ENV,
OPENAI_API_KEY: process.env.OPENAI_API_KEY, OPENAI_API_KEY: process.env.OPENAI_API_KEY,
RESTRICT_PRISMA_LOGS: process.env.RESTRICT_PRISMA_LOGS,
NEXT_PUBLIC_POSTHOG_KEY: process.env.NEXT_PUBLIC_POSTHOG_KEY, NEXT_PUBLIC_POSTHOG_KEY: process.env.NEXT_PUBLIC_POSTHOG_KEY,
NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND: process.env.NEXT_PUBLIC_IS_PUBLIC_PLAYGROUND,
NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL, NEXT_PUBLIC_SOCKET_URL: process.env.NEXT_PUBLIC_SOCKET_URL,
GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID,
GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET,
}, },
/** /**
* Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation. * Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation.

View File

@@ -0,0 +1,23 @@
import { signIn, useSession } from "next-auth/react";
import { useRouter } from "next/router";
import { useEffect } from "react";
import AppShell from "~/components/nav/AppShell";
export default function SignIn() {
const session = useSession().data;
const router = useRouter();
useEffect(() => {
if (session) {
router.push("/experiments").catch(console.error);
} else if (session === null) {
signIn("github").catch(console.error);
}
}, [session, router]);
return (
<AppShell>
<div />
</AppShell>
);
}

View File

@@ -124,6 +124,8 @@ export default function Experiment() {
); );
} }
const canModify = experiment.data?.access.canModify ?? false;
return ( return (
<AppShell title={experiment.data?.label}> <AppShell title={experiment.data?.label}>
<VStack h="full"> <VStack h="full">
@@ -143,37 +145,45 @@ export default function Experiment() {
</Link> </Link>
</BreadcrumbItem> </BreadcrumbItem>
<BreadcrumbItem isCurrentPage> <BreadcrumbItem isCurrentPage>
<Input {canModify ? (
size="sm" <Input
value={label} size="sm"
onChange={(e) => setLabel(e.target.value)} value={label}
onBlur={onSaveLabel} onChange={(e) => setLabel(e.target.value)}
borderWidth={1} onBlur={onSaveLabel}
borderColor="transparent" borderWidth={1}
fontSize={16} borderColor="transparent"
px={0} fontSize={16}
minW={{ base: 100, lg: 300 }} px={0}
flex={1} minW={{ base: 100, lg: 300 }}
_hover={{ borderColor: "gray.300" }} flex={1}
_focus={{ borderColor: "blue.500", outline: "none" }} _hover={{ borderColor: "gray.300" }}
/> _focus={{ borderColor: "blue.500", outline: "none" }}
/>
) : (
<Text fontSize={16} px={0} minW={{ base: 100, lg: 300 }} flex={1}>
{experiment.data?.label}
</Text>
)}
</BreadcrumbItem> </BreadcrumbItem>
</Breadcrumb> </Breadcrumb>
<HStack> {canModify && (
<Button <HStack>
size="sm" <Button
variant={{ base: "outline", lg: "ghost" }} size="sm"
colorScheme="gray" variant={{ base: "outline", lg: "ghost" }}
fontWeight="normal" colorScheme="gray"
onClick={openDrawer} fontWeight="normal"
> onClick={openDrawer}
<Icon as={BsGearFill} boxSize={4} color="gray.600" /> >
<Text display={{ base: "none", lg: "block" }} ml={2}> <Icon as={BsGearFill} boxSize={4} color="gray.600" />
Edit Vars & Evals <Text display={{ base: "none", lg: "block" }} ml={2}>
</Text> Edit Vars & Evals
</Button> </Text>
<DeleteButton /> </Button>
</HStack> <DeleteButton />
</HStack>
)}
</Flex> </Flex>
<SettingsDrawer /> <SettingsDrawer />
<Box w="100%" overflowX="auto" flex={1}> <Box w="100%" overflowX="auto" flex={1}>

View File

@@ -1,25 +1,50 @@
import { import {
SimpleGrid, SimpleGrid,
HStack,
Icon, Icon,
VStack, VStack,
Breadcrumb, Breadcrumb,
BreadcrumbItem, BreadcrumbItem,
Flex, Flex,
Center,
Text,
Link,
HStack,
} from "@chakra-ui/react"; } from "@chakra-ui/react";
import { RiFlaskLine } from "react-icons/ri"; import { RiFlaskLine } from "react-icons/ri";
import AppShell from "~/components/nav/AppShell"; import AppShell from "~/components/nav/AppShell";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { NewExperimentButton } from "~/components/experiments/NewExperimentButton"; import { ExperimentCard, NewExperimentCard } from "~/components/experiments/ExperimentCard";
import { ExperimentCard } from "~/components/experiments/ExperimentCard"; import { signIn, useSession } from "next-auth/react";
export default function ExperimentsPage() { export default function ExperimentsPage() {
const experiments = api.experiments.list.useQuery(); const experiments = api.experiments.list.useQuery();
const user = useSession().data;
if (user === null) {
return (
<AppShell title="Experiments">
<Center h="100%">
<Text>
<Link
onClick={() => {
signIn("github").catch(console.error);
}}
textDecor="underline"
>
Sign in
</Link>{" "}
to view or create new experiments!
</Text>
</Center>
</AppShell>
);
}
return ( return (
<AppShell> <AppShell title="Experiments">
<VStack alignItems={"flex-start"} m={4} mt={1}> <VStack alignItems={"flex-start"} px={4} py={2}>
<HStack w="full" justifyContent="space-between" mb={4}> <HStack minH={8} align="center">
<Breadcrumb flex={1}> <Breadcrumb flex={1}>
<BreadcrumbItem> <BreadcrumbItem>
<Flex alignItems="center"> <Flex alignItems="center">
@@ -27,9 +52,9 @@ export default function ExperimentsPage() {
</Flex> </Flex>
</BreadcrumbItem> </BreadcrumbItem>
</Breadcrumb> </Breadcrumb>
<NewExperimentButton mr={4} borderRadius={8} />
</HStack> </HStack>
<SimpleGrid w="full" columns={{ base: 1, md: 2, lg: 3, xl: 4 }} spacing={8} p="4"> <SimpleGrid w="full" columns={{ base: 1, md: 2, lg: 3, xl: 4 }} spacing={8} p="4">
<NewExperimentCard />
{experiments?.data?.map((exp) => <ExperimentCard key={exp.id} exp={exp} />)} {experiments?.data?.map((exp) => <ExperimentCard key={exp.id} exp={exp} />)}
</SimpleGrid> </SimpleGrid>
</VStack> </VStack>

View File

@@ -1,20 +1,25 @@
import { EvalType } from "@prisma/client"; import { EvalType } from "@prisma/client";
import { z } from "zod"; import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { runAllEvals } from "~/server/utils/evaluations"; import { runAllEvals } from "~/server/utils/evaluations";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const evaluationsRouter = createTRPCRouter({ export const evaluationsRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { list: publicProcedure
return await prisma.evaluation.findMany({ .input(z.object({ experimentId: z.string() }))
where: { .query(async ({ input, ctx }) => {
experimentId: input.experimentId, await requireCanViewExperiment(input.experimentId, ctx);
},
orderBy: { createdAt: "asc" },
});
}),
create: publicProcedure return await prisma.evaluation.findMany({
where: {
experimentId: input.experimentId,
},
orderBy: { createdAt: "asc" },
});
}),
create: protectedProcedure
.input( .input(
z.object({ z.object({
experimentId: z.string(), experimentId: z.string(),
@@ -23,7 +28,9 @@ export const evaluationsRouter = createTRPCRouter({
evalType: z.nativeEnum(EvalType), evalType: z.nativeEnum(EvalType),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
await requireCanModifyExperiment(input.experimentId, ctx);
await prisma.evaluation.create({ await prisma.evaluation.create({
data: { data: {
experimentId: input.experimentId, experimentId: input.experimentId,
@@ -38,7 +45,7 @@ export const evaluationsRouter = createTRPCRouter({
await runAllEvals(input.experimentId); await runAllEvals(input.experimentId);
}), }),
update: publicProcedure update: protectedProcedure
.input( .input(
z.object({ z.object({
id: z.string(), id: z.string(),
@@ -49,7 +56,12 @@ export const evaluationsRouter = createTRPCRouter({
}), }),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
const { experimentId } = await prisma.evaluation.findUniqueOrThrow({
where: { id: input.id },
});
await requireCanModifyExperiment(experimentId, ctx);
const evaluation = await prisma.evaluation.update({ const evaluation = await prisma.evaluation.update({
where: { id: input.id }, where: { id: input.id },
data: { data: {
@@ -69,9 +81,16 @@ export const evaluationsRouter = createTRPCRouter({
await runAllEvals(evaluation.experimentId); await runAllEvals(evaluation.experimentId);
}), }),
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => { delete: protectedProcedure
await prisma.evaluation.delete({ .input(z.object({ id: z.string() }))
where: { id: input.id }, .mutation(async ({ input, ctx }) => {
}); const { experimentId } = await prisma.evaluation.findUniqueOrThrow({
}), where: { id: input.id },
});
await requireCanModifyExperiment(experimentId, ctx);
await prisma.evaluation.delete({
where: { id: input.id },
});
}),
}); });

View File

@@ -1,14 +1,31 @@
import { z } from "zod"; import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import dedent from "dedent"; import dedent from "dedent";
import { generateNewCell } from "~/server/utils/generateNewCell"; import { generateNewCell } from "~/server/utils/generateNewCell";
import {
canModifyExperiment,
requireCanModifyExperiment,
requireCanViewExperiment,
requireNothing,
} from "~/utils/accessControl";
import userOrg from "~/server/utils/userOrg";
export const experimentsRouter = createTRPCRouter({ export const experimentsRouter = createTRPCRouter({
list: publicProcedure.query(async () => { list: protectedProcedure.query(async ({ ctx }) => {
// Anyone can list experiments
requireNothing(ctx);
const experiments = await prisma.experiment.findMany({ const experiments = await prisma.experiment.findMany({
where: {
organization: {
OrganizationUser: {
some: { userId: ctx.session.user.id },
},
},
},
orderBy: { orderBy: {
sortIndex: "asc", sortIndex: "desc",
}, },
}); });
@@ -40,15 +57,29 @@ export const experimentsRouter = createTRPCRouter({
return experimentsWithCounts; return experimentsWithCounts;
}), }),
get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input }) => { get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
return await prisma.experiment.findFirst({ await requireCanViewExperiment(input.id, ctx);
where: { const experiment = await prisma.experiment.findFirstOrThrow({
id: input.id, where: { id: input.id },
},
}); });
const canModify = ctx.session?.user.id
? await canModifyExperiment(experiment.id, ctx.session?.user.id)
: false;
return {
...experiment,
access: {
canView: true,
canModify,
},
};
}), }),
create: publicProcedure.input(z.object({})).mutation(async () => { create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => {
// Anyone can create an experiment
requireNothing(ctx);
const maxSortIndex = const maxSortIndex =
( (
await prisma.experiment.aggregate({ await prisma.experiment.aggregate({
@@ -62,6 +93,7 @@ export const experimentsRouter = createTRPCRouter({
data: { data: {
sortIndex: maxSortIndex + 1, sortIndex: maxSortIndex + 1,
label: `Experiment ${maxSortIndex + 1}`, label: `Experiment ${maxSortIndex + 1}`,
organizationId: (await userOrg(ctx.session.user.id)).id,
}, },
}); });
@@ -117,9 +149,10 @@ export const experimentsRouter = createTRPCRouter({
return exp; return exp;
}), }),
update: publicProcedure update: protectedProcedure
.input(z.object({ id: z.string(), updates: z.object({ label: z.string() }) })) .input(z.object({ id: z.string(), updates: z.object({ label: z.string() }) }))
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
await requireCanModifyExperiment(input.id, ctx);
return await prisma.experiment.update({ return await prisma.experiment.update({
where: { where: {
id: input.id, id: input.id,
@@ -130,11 +163,15 @@ export const experimentsRouter = createTRPCRouter({
}); });
}), }),
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => { delete: protectedProcedure
await prisma.experiment.delete({ .input(z.object({ id: z.string() }))
where: { .mutation(async ({ input, ctx }) => {
id: input.id, await requireCanModifyExperiment(input.id, ctx);
},
}); await prisma.experiment.delete({
}), where: {
id: input.id,
},
});
}),
}); });

View File

@@ -1,152 +1,174 @@
import dedent from "dedent";
import { isObject } from "lodash-es"; import { isObject } from "lodash-es";
import { z } from "zod"; import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { generateNewCell } from "~/server/utils/generateNewCell"; import { generateNewCell } from "~/server/utils/generateNewCell";
import { OpenAIChatModel } from "~/server/types"; import { OpenAIChatModel, type SupportedModel } from "~/server/types";
import { constructPrompt } from "~/server/utils/constructPrompt"; import { constructPrompt } from "~/server/utils/constructPrompt";
import userError from "~/server/utils/error"; import userError from "~/server/utils/error";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated"; import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
import { calculateTokenCost } from "~/utils/calculateTokenCost"; import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
import { type PromptVariant } from "@prisma/client";
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const promptVariantsRouter = createTRPCRouter({ export const promptVariantsRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { list: publicProcedure
return await prisma.promptVariant.findMany({ .input(z.object({ experimentId: z.string() }))
where: { .query(async ({ input, ctx }) => {
experimentId: input.experimentId, await requireCanViewExperiment(input.experimentId, ctx);
visible: true,
},
orderBy: { sortIndex: "asc" },
});
}),
stats: publicProcedure.input(z.object({ variantId: z.string() })).query(async ({ input }) => { return await prisma.promptVariant.findMany({
const variant = await prisma.promptVariant.findUnique({ where: {
where: { experimentId: input.experimentId,
id: input.variantId, visible: true,
}, },
}); orderBy: { sortIndex: "asc" },
});
}),
if (!variant) { stats: publicProcedure
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`); .input(z.object({ variantId: z.string() }))
} .query(async ({ input, ctx }) => {
const variant = await prisma.promptVariant.findUnique({
where: {
id: input.variantId,
},
});
const outputEvals = await prisma.outputEvaluation.groupBy({ if (!variant) {
by: ["evaluationId"], throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
_sum: { }
result: true,
}, await requireCanViewExperiment(variant.experimentId, ctx);
_count: {
id: true, const outputEvals = await prisma.outputEvaluation.groupBy({
}, by: ["evaluationId"],
where: { _sum: {
modelOutput: { result: true,
scenarioVariantCell: { },
promptVariant: { _count: {
id: input.variantId, id: true,
visible: true, },
where: {
modelOutput: {
scenarioVariantCell: {
promptVariant: {
id: input.variantId,
visible: true,
},
testScenario: {
visible: true,
},
}, },
},
},
});
const evals = await prisma.evaluation.findMany({
where: {
experimentId: variant.experimentId,
},
});
const evalResults = evals.map((evalItem) => {
const evalResult = outputEvals.find(
(outputEval) => outputEval.evaluationId === evalItem.id,
);
return {
id: evalItem.id,
label: evalItem.label,
passCount: evalResult?._sum?.result ?? 0,
totalCount: evalResult?._count?.id ?? 1,
};
});
const scenarioCount = await prisma.testScenario.count({
where: {
experimentId: variant.experimentId,
visible: true,
},
});
const outputCount = await prisma.scenarioVariantCell.count({
where: {
promptVariantId: input.variantId,
testScenario: { visible: true },
modelOutput: {
is: {},
},
},
});
const overallTokens = await prisma.modelOutput.aggregate({
where: {
scenarioVariantCell: {
promptVariantId: input.variantId,
testScenario: { testScenario: {
visible: true, visible: true,
}, },
}, },
}, },
}, _sum: {
}); cost: true,
promptTokens: true,
const evals = await prisma.evaluation.findMany({ completionTokens: true,
where: {
experimentId: variant.experimentId,
},
});
const evalResults = evals.map((evalItem) => {
const evalResult = outputEvals.find((outputEval) => outputEval.evaluationId === evalItem.id);
return {
id: evalItem.id,
label: evalItem.label,
passCount: evalResult?._sum?.result ?? 0,
totalCount: evalResult?._count?.id ?? 1,
};
});
const scenarioCount = await prisma.testScenario.count({
where: {
experimentId: variant.experimentId,
visible: true,
},
});
const outputCount = await prisma.scenarioVariantCell.count({
where: {
promptVariantId: input.variantId,
testScenario: { visible: true },
modelOutput: {
is: {},
}, },
}, });
});
const overallTokens = await prisma.modelOutput.aggregate({ const promptTokens = overallTokens._sum?.promptTokens ?? 0;
where: { const completionTokens = overallTokens._sum?.completionTokens ?? 0;
scenarioVariantCell: {
const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({
where: {
promptVariantId: input.variantId, promptVariantId: input.variantId,
testScenario: { testScenario: { visible: true },
visible: true, // Check if is PENDING or IN_PROGRESS
retrievalStatus: {
in: ["PENDING", "IN_PROGRESS"],
}, },
}, },
}, }));
_sum: {
promptTokens: true,
completionTokens: true,
},
});
const promptTokens = overallTokens._sum?.promptTokens ?? 0; return {
const overallPromptCost = calculateTokenCost(variant.model, promptTokens); evalResults,
const completionTokens = overallTokens._sum?.completionTokens ?? 0; promptTokens,
const overallCompletionCost = calculateTokenCost(variant.model, completionTokens, true); completionTokens,
overallCost: overallTokens._sum?.cost ?? 0,
scenarioCount,
outputCount,
awaitingRetrievals,
};
}),
const overallCost = overallPromptCost + overallCompletionCost; create: protectedProcedure
const awaitingRetrievals = !!(await prisma.scenarioVariantCell.findFirst({
where: {
promptVariantId: input.variantId,
testScenario: { visible: true },
// Check if is PENDING or IN_PROGRESS
retrievalStatus: {
in: ["PENDING", "IN_PROGRESS"],
},
},
}));
return {
evalResults,
promptTokens,
completionTokens,
overallCost,
scenarioCount,
outputCount,
awaitingRetrievals,
};
}),
create: publicProcedure
.input( .input(
z.object({ z.object({
experimentId: z.string(), experimentId: z.string(),
variantId: z.string().optional(),
newModel: z.string().optional(),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
const lastVariant = await prisma.promptVariant.findFirst({ await requireCanViewExperiment(input.experimentId, ctx);
where: {
experimentId: input.experimentId, let originalVariant: PromptVariant | null = null;
visible: true, if (input.variantId) {
}, originalVariant = await prisma.promptVariant.findUnique({
orderBy: { where: {
sortIndex: "desc", id: input.variantId,
}, },
}); });
} else {
originalVariant = await prisma.promptVariant.findFirst({
where: {
experimentId: input.experimentId,
visible: true,
},
orderBy: {
sortIndex: "desc",
},
});
}
const largestSortIndex = const largestSortIndex =
( (
@@ -160,24 +182,23 @@ export const promptVariantsRouter = createTRPCRouter({
}) })
)._max?.sortIndex ?? 0; )._max?.sortIndex ?? 0;
const newVariantLabel =
input.variantId && originalVariant
? `${originalVariant?.label} Copy`
: `Prompt Variant ${largestSortIndex + 2}`;
const newConstructFn = await deriveNewConstructFn(
originalVariant,
input.newModel as SupportedModel,
);
const createNewVariantAction = prisma.promptVariant.create({ const createNewVariantAction = prisma.promptVariant.create({
data: { data: {
experimentId: input.experimentId, experimentId: input.experimentId,
label: `Prompt Variant ${largestSortIndex + 2}`, label: newVariantLabel,
sortIndex: (lastVariant?.sortIndex ?? 0) + 1, sortIndex: (originalVariant?.sortIndex ?? 0) + 1,
constructFn: constructFn: newConstructFn,
lastVariant?.constructFn ?? model: originalVariant?.model ?? "gpt-3.5-turbo",
dedent`
prompt = {
model: "gpt-3.5-turbo",
messages: [
{
role: "system",
content: "Return 'Hello, world!'",
}
]
}`,
model: lastVariant?.model ?? "gpt-3.5-turbo",
}, },
}); });
@@ -186,6 +207,11 @@ export const promptVariantsRouter = createTRPCRouter({
recordExperimentUpdated(input.experimentId), recordExperimentUpdated(input.experimentId),
]); ]);
if (originalVariant) {
// Insert new variant to right of original variant
await reorderPromptVariants(newVariant.id, originalVariant.id, true);
}
const scenarios = await prisma.testScenario.findMany({ const scenarios = await prisma.testScenario.findMany({
where: { where: {
experimentId: input.experimentId, experimentId: input.experimentId,
@@ -200,7 +226,7 @@ export const promptVariantsRouter = createTRPCRouter({
return newVariant; return newVariant;
}), }),
update: publicProcedure update: protectedProcedure
.input( .input(
z.object({ z.object({
id: z.string(), id: z.string(),
@@ -209,7 +235,7 @@ export const promptVariantsRouter = createTRPCRouter({
}), }),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
const existing = await prisma.promptVariant.findUnique({ const existing = await prisma.promptVariant.findUnique({
where: { where: {
id: input.id, id: input.id,
@@ -220,6 +246,8 @@ export const promptVariantsRouter = createTRPCRouter({
throw new Error(`Prompt Variant with id ${input.id} does not exist`); throw new Error(`Prompt Variant with id ${input.id} does not exist`);
} }
await requireCanModifyExperiment(existing.experimentId, ctx);
const updatePromptVariantAction = prisma.promptVariant.update({ const updatePromptVariantAction = prisma.promptVariant.update({
where: { where: {
id: input.id, id: input.id,
@@ -235,13 +263,18 @@ export const promptVariantsRouter = createTRPCRouter({
return updatedPromptVariant; return updatedPromptVariant;
}), }),
hide: publicProcedure hide: protectedProcedure
.input( .input(
z.object({ z.object({
id: z.string(), id: z.string(),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
const { experimentId } = await prisma.promptVariant.findUniqueOrThrow({
where: { id: input.id },
});
await requireCanModifyExperiment(experimentId, ctx);
const updatedPromptVariant = await prisma.promptVariant.update({ const updatedPromptVariant = await prisma.promptVariant.update({
where: { id: input.id }, where: { id: input.id },
data: { visible: false, experiment: { update: { updatedAt: new Date() } } }, data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
@@ -250,19 +283,50 @@ export const promptVariantsRouter = createTRPCRouter({
return updatedPromptVariant; return updatedPromptVariant;
}), }),
replaceVariant: publicProcedure getRefinedPromptFn: protectedProcedure
.input(
z.object({
id: z.string(),
instructions: z.string(),
}),
)
.mutation(async ({ input, ctx }) => {
const existing = await prisma.promptVariant.findUniqueOrThrow({
where: {
id: input.id,
},
});
await requireCanModifyExperiment(existing.experimentId, ctx);
const constructedPrompt = await constructPrompt({ constructFn: existing.constructFn }, null);
const promptConstructionFn = await deriveNewConstructFn(
existing,
// @ts-expect-error TODO clean this up
constructedPrompt?.model as SupportedModel,
input.instructions,
);
// TODO: Validate promptConstructionFn
// TODO: Record in some sort of history
return promptConstructionFn;
}),
replaceVariant: protectedProcedure
.input( .input(
z.object({ z.object({
id: z.string(), id: z.string(),
constructFn: z.string(), constructFn: z.string(),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
const existing = await prisma.promptVariant.findUnique({ const existing = await prisma.promptVariant.findUniqueOrThrow({
where: { where: {
id: input.id, id: input.id,
}, },
}); });
await requireCanModifyExperiment(existing.experimentId, ctx);
if (!existing) { if (!existing) {
throw new Error(`Prompt Variant with id ${input.id} does not exist`); throw new Error(`Prompt Variant with id ${input.id} does not exist`);
@@ -330,72 +394,19 @@ export const promptVariantsRouter = createTRPCRouter({
return { status: "ok" } as const; return { status: "ok" } as const;
}), }),
reorder: publicProcedure reorder: protectedProcedure
.input( .input(
z.object({ z.object({
draggedId: z.string(), draggedId: z.string(),
droppedId: z.string(), droppedId: z.string(),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
const dragged = await prisma.promptVariant.findUnique({ const { experimentId } = await prisma.promptVariant.findUniqueOrThrow({
where: { where: { id: input.draggedId },
id: input.draggedId,
},
}); });
await requireCanModifyExperiment(experimentId, ctx);
const dropped = await prisma.promptVariant.findUnique({ await reorderPromptVariants(input.draggedId, input.droppedId);
where: {
id: input.droppedId,
},
});
if (!dragged || !dropped || dragged.experimentId !== dropped.experimentId) {
throw new Error(
`Prompt Variant with id ${input.draggedId} or ${input.droppedId} does not exist`,
);
}
const visibleItems = await prisma.promptVariant.findMany({
where: {
experimentId: dragged.experimentId,
visible: true,
},
orderBy: {
sortIndex: "asc",
},
});
// Remove the dragged item from its current position
const orderedItems = visibleItems.filter((item) => item.id !== dragged.id);
// Find the index of the dragged item and the dropped item
const dragIndex = visibleItems.findIndex((item) => item.id === dragged.id);
const dropIndex = visibleItems.findIndex((item) => item.id === dropped.id);
// Determine the new index for the dragged item
let newIndex;
if (dragIndex < dropIndex) {
newIndex = dropIndex + 1; // Insert after the dropped item
} else {
newIndex = dropIndex; // Insert before the dropped item
}
// Insert the dragged item at the new position
orderedItems.splice(newIndex, 0, dragged);
// Now, we need to update all the items with their new sortIndex
await prisma.$transaction(
orderedItems.map((item, index) => {
return prisma.promptVariant.update({
where: {
id: item.id,
},
data: {
sortIndex: index,
},
});
}),
);
}), }),
}); });

View File

@@ -1,8 +1,9 @@
import { z } from "zod"; import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { generateNewCell } from "~/server/utils/generateNewCell"; import { generateNewCell } from "~/server/utils/generateNewCell";
import { queueLLMRetrievalTask } from "~/server/utils/queueLLMRetrievalTask"; import { queueLLMRetrievalTask } from "~/server/utils/queueLLMRetrievalTask";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const scenarioVariantCellsRouter = createTRPCRouter({ export const scenarioVariantCellsRouter = createTRPCRouter({
get: publicProcedure get: publicProcedure
@@ -12,7 +13,12 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
variantId: z.string(), variantId: z.string(),
}), }),
) )
.query(async ({ input }) => { .query(async ({ input, ctx }) => {
const { experimentId } = await prisma.testScenario.findUniqueOrThrow({
where: { id: input.scenarioId },
});
await requireCanViewExperiment(experimentId, ctx);
return await prisma.scenarioVariantCell.findUnique({ return await prisma.scenarioVariantCell.findUnique({
where: { where: {
promptVariantId_testScenarioId: { promptVariantId_testScenarioId: {
@@ -35,14 +41,20 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
}, },
}); });
}), }),
forceRefetch: publicProcedure forceRefetch: protectedProcedure
.input( .input(
z.object({ z.object({
scenarioId: z.string(), scenarioId: z.string(),
variantId: z.string(), variantId: z.string(),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
const { experimentId } = await prisma.testScenario.findUniqueOrThrow({
where: { id: input.scenarioId },
});
await requireCanModifyExperiment(experimentId, ctx);
const cell = await prisma.scenarioVariantCell.findUnique({ const cell = await prisma.scenarioVariantCell.findUnique({
where: { where: {
promptVariantId_testScenarioId: { promptVariantId_testScenarioId: {

View File

@@ -1,32 +1,39 @@
import { z } from "zod"; import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { autogenerateScenarioValues } from "../autogen"; import { autogenerateScenarioValues } from "../autogen";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated"; import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
import { runAllEvals } from "~/server/utils/evaluations"; import { runAllEvals } from "~/server/utils/evaluations";
import { generateNewCell } from "~/server/utils/generateNewCell"; import { generateNewCell } from "~/server/utils/generateNewCell";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const scenariosRouter = createTRPCRouter({ export const scenariosRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { list: publicProcedure
return await prisma.testScenario.findMany({ .input(z.object({ experimentId: z.string() }))
where: { .query(async ({ input, ctx }) => {
experimentId: input.experimentId, await requireCanViewExperiment(input.experimentId, ctx);
visible: true,
},
orderBy: {
sortIndex: "asc",
},
});
}),
create: publicProcedure return await prisma.testScenario.findMany({
where: {
experimentId: input.experimentId,
visible: true,
},
orderBy: {
sortIndex: "asc",
},
});
}),
create: protectedProcedure
.input( .input(
z.object({ z.object({
experimentId: z.string(), experimentId: z.string(),
autogenerate: z.boolean().optional(), autogenerate: z.boolean().optional(),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
await requireCanModifyExperiment(input.experimentId, ctx);
const maxSortIndex = const maxSortIndex =
( (
await prisma.testScenario.aggregate({ await prisma.testScenario.aggregate({
@@ -66,7 +73,14 @@ export const scenariosRouter = createTRPCRouter({
} }
}), }),
hide: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => { hide: protectedProcedure.input(z.object({ id: z.string() })).mutation(async ({ input, ctx }) => {
const experimentId = (
await prisma.testScenario.findUniqueOrThrow({
where: { id: input.id },
})
).experimentId;
await requireCanModifyExperiment(experimentId, ctx);
const hiddenScenario = await prisma.testScenario.update({ const hiddenScenario = await prisma.testScenario.update({
where: { id: input.id }, where: { id: input.id },
data: { visible: false, experiment: { update: { updatedAt: new Date() } } }, data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
@@ -78,14 +92,14 @@ export const scenariosRouter = createTRPCRouter({
return hiddenScenario; return hiddenScenario;
}), }),
reorder: publicProcedure reorder: protectedProcedure
.input( .input(
z.object({ z.object({
draggedId: z.string(), draggedId: z.string(),
droppedId: z.string(), droppedId: z.string(),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
const dragged = await prisma.testScenario.findUnique({ const dragged = await prisma.testScenario.findUnique({
where: { where: {
id: input.draggedId, id: input.draggedId,
@@ -104,6 +118,8 @@ export const scenariosRouter = createTRPCRouter({
); );
} }
await requireCanModifyExperiment(dragged.experimentId, ctx);
const visibleItems = await prisma.testScenario.findMany({ const visibleItems = await prisma.testScenario.findMany({
where: { where: {
experimentId: dragged.experimentId, experimentId: dragged.experimentId,
@@ -147,14 +163,14 @@ export const scenariosRouter = createTRPCRouter({
); );
}), }),
replaceWithValues: publicProcedure replaceWithValues: protectedProcedure
.input( .input(
z.object({ z.object({
id: z.string(), id: z.string(),
values: z.record(z.string()), values: z.record(z.string()),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
const existing = await prisma.testScenario.findUnique({ const existing = await prisma.testScenario.findUnique({
where: { where: {
id: input.id, id: input.id,
@@ -165,6 +181,8 @@ export const scenariosRouter = createTRPCRouter({
throw new Error(`Scenario with id ${input.id} does not exist`); throw new Error(`Scenario with id ${input.id} does not exist`);
} }
await requireCanModifyExperiment(existing.experimentId, ctx);
const newScenario = await prisma.testScenario.create({ const newScenario = await prisma.testScenario.create({
data: { data: {
experimentId: existing.experimentId, experimentId: existing.experimentId,

View File

@@ -1,11 +1,14 @@
import { z } from "zod"; import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const templateVarsRouter = createTRPCRouter({ export const templateVarsRouter = createTRPCRouter({
create: publicProcedure create: protectedProcedure
.input(z.object({ experimentId: z.string(), label: z.string() })) .input(z.object({ experimentId: z.string(), label: z.string() }))
.mutation(async ({ input }) => { .mutation(async ({ input, ctx }) => {
await requireCanModifyExperiment(input.experimentId, ctx);
await prisma.templateVariable.create({ await prisma.templateVariable.create({
data: { data: {
experimentId: input.experimentId, experimentId: input.experimentId,
@@ -14,22 +17,33 @@ export const templateVarsRouter = createTRPCRouter({
}); });
}), }),
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => { delete: protectedProcedure
await prisma.templateVariable.delete({ where: { id: input.id } }); .input(z.object({ id: z.string() }))
}), .mutation(async ({ input, ctx }) => {
const { experimentId } = await prisma.templateVariable.findUniqueOrThrow({
where: { id: input.id },
});
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { await requireCanModifyExperiment(experimentId, ctx);
return await prisma.templateVariable.findMany({
where: { await prisma.templateVariable.delete({ where: { id: input.id } });
experimentId: input.experimentId, }),
},
orderBy: { list: publicProcedure
createdAt: "asc", .input(z.object({ experimentId: z.string() }))
}, .query(async ({ input, ctx }) => {
select: { await requireCanViewExperiment(input.experimentId, ctx);
id: true, return await prisma.templateVariable.findMany({
label: true, where: {
}, experimentId: input.experimentId,
}); },
}), orderBy: {
createdAt: "asc",
},
select: {
id: true,
label: true,
},
});
}),
}); });

View File

@@ -27,6 +27,9 @@ type CreateContextOptions = {
session: Session | null; session: Session | null;
}; };
// eslint-disable-next-line @typescript-eslint/no-empty-function
const noOp = () => {};
/** /**
* This helper generates the "internals" for a tRPC context. If you need to use it, you can export * This helper generates the "internals" for a tRPC context. If you need to use it, you can export
* it from here. * it from here.
@@ -41,6 +44,7 @@ const createInnerTRPCContext = (opts: CreateContextOptions) => {
return { return {
session: opts.session, session: opts.session,
prisma, prisma,
markAccessControlRun: noOp,
}; };
}; };
@@ -69,6 +73,8 @@ export const createTRPCContext = async (opts: CreateNextContextOptions) => {
* errors on the backend. * errors on the backend.
*/ */
export type TRPCContext = Awaited<ReturnType<typeof createTRPCContext>>;
const t = initTRPC.context<typeof createTRPCContext>().create({ const t = initTRPC.context<typeof createTRPCContext>().create({
transformer: superjson, transformer: superjson,
errorFormatter({ shape, error }) { errorFormatter({ shape, error }) {
@@ -106,16 +112,29 @@ export const createTRPCRouter = t.router;
export const publicProcedure = t.procedure; export const publicProcedure = t.procedure;
/** Reusable middleware that enforces users are logged in before running the procedure. */ /** Reusable middleware that enforces users are logged in before running the procedure. */
const enforceUserIsAuthed = t.middleware(({ ctx, next }) => { const enforceUserIsAuthed = t.middleware(async ({ ctx, next }) => {
if (!ctx.session || !ctx.session.user) { if (!ctx.session || !ctx.session.user) {
throw new TRPCError({ code: "UNAUTHORIZED" }); throw new TRPCError({ code: "UNAUTHORIZED" });
} }
return next({
let accessControlRun = false;
const resp = await next({
ctx: { ctx: {
// infers the `session` as non-nullable // infers the `session` as non-nullable
session: { ...ctx.session, user: ctx.session.user }, session: { ...ctx.session, user: ctx.session.user },
markAccessControlRun: () => {
accessControlRun = true;
},
}, },
}); });
if (!accessControlRun)
throw new TRPCError({
code: "INTERNAL_SERVER_ERROR",
message:
"Protected routes must perform access control checks then explicitly invoke the `ctx.markAccessControlRun()` function to ensure we don't forget access control on a route.",
});
return resp;
}); });
/** /**

View File

@@ -2,6 +2,8 @@ import { PrismaAdapter } from "@next-auth/prisma-adapter";
import { type GetServerSidePropsContext } from "next"; import { type GetServerSidePropsContext } from "next";
import { getServerSession, type NextAuthOptions, type DefaultSession } from "next-auth"; import { getServerSession, type NextAuthOptions, type DefaultSession } from "next-auth";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import GitHubProvider from "next-auth/providers/github";
import { env } from "~/env.mjs";
/** /**
* Module augmentation for `next-auth` types. Allows us to add custom properties to the `session` * Module augmentation for `next-auth` types. Allows us to add custom properties to the `session`
@@ -41,20 +43,15 @@ export const authOptions: NextAuthOptions = {
}, },
adapter: PrismaAdapter(prisma), adapter: PrismaAdapter(prisma),
providers: [ providers: [
// DiscordProvider({ GitHubProvider({
// clientId: env.DISCORD_CLIENT_ID, clientId: env.GITHUB_CLIENT_ID,
// clientSecret: env.DISCORD_CLIENT_SECRET, clientSecret: env.GITHUB_CLIENT_SECRET,
// }), }),
/**
* ...add more providers here.
*
* Most other providers require a bit more work than the Discord provider. For example, the
* GitHub provider requires you to add the `refresh_token_expires_in` field to the Account
* model. Refer to the NextAuth.js docs for the provider you want to use. Example:
*
* @see https://next-auth.js.org/providers/github
*/
], ],
theme: {
logo: "/logo.svg",
brandColor: "#ff5733",
},
}; };
/** /**

View File

@@ -8,7 +8,10 @@ const globalForPrisma = globalThis as unknown as {
export const prisma = export const prisma =
globalForPrisma.prisma ?? globalForPrisma.prisma ??
new PrismaClient({ new PrismaClient({
log: env.NODE_ENV === "development" ? ["query", "error", "warn"] : ["error"], log:
env.NODE_ENV === "development" && !env.RESTRICT_PRISMA_LOGS
? ["query", "error", "warn"]
: ["error"],
}); });
if (env.NODE_ENV !== "production") globalForPrisma.prisma = prisma; if (env.NODE_ENV !== "production") globalForPrisma.prisma = prisma;

77
src/server/modelStats.ts Normal file
View File

@@ -0,0 +1,77 @@
import { type SupportedModel } from "./types";
interface ModelStats {
contextLength: number;
promptTokenPrice: number;
completionTokenPrice: number;
speed: "fast" | "medium" | "slow";
provider: "OpenAI";
learnMoreUrl: string;
}
export const modelStats: Record<SupportedModel, ModelStats> = {
"gpt-4": {
contextLength: 8192,
promptTokenPrice: 0.00003,
completionTokenPrice: 0.00006,
speed: "medium",
provider: "OpenAI",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-4-0613": {
contextLength: 8192,
promptTokenPrice: 0.00003,
completionTokenPrice: 0.00006,
speed: "medium",
provider: "OpenAI",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-4-32k": {
contextLength: 32768,
promptTokenPrice: 0.00006,
completionTokenPrice: 0.00012,
speed: "medium",
provider: "OpenAI",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-4-32k-0613": {
contextLength: 32768,
promptTokenPrice: 0.00006,
completionTokenPrice: 0.00012,
speed: "medium",
provider: "OpenAI",
learnMoreUrl: "https://openai.com/gpt-4",
},
"gpt-3.5-turbo": {
contextLength: 4096,
promptTokenPrice: 0.0000015,
completionTokenPrice: 0.000002,
speed: "fast",
provider: "OpenAI",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-0613": {
contextLength: 4096,
promptTokenPrice: 0.0000015,
completionTokenPrice: 0.000002,
speed: "fast",
provider: "OpenAI",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-16k": {
contextLength: 16384,
promptTokenPrice: 0.000003,
completionTokenPrice: 0.000004,
speed: "fast",
provider: "OpenAI",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
"gpt-3.5-turbo-16k-0613": {
contextLength: 16384,
promptTokenPrice: 0.000003,
completionTokenPrice: 0.000004,
speed: "fast",
provider: "OpenAI",
learnMoreUrl: "https://platform.openai.com/docs/guides/gpt/chat-completions-api",
},
};

View File

@@ -1,47 +0,0 @@
import { type Prisma } from "@prisma/client";
import { prisma } from "../db";
async function migrateScenarioVariantOutputData() {
// Get all ScenarioVariantCells
const cells = await prisma.scenarioVariantCell.findMany({ include: { modelOutput: true } });
console.log(`Found ${cells.length} records`);
let updatedCount = 0;
// Loop through all scenarioVariants
for (const cell of cells) {
// Create a new ModelOutput for each ScenarioVariant with an existing output
if (cell.output && !cell.modelOutput) {
updatedCount++;
await prisma.modelOutput.create({
data: {
scenarioVariantCellId: cell.id,
inputHash: cell.inputHash || "",
output: cell.output as Prisma.InputJsonValue,
timeToComplete: cell.timeToComplete ?? undefined,
promptTokens: cell.promptTokens,
completionTokens: cell.completionTokens,
createdAt: cell.createdAt,
updatedAt: cell.updatedAt,
},
});
} else if (cell.errorMessage && cell.retrievalStatus === "COMPLETE") {
updatedCount++;
await prisma.scenarioVariantCell.update({
where: { id: cell.id },
data: {
retrievalStatus: "ERROR",
},
});
}
}
console.log("Data migration completed");
console.log(`Updated ${updatedCount} records`);
}
// Execute the function
migrateScenarioVariantOutputData().catch((error) => {
console.error("An error occurred while migrating data: ", error);
process.exit(1);
});

View File

@@ -0,0 +1,26 @@
// /* eslint-disable */
// import "dotenv/config";
// import Replicate from "replicate";
// const replicate = new Replicate({
// auth: process.env.REPLICATE_API_TOKEN || "",
// });
// console.log("going to run");
// const prediction = await replicate.predictions.create({
// version: "e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48",
// input: {
// prompt: "...",
// },
// });
// console.log("waiting");
// setInterval(() => {
// replicate.predictions.get(prediction.id).then((prediction) => {
// console.log(prediction.output);
// });
// }, 500);
// // const output = await replicate.wait(prediction, {});
// // console.log(output);

View File

@@ -1,7 +1,7 @@
import crypto from "crypto"; import crypto from "crypto";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import defineTask from "./defineTask"; import defineTask from "./defineTask";
import { type CompletionResponse, getCompletion } from "../utils/getCompletion"; import { type CompletionResponse, getOpenAIChatCompletion } from "../utils/getCompletion";
import { type JSONSerializable } from "../types"; import { type JSONSerializable } from "../types";
import { sleep } from "../utils/sleep"; import { sleep } from "../utils/sleep";
import { shouldStream } from "../utils/shouldStream"; import { shouldStream } from "../utils/shouldStream";
@@ -29,7 +29,10 @@ const getCompletionWithRetries = async (
let modelResponse: CompletionResponse | null = null; let modelResponse: CompletionResponse | null = null;
try { try {
for (let i = 0; i < MAX_AUTO_RETRIES; i++) { for (let i = 0; i < MAX_AUTO_RETRIES; i++) {
modelResponse = await getCompletion(payload as unknown as CompletionCreateParams, channel); modelResponse = await getOpenAIChatCompletion(
payload as unknown as CompletionCreateParams,
channel,
);
if (modelResponse.statusCode !== 429 || i === MAX_AUTO_RETRIES - 1) { if (modelResponse.statusCode !== 429 || i === MAX_AUTO_RETRIES - 1) {
return modelResponse; return modelResponse;
} }
@@ -50,7 +53,7 @@ const getCompletionWithRetries = async (
return { return {
statusCode: modelResponse?.statusCode ?? 500, statusCode: modelResponse?.statusCode ?? 500,
errorMessage: modelResponse?.errorMessage ?? (error as Error).message, errorMessage: modelResponse?.errorMessage ?? (error as Error).message,
output: null as unknown as Prisma.InputJsonValue, output: null,
timeToComplete: 0, timeToComplete: 0,
}; };
} }
@@ -67,6 +70,14 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
include: { modelOutput: true }, include: { modelOutput: true },
}); });
if (!cell) { if (!cell) {
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
statusCode: 404,
errorMessage: "Cell not found",
retrievalStatus: "ERROR",
},
});
return; return;
} }
@@ -85,6 +96,14 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
where: { id: cell.promptVariantId }, where: { id: cell.promptVariantId },
}); });
if (!variant) { if (!variant) {
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
statusCode: 404,
errorMessage: "Prompt Variant not found",
retrievalStatus: "ERROR",
},
});
return; return;
} }
@@ -92,6 +111,14 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
where: { id: cell.testScenarioId }, where: { id: cell.testScenarioId },
}); });
if (!scenario) { if (!scenario) {
await prisma.scenarioVariantCell.update({
where: { id: scenarioVariantCellId },
data: {
statusCode: 404,
errorMessage: "Scenario not found",
retrievalStatus: "ERROR",
},
});
return; return;
} }
@@ -125,10 +152,11 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
data: { data: {
scenarioVariantCellId, scenarioVariantCellId,
inputHash, inputHash,
output: modelResponse.output, output: modelResponse.output as unknown as Prisma.InputJsonObject,
timeToComplete: modelResponse.timeToComplete, timeToComplete: modelResponse.timeToComplete,
promptTokens: modelResponse.promptTokens, promptTokens: modelResponse.promptTokens,
completionTokens: modelResponse.completionTokens, completionTokens: modelResponse.completionTokens,
cost: modelResponse.cost,
}, },
}); });
} }

View File

@@ -0,0 +1,123 @@
import { type PromptVariant } from "@prisma/client";
import { type SupportedModel } from "../types";
import ivm from "isolated-vm";
import dedent from "dedent";
import { openai } from "./openai";
import { getApiShapeForModel } from "./getTypesForModel";
import { isObject } from "lodash-es";
import { type CompletionCreateParams } from "openai/resources/chat/completions";
const isolate = new ivm.Isolate({ memoryLimit: 128 });
export async function deriveNewConstructFn(
originalVariant: PromptVariant | null,
newModel?: SupportedModel,
instructions?: string,
) {
if (originalVariant && !newModel && !instructions) {
return originalVariant.constructFn;
}
if (originalVariant && (newModel || instructions)) {
return await requestUpdatedPromptFunction(originalVariant, newModel, instructions);
}
return dedent`
prompt = {
model: "gpt-3.5-turbo",
messages: [
{
role: "system",
content: "Return 'Hello, world!'",
}
]
}`;
}
const NUM_RETRIES = 5;
const requestUpdatedPromptFunction = async (
originalVariant: PromptVariant,
newModel?: SupportedModel,
instructions?: string,
) => {
const originalModel = originalVariant.model as SupportedModel;
let newContructionFn = "";
for (let i = 0; i < NUM_RETRIES; i++) {
try {
const messages: CompletionCreateParams.CreateChatCompletionRequestNonStreaming.Message[] = [
{
role: "system",
content: `Your job is to update prompt constructor functions. Here is the api shape for the current model:\n---\n${JSON.stringify(
getApiShapeForModel(originalModel),
null,
2,
)}`,
},
];
if (newModel) {
messages.push({
role: "user",
content: `Return the prompt constructor function for ${newModel} given the following prompt constructor function for ${originalModel}:\n---\n${originalVariant.constructFn}`,
});
}
if (instructions) {
messages.push({
role: "user",
content: `Follow these instructions: ${instructions}`,
});
}
messages.push({
role: "user",
content:
"The prompt variable has already been declared, so do not declare it again. Rewrite the entire prompt constructor function.",
});
const completion = await openai.chat.completions.create({
model: "gpt-4",
messages,
functions: [
{
name: "update_prompt_constructor_function",
parameters: {
type: "object",
properties: {
new_prompt_function: {
type: "string",
description: "The new prompt function, runnable in typescript",
},
},
},
},
],
function_call: {
name: "update_prompt_constructor_function",
},
});
const argString = completion.choices[0]?.message?.function_call?.arguments || "{}";
const code = `
global.contructPromptFunctionArgs = ${argString};
`;
const context = await isolate.createContext();
const jail = context.global;
await jail.set("global", jail.derefInto());
const script = await isolate.compileScript(code);
await script.run(context);
const contructPromptFunctionArgs = (await context.global.get(
"contructPromptFunctionArgs",
)) as ivm.Reference;
const args = await contructPromptFunctionArgs.copy(); // Get the actual value from the isolate
if (args && isObject(args) && "new_prompt_function" in args) {
newContructionFn = args.new_prompt_function as string;
break;
}
} catch (e) {
console.error(e);
}
}
return newContructionFn;
};

View File

@@ -62,6 +62,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
inputHash, inputHash,
output: matchingModelOutput.output as Prisma.InputJsonValue, output: matchingModelOutput.output as Prisma.InputJsonValue,
timeToComplete: matchingModelOutput.timeToComplete, timeToComplete: matchingModelOutput.timeToComplete,
cost: matchingModelOutput.cost,
promptTokens: matchingModelOutput.promptTokens, promptTokens: matchingModelOutput.promptTokens,
completionTokens: matchingModelOutput.completionTokens, completionTokens: matchingModelOutput.completionTokens,
createdAt: matchingModelOutput.createdAt, createdAt: matchingModelOutput.createdAt,

View File

@@ -1,24 +1,25 @@
/* eslint-disable @typescript-eslint/no-unsafe-call */ /* eslint-disable @typescript-eslint/no-unsafe-call */
import { isObject } from "lodash-es"; import { isObject } from "lodash-es";
import { Prisma } from "@prisma/client";
import { streamChatCompletion } from "./openai"; import { streamChatCompletion } from "./openai";
import { wsConnection } from "~/utils/wsConnection"; import { wsConnection } from "~/utils/wsConnection";
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat"; import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
import { type OpenAIChatModel } from "../types"; import { type SupportedModel, type OpenAIChatModel } from "../types";
import { env } from "~/env.mjs"; import { env } from "~/env.mjs";
import { countOpenAIChatTokens } from "~/utils/countTokens"; import { countOpenAIChatTokens } from "~/utils/countTokens";
import { rateLimitErrorMessage } from "~/sharedStrings"; import { rateLimitErrorMessage } from "~/sharedStrings";
import { modelStats } from "../modelStats";
export type CompletionResponse = { export type CompletionResponse = {
output: Prisma.InputJsonValue | typeof Prisma.JsonNull; output: ChatCompletion | null;
statusCode: number; statusCode: number;
errorMessage: string | null; errorMessage: string | null;
timeToComplete: number; timeToComplete: number;
promptTokens?: number; promptTokens?: number;
completionTokens?: number; completionTokens?: number;
cost?: number;
}; };
export async function getCompletion( export async function getOpenAIChatCompletion(
payload: CompletionCreateParams, payload: CompletionCreateParams,
channel?: string, channel?: string,
): Promise<CompletionResponse> { ): Promise<CompletionResponse> {
@@ -35,7 +36,7 @@ export async function getCompletion(
}); });
const resp: CompletionResponse = { const resp: CompletionResponse = {
output: Prisma.JsonNull, output: null,
errorMessage: null, errorMessage: null,
statusCode: response.status, statusCode: response.status,
timeToComplete: 0, timeToComplete: 0,
@@ -52,7 +53,7 @@ export async function getCompletion(
} }
})().catch((err) => console.error(err)); })().catch((err) => console.error(err));
if (finalOutput) { if (finalOutput) {
resp.output = finalOutput as unknown as Prisma.InputJsonValue; resp.output = finalOutput;
resp.timeToComplete = Date.now() - start; resp.timeToComplete = Date.now() - start;
} }
} else { } else {
@@ -88,6 +89,13 @@ export async function getCompletion(
resp.completionTokens = countOpenAIChatTokens(model, messages); resp.completionTokens = countOpenAIChatTokens(model, messages);
} }
} }
const stats = modelStats[resp.output?.model as SupportedModel];
if (stats && resp.promptTokens && resp.completionTokens) {
resp.cost =
resp.promptTokens * stats.promptTokenPrice +
resp.completionTokens * stats.completionTokenPrice;
}
} catch (e) { } catch (e) {
console.error(e); console.error(e);
if (response.ok) { if (response.ok) {

View File

@@ -0,0 +1,7 @@
import { OpenAIChatModel, type SupportedModel } from "../types";
import openAIChatApiShape from "~/codegen/openai.types.ts.txt";
export const getApiShapeForModel = (model: SupportedModel) => {
if (model in OpenAIChatModel) return openAIChatApiShape;
return "";
};

View File

@@ -0,0 +1,65 @@
import { prisma } from "~/server/db";
export const reorderPromptVariants = async (
movedId: string,
stationaryTargetId: string,
alwaysInsertRight?: boolean,
) => {
const moved = await prisma.promptVariant.findUnique({
where: {
id: movedId,
},
});
const target = await prisma.promptVariant.findUnique({
where: {
id: stationaryTargetId,
},
});
if (!moved || !target || moved.experimentId !== target.experimentId) {
throw new Error(`Prompt Variant with id ${movedId} or ${stationaryTargetId} does not exist`);
}
const visibleItems = await prisma.promptVariant.findMany({
where: {
experimentId: moved.experimentId,
visible: true,
},
orderBy: {
sortIndex: "asc",
},
});
// Remove the moved item from its current position
const orderedItems = visibleItems.filter((item) => item.id !== moved.id);
// Find the index of the moved item and the target item
const movedIndex = visibleItems.findIndex((item) => item.id === moved.id);
const targetIndex = visibleItems.findIndex((item) => item.id === target.id);
// Determine the new index for the moved item
let newIndex;
if (movedIndex < targetIndex || alwaysInsertRight) {
newIndex = targetIndex + 1; // Insert after the target item
} else {
newIndex = targetIndex; // Insert before the target item
}
// Insert the moved item at the new position
orderedItems.splice(newIndex, 0, moved);
// Now, we need to update all the items with their new sortIndex
await prisma.$transaction(
orderedItems.map((item, index) => {
return prisma.promptVariant.update({
where: {
id: item.id,
},
data: {
sortIndex: index,
},
});
}),
);
};

View File

@@ -0,0 +1,19 @@
import { prisma } from "~/server/db";
export default async function userOrg(userId: string) {
return await prisma.organization.upsert({
where: {
personalOrgUserId: userId,
},
update: {},
create: {
personalOrgUserId: userId,
OrganizationUser: {
create: {
userId: userId,
role: "ADMIN",
},
},
},
});
}

View File

@@ -0,0 +1,49 @@
import { OrganizationUserRole } from "@prisma/client";
import { TRPCError } from "@trpc/server";
import { type TRPCContext } from "~/server/api/trpc";
import { prisma } from "~/server/db";
// No-op method for protected routes that really should be accessible to anyone.
export const requireNothing = (ctx: TRPCContext) => {
ctx.markAccessControlRun();
};
export const requireCanViewExperiment = async (experimentId: string, ctx: TRPCContext) => {
await prisma.experiment.findFirst({
where: { id: experimentId },
});
// Right now all experiments are publicly viewable, so this is a no-op.
ctx.markAccessControlRun();
};
export const canModifyExperiment = async (experimentId: string, userId: string) => {
const experiment = await prisma.experiment.findFirst({
where: {
id: experimentId,
organization: {
OrganizationUser: {
some: {
role: { in: [OrganizationUserRole.ADMIN, OrganizationUserRole.MEMBER] },
userId,
},
},
},
},
});
return !!experiment;
};
export const requireCanModifyExperiment = async (experimentId: string, ctx: TRPCContext) => {
const userId = ctx.session?.user.id;
if (!userId) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
if (!(await canModifyExperiment(experimentId, userId))) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
ctx.markAccessControlRun();
};

View File

@@ -1,46 +0,0 @@
import { type SupportedModel, OpenAIChatModel } from "~/server/types";
const openAIPromptTokensToDollars: { [key in OpenAIChatModel]: number } = {
"gpt-4": 0.00003,
"gpt-4-0613": 0.00003,
"gpt-4-32k": 0.00006,
"gpt-4-32k-0613": 0.00006,
"gpt-3.5-turbo": 0.0000015,
"gpt-3.5-turbo-0613": 0.0000015,
"gpt-3.5-turbo-16k": 0.000003,
"gpt-3.5-turbo-16k-0613": 0.000003,
};
const openAICompletionTokensToDollars: { [key in OpenAIChatModel]: number } = {
"gpt-4": 0.00006,
"gpt-4-0613": 0.00006,
"gpt-4-32k": 0.00012,
"gpt-4-32k-0613": 0.00012,
"gpt-3.5-turbo": 0.000002,
"gpt-3.5-turbo-0613": 0.000002,
"gpt-3.5-turbo-16k": 0.000004,
"gpt-3.5-turbo-16k-0613": 0.000004,
};
export const calculateTokenCost = (
model: SupportedModel | string | null,
numTokens: number,
isCompletion = false,
) => {
if (!model) return 0;
if (model in OpenAIChatModel) {
return calculateOpenAIChatTokenCost(model as OpenAIChatModel, numTokens, isCompletion);
}
return 0;
};
const calculateOpenAIChatTokenCost = (
model: OpenAIChatModel,
numTokens: number,
isCompletion: boolean,
) => {
const tokensToDollars = isCompletion
? openAICompletionTokensToDollars[model]
: openAIPromptTokensToDollars[model];
return tokensToDollars * numTokens;
};

View File

@@ -5,16 +5,5 @@ import relativeTime from "dayjs/plugin/relativeTime";
dayjs.extend(duration); dayjs.extend(duration);
dayjs.extend(relativeTime); dayjs.extend(relativeTime);
export const formatTimePast = (date: Date) => { export const formatTimePast = (date: Date) =>
const now = dayjs(); dayjs.duration(dayjs(date).diff(dayjs())).humanize(true);
const dayDiff = Math.floor(now.diff(date, "day"));
if (dayDiff > 0) return dayjs.duration(-dayDiff, "days").humanize(true);
const hourDiff = Math.floor(now.diff(date, "hour"));
if (hourDiff > 0) return dayjs.duration(-hourDiff, "hours").humanize(true);
const minuteDiff = Math.floor(now.diff(date, "minute"));
if (minuteDiff > 0) return dayjs.duration(-minuteDiff, "minutes").humanize(true);
return "a few seconds ago";
};

View File

@@ -12,6 +12,10 @@ export const useExperiment = () => {
return experiment; return experiment;
}; };
export const useExperimentAccess = () => {
return useExperiment().data?.access ?? { canView: false, canModify: false };
};
type AsyncFunction<T extends unknown[], U> = (...args: T) => Promise<U>; type AsyncFunction<T extends unknown[], U> = (...args: T) => Promise<U>;
export function useHandledAsyncCallback<T extends unknown[], U>( export function useHandledAsyncCallback<T extends unknown[], U>(

View File

@@ -1,4 +1,5 @@
import { extendTheme } from "@chakra-ui/react"; import { extendTheme } from "@chakra-ui/react";
import "@fontsource/inconsolata";
const systemFont = const systemFont =
'ui-sans-serif, -apple-system, "system-ui", "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol"'; 'ui-sans-serif, -apple-system, "system-ui", "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol"';