Compare commits

..

21 Commits

Author SHA1 Message Date
Kyle Corbitt
d7cff0f52e Add caching in Python
Still need it in JS
2023-08-11 19:02:35 -07:00
Kyle Corbitt
8ed47eb4dd Add a python client library
We still don't have any documentation and things are in flux, but you can report your OpenAI API calls to OpenPipe now.
2023-08-11 16:54:50 -07:00
arcticfly
d9db6d80ea Update external types (#137)
* Separate server and frontend error logic

* Update types in external api
2023-08-11 15:02:14 -07:00
arcticfly
8d1ee62ff1 Record model and cost when reporting logs (#136)
* Rename prompt and completion tokens to input and output tokens

* Add getUsage function

* Record model and cost when reporting log

* Remove unused imports

* Move UsageGraph to its own component

* Standardize model response fields

* Fix types
2023-08-11 13:56:47 -07:00
arcticfly
f270579283 Auto-resize project menu width (#135) 2023-08-10 22:50:39 -07:00
arcticfly
81fbaeae44 Style project settings on mobile (#134)
* Style project settings on mobile

* Use auto-resize text area for display name

* Remove unused import
2023-08-10 22:15:45 -07:00
arcticfly
5277afa199 Change logo (#133)
* Change logo

* Add more vertical padding on desktop

* Fix prettier
2023-08-10 21:44:33 -07:00
arcticfly
76c34d64e6 Change menu styles (#132)
* Change ProjectMenu placement

* Reduce UserMenu width
2023-08-10 18:48:23 -07:00
Kyle Corbitt
454ac9a0d3 Merge pull request #131 from OpenPipe/better-template-vars
Better scenario variable editing
2023-08-10 12:25:54 -07:00
Kyle Corbitt
5ed7adadf9 Better scenario variable editing
Some users have gotten confused by the scenario variable editing interface. This change makes the interface easier to understand.
2023-08-10 12:08:17 -07:00
Kyle Corbitt
b8e0f392ab Merge pull request #130 from OpenPipe/output-wrapping
Preserve linebreaks in model output
2023-08-10 07:26:55 -07:00
Kyle Corbitt
b2af83341d Preserve linebreaks in model output 2023-08-09 21:58:41 -07:00
Kyle Corbitt
e6d229d5f9 Merge pull request #129 from OpenPipe/persist-proj
persist the currently-selected project
2023-08-09 17:05:17 -07:00
Kyle Corbitt
1a6ae3aef7 Merge pull request #128 from OpenPipe/proj-styling
Sidebar styling
2023-08-09 17:05:02 -07:00
Kyle Corbitt
9051d80775 Sidebar styling
Unify the menu styles between the UserMenu and ProjectMenu
2023-08-09 16:47:09 -07:00
Kyle Corbitt
6c060c6ea0 persist the currently-selected project 2023-08-09 16:45:54 -07:00
Kyle Corbitt
f70e73e338 Merge pull request #126 from OpenPipe/org-to-proj
Rename Organization to Project
2023-08-09 16:04:36 -07:00
Kyle Corbitt
16aa6672fc Rename Organization to Project
We'll probably need a concept of organizations at some point in the future, but in practice the way we're using these in the codebase right now is as a project, so this renames it to that to avoid confusion.
2023-08-09 16:01:13 -07:00
Kyle Corbitt
ac99c8e0f7 Merge pull request #127 from OpenPipe/pause-champs
Pause world championships
2023-08-09 15:59:15 -07:00
Kyle Corbitt
df121db78c Merge pull request #125 from OpenPipe/claude-1.1
Support Claude Instant 1.2
2023-08-09 15:58:36 -07:00
Kyle Corbitt
f09dfe18be Support Claude Instant 1.2 2023-08-09 14:43:54 -07:00
123 changed files with 4119 additions and 1617 deletions

3
.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
.env
.venv/
*.pyc

View File

@@ -65,7 +65,14 @@ OpenPipe includes a tool to generate new test scenarios based on your existing p
4. Clone this repository: `git clone https://github.com/openpipe/openpipe` 4. Clone this repository: `git clone https://github.com/openpipe/openpipe`
5. Install the dependencies: `cd openpipe && pnpm install` 5. Install the dependencies: `cd openpipe && pnpm install`
6. Create a `.env` file (`cp .env.example .env`) and enter your `OPENAI_API_KEY`. 6. Create a `.env` file (`cp .env.example .env`) and enter your `OPENAI_API_KEY`.
7. Update `DATABASE_URL` if necessary to point to your Postgres instance and run `pnpm prisma db push` to create the database. 7. Update `DATABASE_URL` if necessary to point to your Postgres instance and run `pnpm prisma migrate dev` to create the database.
8. Create a [GitHub OAuth App](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/creating-an-oauth-app) and update the `GITHUB_CLIENT_ID` and `GITHUB_CLIENT_SECRET` values. (Note: a PR to make auth optional when running locally would be a great contribution!) 8. Create a [GitHub OAuth App](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/creating-an-oauth-app) and update the `GITHUB_CLIENT_ID` and `GITHUB_CLIENT_SECRET` values. (Note: a PR to make auth optional when running locally would be a great contribution!)
9. Start the app: `pnpm dev`. 9. Start the app: `pnpm dev`.
10. Navigate to [http://localhost:3000](http://localhost:3000) 10. Navigate to [http://localhost:3000](http://localhost:3000)
## Testing Locally
1. Copy your `.env` file to `.env.test`.
2. Update the `DATABASE_URL` to have a different database name than your development one
3. Run `DATABASE_URL=[your new datatase url] pnpm prisma migrate dev --skip-seed --skip-generate`
4. Run `pnpm test`

1
app/.gitignore vendored
View File

@@ -34,6 +34,7 @@ yarn-error.log*
# do not commit any .env files to git, except for the .env.example file. https://create.t3.gg/en/usage/env-variables#using-environment-variables # do not commit any .env files to git, except for the .env.example file. https://create.t3.gg/en/usage/env-variables#using-environment-variables
.env .env
.env*.local .env*.local
.env.test
# vercel # vercel
.vercel .vercel

View File

@@ -19,7 +19,7 @@
"codegen": "tsx src/server/scripts/client-codegen.ts", "codegen": "tsx src/server/scripts/client-codegen.ts",
"seed": "tsx prisma/seed.ts", "seed": "tsx prisma/seed.ts",
"check": "concurrently 'pnpm lint' 'pnpm tsc' 'pnpm prettier . --check'", "check": "concurrently 'pnpm lint' 'pnpm tsc' 'pnpm prettier . --check'",
"test": "pnpm vitest --no-threads" "test": "pnpm vitest"
}, },
"dependencies": { "dependencies": {
"@anthropic-ai/sdk": "^0.5.8", "@anthropic-ai/sdk": "^0.5.8",

View File

@@ -0,0 +1,37 @@
-- Rename Enum
ALTER TYPE "OrganizationUserRole" RENAME TO "ProjectUserRole";
-- Drop and recreate foreign keys
ALTER TABLE "ApiKey" DROP CONSTRAINT "ApiKey_organizationId_fkey";
ALTER TABLE "Dataset" DROP CONSTRAINT "Dataset_organizationId_fkey";
ALTER TABLE "Experiment" DROP CONSTRAINT "Experiment_organizationId_fkey";
ALTER TABLE "LoggedCall" DROP CONSTRAINT "LoggedCall_organizationId_fkey";
ALTER TABLE "OrganizationUser" DROP CONSTRAINT "OrganizationUser_organizationId_fkey";
ALTER TABLE "OrganizationUser" DROP CONSTRAINT "OrganizationUser_userId_fkey";
-- Rename columns
ALTER TABLE "ApiKey" RENAME COLUMN "organizationId" TO "projectId";
ALTER TABLE "Dataset" RENAME COLUMN "organizationId" TO "projectId";
ALTER TABLE "Experiment" RENAME COLUMN "organizationId" TO "projectId";
ALTER TABLE "LoggedCall" RENAME COLUMN "organizationId" TO "projectId";
ALTER TABLE "OrganizationUser" RENAME COLUMN "organizationId" TO "projectId";
ALTER TABLE "Organization" RENAME COLUMN "personalOrgUserId" TO "personalProjectUserId";
-- Rename table
ALTER TABLE "Organization" RENAME TO "Project";
ALTER TABLE "OrganizationUser" RENAME TO "ProjectUser";
-- Recreate foreign keys
ALTER TABLE "Experiment" ADD CONSTRAINT "Experiment_projectId_fkey" FOREIGN KEY ("projectId") REFERENCES "Project"("id") ON DELETE CASCADE ON UPDATE CASCADE;
ALTER TABLE "Dataset" ADD CONSTRAINT "Dataset_projectId_fkey" FOREIGN KEY ("projectId") REFERENCES "Project"("id") ON DELETE CASCADE ON UPDATE CASCADE;
ALTER TABLE "ProjectUser" ADD CONSTRAINT "ProjectUser_projectId_fkey" FOREIGN KEY ("projectId") REFERENCES "Project"("id") ON DELETE CASCADE ON UPDATE CASCADE;
ALTER TABLE "ProjectUser" ADD CONSTRAINT "ProjectUser_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
ALTER TABLE "LoggedCall" ADD CONSTRAINT "LoggedCall_projectId_fkey" FOREIGN KEY ("projectId") REFERENCES "Project"("id") ON DELETE CASCADE ON UPDATE CASCADE;
ALTER TABLE "ApiKey" ADD CONSTRAINT "ApiKey_projectId_fkey" FOREIGN KEY ("projectId") REFERENCES "Project"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- Rename indexes
ALTER TABLE "Project" RENAME CONSTRAINT "Organization_pkey" TO "Project_pkey";
ALTER TABLE "ProjectUser" RENAME CONSTRAINT "OrganizationUser_pkey" TO "ProjectUser_pkey";
ALTER TABLE "Project" RENAME CONSTRAINT "Organization_personalOrgUserId_fkey" TO "Project_personalProjectUserId_fkey";
ALTER INDEX "Organization_personalOrgUserId_key" RENAME TO "Project_personalProjectUserId_key";
ALTER INDEX "OrganizationUser_organizationId_userId_key" RENAME TO "ProjectUser_projectId_userId_key";

View File

@@ -0,0 +1 @@
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";

View File

@@ -0,0 +1,66 @@
/*
Warnings:
- You are about to rename the column `completionTokens` to `outputTokens` on the `ModelResponse` table.
- You are about to rename the column `promptTokens` to `inputTokens` on the `ModelResponse` table.
- You are about to rename the column `startTime` on the `LoggedCall` table to `requestedAt`. Ensure compatibility with application logic.
- You are about to rename the column `startTime` on the `LoggedCallModelResponse` table to `requestedAt`. Ensure compatibility with application logic.
- You are about to rename the column `endTime` on the `LoggedCallModelResponse` table to `receivedAt`. Ensure compatibility with application logic.
- You are about to rename the column `error` on the `LoggedCallModelResponse` table to `errorMessage`. Ensure compatibility with application logic.
- You are about to rename the column `respStatus` on the `LoggedCallModelResponse` table to `statusCode`. Ensure compatibility with application logic.
- You are about to rename the column `totalCost` on the `LoggedCallModelResponse` table to `cost`. Ensure compatibility with application logic.
- You are about to rename the column `inputHash` on the `ModelResponse` table to `cacheKey`. Ensure compatibility with application logic.
- You are about to rename the column `output` on the `ModelResponse` table to `respPayload`. Ensure compatibility with application logic.
*/
-- DropIndex
DROP INDEX "LoggedCall_startTime_idx";
-- DropIndex
DROP INDEX "ModelResponse_inputHash_idx";
-- Rename completionTokens to outputTokens
ALTER TABLE "ModelResponse"
RENAME COLUMN "completionTokens" TO "outputTokens";
-- Rename promptTokens to inputTokens
ALTER TABLE "ModelResponse"
RENAME COLUMN "promptTokens" TO "inputTokens";
-- AlterTable
ALTER TABLE "LoggedCall"
RENAME COLUMN "startTime" TO "requestedAt";
-- AlterTable
ALTER TABLE "LoggedCallModelResponse"
RENAME COLUMN "startTime" TO "requestedAt";
-- AlterTable
ALTER TABLE "LoggedCallModelResponse"
RENAME COLUMN "endTime" TO "receivedAt";
-- AlterTable
ALTER TABLE "LoggedCallModelResponse"
RENAME COLUMN "error" TO "errorMessage";
-- AlterTable
ALTER TABLE "LoggedCallModelResponse"
RENAME COLUMN "respStatus" TO "statusCode";
-- AlterTable
ALTER TABLE "LoggedCallModelResponse"
RENAME COLUMN "totalCost" TO "cost";
-- AlterTable
ALTER TABLE "ModelResponse"
RENAME COLUMN "inputHash" TO "cacheKey";
-- AlterTable
ALTER TABLE "ModelResponse"
RENAME COLUMN "output" TO "respPayload";
-- CreateIndex
CREATE INDEX "LoggedCall_requestedAt_idx" ON "LoggedCall"("requestedAt");
-- CreateIndex
CREATE INDEX "ModelResponse_cacheKey_idx" ON "ModelResponse"("cacheKey");

View File

@@ -16,8 +16,8 @@ model Experiment {
sortIndex Int @default(0) sortIndex Int @default(0)
organizationId String @db.Uuid projectId String @db.Uuid
organization Organization? @relation(fields: [organizationId], references: [id], onDelete: Cascade) project Project? @relation(fields: [projectId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
@@ -112,13 +112,13 @@ model ScenarioVariantCell {
model ModelResponse { model ModelResponse {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
inputHash String cacheKey String
requestedAt DateTime? requestedAt DateTime?
receivedAt DateTime? receivedAt DateTime?
output Json? respPayload Json?
cost Float? cost Float?
promptTokens Int? inputTokens Int?
completionTokens Int? outputTokens Int?
statusCode Int? statusCode Int?
errorMessage String? errorMessage String?
retryTime DateTime? retryTime DateTime?
@@ -131,7 +131,7 @@ model ModelResponse {
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade) scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
outputEvaluations OutputEvaluation[] outputEvaluations OutputEvaluation[]
@@index([inputHash]) @@index([cacheKey])
} }
enum EvalType { enum EvalType {
@@ -180,8 +180,8 @@ model Dataset {
name String name String
datasetEntries DatasetEntry[] datasetEntries DatasetEntry[]
organizationId String @db.Uuid projectId String @db.Uuid
organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) project Project @relation(fields: [projectId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
@@ -200,36 +200,35 @@ model DatasetEntry {
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
} }
// TODO rename Organization to Project model Project {
model Organization { id String @id @default(uuid()) @db.Uuid
id String @id @default(uuid()) @db.Uuid name String @default("Project 1")
name String @default("Project 1")
personalOrgUserId String? @unique @db.Uuid personalProjectUserId String? @unique @db.Uuid
personalOrgUser User? @relation(fields: [personalOrgUserId], references: [id], onDelete: Cascade) personalProjectUser User? @relation(fields: [personalProjectUserId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
organizationUsers OrganizationUser[] projectUsers ProjectUser[]
experiments Experiment[] experiments Experiment[]
datasets Dataset[] datasets Dataset[]
loggedCalls LoggedCall[] loggedCalls LoggedCall[]
apiKeys ApiKey[] apiKeys ApiKey[]
} }
enum OrganizationUserRole { enum ProjectUserRole {
ADMIN ADMIN
MEMBER MEMBER
VIEWER VIEWER
} }
model OrganizationUser { model ProjectUser {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
role OrganizationUserRole role ProjectUserRole
organizationId String @db.Uuid projectId String @db.Uuid
organization Organization? @relation(fields: [organizationId], references: [id], onDelete: Cascade) project Project? @relation(fields: [projectId], references: [id], onDelete: Cascade)
userId String @db.Uuid userId String @db.Uuid
user User @relation(fields: [userId], references: [id], onDelete: Cascade) user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@ -237,7 +236,7 @@ model OrganizationUser {
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
@@unique([organizationId, userId]) @@unique([projectId, userId])
} }
model WorldChampEntrant { model WorldChampEntrant {
@@ -257,7 +256,7 @@ model WorldChampEntrant {
model LoggedCall { model LoggedCall {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
startTime DateTime requestedAt DateTime
// True if this call was served from the cache, false otherwise // True if this call was served from the cache, false otherwise
cacheHit Boolean cacheHit Boolean
@@ -265,21 +264,21 @@ model LoggedCall {
// A LoggedCall is always associated with a LoggedCallModelResponse. If this // A LoggedCall is always associated with a LoggedCallModelResponse. If this
// is a cache miss, we create a new LoggedCallModelResponse. // is a cache miss, we create a new LoggedCallModelResponse.
// If it's a cache hit, it's a pre-existing LoggedCallModelResponse. // If it's a cache hit, it's a pre-existing LoggedCallModelResponse.
modelResponseId String? @db.Uuid modelResponseId String? @db.Uuid
modelResponse LoggedCallModelResponse? @relation(fields: [modelResponseId], references: [id], onDelete: Cascade) modelResponse LoggedCallModelResponse? @relation(fields: [modelResponseId], references: [id], onDelete: Cascade)
// The responses created by this LoggedCall. Will be empty if this LoggedCall was a cache hit. // The responses created by this LoggedCall. Will be empty if this LoggedCall was a cache hit.
createdResponses LoggedCallModelResponse[] @relation(name: "ModelResponseOriginalCall") createdResponses LoggedCallModelResponse[] @relation(name: "ModelResponseOriginalCall")
organizationId String @db.Uuid projectId String @db.Uuid
organization Organization? @relation(fields: [organizationId], references: [id], onDelete: Cascade) project Project? @relation(fields: [projectId], references: [id], onDelete: Cascade)
tags LoggedCallTag[] tags LoggedCallTag[]
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
@@index([startTime]) @@index([requestedAt])
} }
model LoggedCallModelResponse { model LoggedCallModelResponse {
@@ -288,14 +287,14 @@ model LoggedCallModelResponse {
reqPayload Json reqPayload Json
// The HTTP status returned by the model provider // The HTTP status returned by the model provider
respStatus Int? statusCode Int?
respPayload Json? respPayload Json?
// Should be null if the request was successful, and some string if the request failed. // Should be null if the request was successful, and some string if the request failed.
error String? errorMessage String?
startTime DateTime requestedAt DateTime
endTime DateTime receivedAt DateTime
// Note: the function to calculate the cacheKey should include the project // Note: the function to calculate the cacheKey should include the project
// ID so we don't share cached responses between projects, which could be an // ID so we don't share cached responses between projects, which could be an
@@ -309,7 +308,7 @@ model LoggedCallModelResponse {
outputTokens Int? outputTokens Int?
finishReason String? finishReason String?
completionId String? completionId String?
totalCost Decimal? @db.Decimal(18, 12) cost Decimal? @db.Decimal(18, 12)
// The LoggedCall that created this LoggedCallModelResponse // The LoggedCall that created this LoggedCallModelResponse
originalLoggedCallId String @unique @db.Uuid originalLoggedCallId String @unique @db.Uuid
@@ -323,11 +322,11 @@ model LoggedCallModelResponse {
} }
model LoggedCallTag { model LoggedCallTag {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
name String name String
value String? value String?
loggedCallId String @db.Uuid loggedCallId String @db.Uuid
loggedCall LoggedCall @relation(fields: [loggedCallId], references: [id], onDelete: Cascade) loggedCall LoggedCall @relation(fields: [loggedCallId], references: [id], onDelete: Cascade)
@@index([name]) @@index([name])
@@ -340,8 +339,8 @@ model ApiKey {
name String name String
apiKey String @unique apiKey String @unique
organizationId String @db.Uuid projectId String @db.Uuid
organization Organization? @relation(fields: [organizationId], references: [id], onDelete: Cascade) project Project? @relation(fields: [projectId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
@@ -390,8 +389,8 @@ model User {
accounts Account[] accounts Account[]
sessions Session[] sessions Session[]
organizationUsers OrganizationUser[] projectUsers ProjectUser[]
organizations Organization[] projects Project[]
worldChampEntrant WorldChampEntrant? worldChampEntrant WorldChampEntrant?
createdAt DateTime @default(now()) createdAt DateTime @default(now())

View File

@@ -5,14 +5,14 @@ import { promptConstructorVersion } from "~/promptConstructor/version";
const defaultId = "11111111-1111-1111-1111-111111111111"; const defaultId = "11111111-1111-1111-1111-111111111111";
await prisma.organization.deleteMany({ await prisma.project.deleteMany({
where: { id: defaultId }, where: { id: defaultId },
}); });
// If there's an existing org, just seed into it // If there's an existing project, just seed into it
const org = const project =
(await prisma.organization.findFirst({})) ?? (await prisma.project.findFirst({})) ??
(await prisma.organization.create({ (await prisma.project.create({
data: { id: defaultId }, data: { id: defaultId },
})); }));
@@ -26,7 +26,7 @@ await prisma.experiment.create({
data: { data: {
id: defaultId, id: defaultId,
label: "Country Capitals Example", label: "Country Capitals Example",
organizationId: org.id, projectId: project.id,
}, },
}); });

View File

@@ -7,14 +7,14 @@ import { promptConstructorVersion } from "~/promptConstructor/version";
const defaultId = "11111111-1111-1111-1111-111111111112"; const defaultId = "11111111-1111-1111-1111-111111111112";
await prisma.organization.deleteMany({ await prisma.project.deleteMany({
where: { id: defaultId }, where: { id: defaultId },
}); });
// If there's an existing org, just seed into it // If there's an existing project, just seed into it
const org = const project =
(await prisma.organization.findFirst({})) ?? (await prisma.project.findFirst({})) ??
(await prisma.organization.create({ (await prisma.project.create({
data: { id: defaultId }, data: { id: defaultId },
})); }));
@@ -47,7 +47,7 @@ for (const dataset of datasets) {
const oldExperiment = await prisma.experiment.findFirst({ const oldExperiment = await prisma.experiment.findFirst({
where: { where: {
label: experimentName, label: experimentName,
organizationId: org.id, projectId: project.id,
}, },
}); });
if (oldExperiment) { if (oldExperiment) {
@@ -60,7 +60,7 @@ for (const dataset of datasets) {
data: { data: {
id: oldExperiment?.id ?? undefined, id: oldExperiment?.id ?? undefined,
label: experimentName, label: experimentName,
organizationId: org.id, projectId: project.id,
}, },
}); });

View File

@@ -311,9 +311,9 @@ const MODEL_RESPONSE_TEMPLATES: {
await prisma.loggedCallModelResponse.deleteMany(); await prisma.loggedCallModelResponse.deleteMany();
const org = await prisma.organization.findFirst({ const project = await prisma.project.findFirst({
where: { where: {
personalOrgUserId: { personalProjectUserId: {
not: null, not: null,
}, },
}, },
@@ -322,8 +322,8 @@ const org = await prisma.organization.findFirst({
}, },
}); });
if (!org) { if (!project) {
console.error("No org found. Sign up to create your first org."); console.error("No project found. Sign up to create your first project.");
process.exit(1); process.exit(1);
} }
@@ -339,17 +339,17 @@ for (let i = 0; i < 1437; i++) {
MODEL_RESPONSE_TEMPLATES[Math.floor(Math.random() * MODEL_RESPONSE_TEMPLATES.length)]!; MODEL_RESPONSE_TEMPLATES[Math.floor(Math.random() * MODEL_RESPONSE_TEMPLATES.length)]!;
const model = template.reqPayload.model; const model = template.reqPayload.model;
// choose random time in the last two weeks, with a bias towards the last few days // choose random time in the last two weeks, with a bias towards the last few days
const startTime = new Date(Date.now() - Math.pow(Math.random(), 2) * 1000 * 60 * 60 * 24 * 14); const requestedAt = new Date(Date.now() - Math.pow(Math.random(), 2) * 1000 * 60 * 60 * 24 * 14);
// choose random delay anywhere from 2 to 10 seconds later for gpt-4, or 1 to 5 seconds for gpt-3.5 // choose random delay anywhere from 2 to 10 seconds later for gpt-4, or 1 to 5 seconds for gpt-3.5
const delay = const delay =
model === "gpt-4" ? 1000 * 2 + Math.random() * 1000 * 8 : 1000 + Math.random() * 1000 * 4; model === "gpt-4" ? 1000 * 2 + Math.random() * 1000 * 8 : 1000 + Math.random() * 1000 * 4;
const endTime = new Date(startTime.getTime() + delay); const receivedAt = new Date(requestedAt.getTime() + delay);
loggedCallsToCreate.push({ loggedCallsToCreate.push({
id: loggedCallId, id: loggedCallId,
cacheHit: false, cacheHit: false,
startTime, requestedAt,
organizationId: org.id, projectId: project.id,
createdAt: startTime, createdAt: requestedAt,
}); });
const { promptTokenPrice, completionTokenPrice } = const { promptTokenPrice, completionTokenPrice } =
@@ -365,21 +365,20 @@ for (let i = 0; i < 1437; i++) {
loggedCallModelResponsesToCreate.push({ loggedCallModelResponsesToCreate.push({
id: loggedCallModelResponseId, id: loggedCallModelResponseId,
startTime, requestedAt,
endTime, receivedAt,
originalLoggedCallId: loggedCallId, originalLoggedCallId: loggedCallId,
reqPayload: template.reqPayload, reqPayload: template.reqPayload,
respPayload: template.respPayload, respPayload: template.respPayload,
respStatus: template.respStatus, statusCode: template.respStatus,
error: template.error, errorMessage: template.error,
createdAt: startTime, createdAt: requestedAt,
cacheKey: hashRequest(org.id, template.reqPayload as JsonValue), cacheKey: hashRequest(project.id, template.reqPayload as JsonValue),
durationMs: endTime.getTime() - startTime.getTime(), durationMs: receivedAt.getTime() - requestedAt.getTime(),
inputTokens: template.inputTokens, inputTokens: template.inputTokens,
outputTokens: template.outputTokens, outputTokens: template.outputTokens,
finishReason: template.finishReason, finishReason: template.finishReason,
totalCost: cost: template.inputTokens * promptTokenPrice + template.outputTokens * completionTokenPrice,
template.inputTokens * promptTokenPrice + template.outputTokens * completionTokenPrice,
}); });
loggedCallsToUpdate.push({ loggedCallsToUpdate.push({
where: { where: {

View File

@@ -6,14 +6,14 @@ import { promptConstructorVersion } from "~/promptConstructor/version";
const defaultId = "11111111-1111-1111-1111-111111111112"; const defaultId = "11111111-1111-1111-1111-111111111112";
await prisma.organization.deleteMany({ await prisma.project.deleteMany({
where: { id: defaultId }, where: { id: defaultId },
}); });
// If there's an existing org, just seed into it // If there's an existing project, just seed into it
const org = const project =
(await prisma.organization.findFirst({})) ?? (await prisma.project.findFirst({})) ??
(await prisma.organization.create({ (await prisma.project.create({
data: { id: defaultId }, data: { id: defaultId },
})); }));
@@ -27,7 +27,7 @@ const experimentName = `Twitter Sentiment Analysis`;
const oldExperiment = await prisma.experiment.findFirst({ const oldExperiment = await prisma.experiment.findFirst({
where: { where: {
label: experimentName, label: experimentName,
organizationId: org.id, projectId: project.id,
}, },
}); });
if (oldExperiment) { if (oldExperiment) {
@@ -40,7 +40,7 @@ const experiment = await prisma.experiment.create({
data: { data: {
id: oldExperiment?.id ?? undefined, id: oldExperiment?.id ?? undefined,
label: experimentName, label: experimentName,
organizationId: org.id, projectId: project.id,
}, },
}); });

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

After

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6.8 KiB

After

Width:  |  Height:  |  Size: 6.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 22 KiB

After

Width:  |  Height:  |  Size: 49 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6.1 KiB

After

Width:  |  Height:  |  Size: 5.3 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 704 B

After

Width:  |  Height:  |  Size: 800 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.1 KiB

After

Width:  |  Height:  |  Size: 1.3 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

After

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.0 KiB

After

Width:  |  Height:  |  Size: 3.4 KiB

View File

@@ -9,10 +9,9 @@ Created by potrace 1.14, written by Peter Selinger 2001-2017
</metadata> </metadata>
<g transform="translate(0.000000,550.000000) scale(0.100000,-0.100000)" <g transform="translate(0.000000,550.000000) scale(0.100000,-0.100000)"
fill="#000000" stroke="none"> fill="#000000" stroke="none">
<path d="M813 5478 c-18 -13 -37 -36 -43 -52 -6 -19 -10 -236 -10 -603 0 -638 <path d="M785 5474 l-25 -27 0 -622 0 -622 25 -27 24 -26 171 0 170 0 0 -2050
-1 -626 65 -657 25 -12 67 -16 179 -16 l146 0 0 -2032 0 -2032 23 -33 c12 -18 0 -2051 25 -25 24 -24 1557 2 1556 3 19 24 c19 23 19 70 19 2072 l0 2049 169
35 -37 51 -43 19 -7 539 -10 1528 -10 1663 0 1549 -5 1582 65 14 30 16 235 16 0 c165 0 169 1 195 25 l26 24 0 626 0 626 -26 24 -27 25 -1939 0 -1939 0 -24
2059 l0 2026 156 0 156 0 39 39 39 39 0 587 c0 651 1 638 -65 669 -30 14 -223 -26z"/>
16 -1932 16 l-1898 0 -32 -22z"/>
</g> </g>
</svg> </svg>

Before

Width:  |  Height:  |  Size: 858 B

After

Width:  |  Height:  |  Size: 755 B

View File

@@ -1,5 +1,28 @@
<svg width="380" height="320" viewBox="0 0 380 320" fill="none" xmlns="http://www.w3.org/2000/svg"> <svg width="398" height="550" viewBox="0 0 398 550" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M72 320L122.5 231L130.5 150.5L115 73L72 0H312L265 64.5L257 158.5L265 249L312 320H72Z" fill="#FF5733"/> <path d="M39 125H359V542C359 546.418 355.418 550 351 550H47C42.5817 550 39 546.418 39 542V125Z" fill="black"/>
<path d="M67.027 9.5C72.9909 9.5 79.5196 12.3449 86.3672 19.2588C93.2495 26.2075 99.8845 36.7468 105.66 50.5336C117.194 78.0671 124.554 116.764 124.554 160C124.554 203.236 117.194 241.933 105.66 269.466C99.8845 283.253 93.2495 293.793 86.3672 300.741C79.5196 307.655 72.9909 310.5 67.027 310.5C61.0632 310.5 54.5345 307.655 47.6868 300.741C40.8045 293.793 34.1695 283.253 28.394 269.466C16.8596 241.933 9.5 203.236 9.5 160C9.5 116.764 16.8596 78.0671 28.394 50.5336C34.1695 36.7468 40.8045 26.2075 47.6868 19.2588C54.5345 12.3449 61.0632 9.5 67.027 9.5Z" stroke="#FF5733" stroke-width="19"/> <path d="M0 8C0 3.58172 3.58172 0 8 0H390C394.418 0 398 3.58172 398 8V127C398 131.418 394.418 135 390 135H7.99999C3.58171 135 0 131.418 0 127V8Z" fill="black"/>
<path d="M312.027 9.5C317.991 9.5 324.52 12.3449 331.367 19.2588C338.25 26.2075 344.885 36.7468 350.66 50.5336C362.194 78.0671 369.554 116.764 369.554 160C369.554 203.236 362.194 241.933 350.66 269.466C344.885 283.253 338.25 293.793 331.367 300.741C324.52 307.655 317.991 310.5 312.027 310.5C306.063 310.5 299.534 307.655 292.687 300.741C285.805 293.793 279.17 283.253 273.394 269.466C261.86 241.933 254.5 203.236 254.5 160C254.5 116.764 261.86 78.0671 273.394 50.5336C279.17 36.7468 285.805 26.2075 292.687 19.2588C299.534 12.3449 306.063 9.5 312.027 9.5Z" stroke="#FF5733" stroke-width="19"/> <path d="M50 135H348V535C348 537.209 346.209 539 344 539H54C51.7909 539 50 537.209 50 535V135Z" fill="#FF5733"/>
<path d="M11 14.0001C11 11.791 12.7909 10.0001 15 10.0001H384C386.209 10.0001 388 11.791 388 14.0001V120C388 122.209 386.209 124 384 124H15C12.7909 124 11 122.209 11 120V14.0001Z" fill="#FF5733"/>
<path d="M11 14.0001C11 11.791 12.7909 10.0001 15 10.0001H384C386.209 10.0001 388 11.791 388 14.0001V120C388 122.209 386.209 124 384 124H15C12.7909 124 11 122.209 11 120V14.0001Z" fill="url(#paint0_linear_102_49)"/>
<path d="M50 134H348V535C348 537.209 346.209 539 344 539H54C51.7909 539 50 537.209 50 535V134Z" fill="url(#paint1_linear_102_49)"/>
<path d="M108 142H156V535H108V142Z" fill="white"/>
<path d="M300 135H348V535C348 537.209 346.209 539 344 539H300V135Z" fill="white" fill-opacity="0.25"/>
<path d="M96 142H108V535H96V142Z" fill="white" fill-opacity="0.5"/>
<path d="M84 10.0001H133V120H84V10.0001Z" fill="white"/>
<path d="M339 10.0001H384C386.209 10.0001 388 11.791 388 14.0001V120C388 122.209 386.209 124 384 124H339V10.0001Z" fill="white" fill-opacity="0.25"/>
<path d="M71.9995 10.0001H83.9995V120H71.9995V10.0001Z" fill="white" fill-opacity="0.5"/>
<path d="M108 534.529H156V539.019H108V534.529Z" fill="#AAAAAA"/>
<path opacity="0.5" d="M95.9927 534.529H107.982V539.019H95.9927V534.529Z" fill="#AAAAAA"/>
<path d="M84.0029 119.887H133.007V124.027H84.0029V119.887Z" fill="#AAAAAA"/>
<path opacity="0.5" d="M71.9883 119.887H83.978V124.027H71.9883V119.887Z" fill="#AAAAAA"/>
<defs>
<linearGradient id="paint0_linear_102_49" x1="335" y1="67.0001" x2="137" y2="67.0001" gradientUnits="userSpaceOnUse">
<stop stop-color="#D62600"/>
<stop offset="1" stop-color="#FF5733" stop-opacity="0"/>
</linearGradient>
<linearGradient id="paint1_linear_102_49" x1="306.106" y1="336.5" x2="149.597" y2="336.5" gradientUnits="userSpaceOnUse">
<stop stop-color="#D62600"/>
<stop offset="1" stop-color="#FF5733" stop-opacity="0"/>
</linearGradient>
</defs>
</svg> </svg>

Before

Width:  |  Height:  |  Size: 1.4 KiB

After

Width:  |  Height:  |  Size: 2.3 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 62 KiB

After

Width:  |  Height:  |  Size: 26 KiB

View File

@@ -19,7 +19,7 @@ const CopiableCode = ({ code }: { code: string }) => {
w="full" w="full"
justifyContent="space-between" justifyContent="space-between"
> >
<Text fontFamily="inconsolata" fontWeight="bold" letterSpacing={0.5}> <Text fontFamily="inconsolata" fontWeight="bold" letterSpacing={0.5} overflowX="auto">
{code} {code}
</Text> </Text>
<Tooltip closeOnClick={false} label={copied ? "Copied!" : "Copy to clipboard"}> <Tooltip closeOnClick={false} label={copied ? "Copied!" : "Copy to clipboard"}>

View File

@@ -8,8 +8,8 @@ export default function Favicon() {
<link rel="icon" type="image/png" sizes="16x16" href="/favicons/favicon-16x16.png" /> <link rel="icon" type="image/png" sizes="16x16" href="/favicons/favicon-16x16.png" />
<link rel="manifest" href="/favicons/site.webmanifest" /> <link rel="manifest" href="/favicons/site.webmanifest" />
<link rel="shortcut icon" href="/favicons/favicon.ico" /> <link rel="shortcut icon" href="/favicons/favicon.ico" />
<link rel="mask-icon" href="/favicons/safari-pinned-tab.svg" color="#5bbad5" />
<meta name="msapplication-TileColor" content="#da532c" /> <meta name="msapplication-TileColor" content="#da532c" />
<meta name="msapplication-config" content="/favicons/browserconfig.xml" />
<meta name="theme-color" content="#ffffff" /> <meta name="theme-color" content="#ffffff" />
</Head> </Head>
); );

View File

@@ -12,6 +12,7 @@ import {
Select, Select,
FormHelperText, FormHelperText,
Code, Code,
IconButton,
} from "@chakra-ui/react"; } from "@chakra-ui/react";
import { type Evaluation, EvalType } from "@prisma/client"; import { type Evaluation, EvalType } from "@prisma/client";
import { useCallback, useState } from "react"; import { useCallback, useState } from "react";
@@ -183,46 +184,37 @@ export default function EditEvaluations() {
<Text flex={1}> <Text flex={1}>
{evaluation.evalType}: &quot;{evaluation.value}&quot; {evaluation.evalType}: &quot;{evaluation.value}&quot;
</Text> </Text>
<Button
<IconButton
aria-label="Edit"
variant="unstyled" variant="unstyled"
color="gray.400"
height="unset"
width="unset"
minW="unset" minW="unset"
color="gray.400"
onClick={() => setEditingId(evaluation.id)} onClick={() => setEditingId(evaluation.id)}
_hover={{ _hover={{ color: "gray.800", cursor: "pointer" }}
color: "gray.800", icon={<Icon as={BsPencil} />}
cursor: "pointer", />
}} <IconButton
> aria-label="Delete"
<Icon as={BsPencil} boxSize={4} />
</Button>
<Button
variant="unstyled" variant="unstyled"
color="gray.400"
height="unset"
width="unset"
minW="unset" minW="unset"
color="gray.400"
onClick={() => onDelete(evaluation.id)} onClick={() => onDelete(evaluation.id)}
_hover={{ _hover={{ color: "gray.800", cursor: "pointer" }}
color: "gray.800", icon={<Icon as={BsX} boxSize={6} />}
cursor: "pointer", />
}}
>
<Icon as={BsX} boxSize={6} />
</Button>
</HStack> </HStack>
), ),
)} )}
{editingId == null && ( {editingId == null && (
<Button <Button
onClick={() => setEditingId("new")} onClick={() => setEditingId("new")}
alignSelf="flex-start" alignSelf="end"
size="sm" size="sm"
mt={4} mt={4}
colorScheme="blue" colorScheme="blue"
> >
Add Evaluation New Evaluation
</Button> </Button>
)} )}
{editingId == "new" && ( {editingId == "new" && (

View File

@@ -1,103 +1,185 @@
import { Text, Button, HStack, Heading, Icon, Input, Stack } from "@chakra-ui/react"; import { Text, Button, HStack, Heading, Icon, IconButton, Stack, VStack } from "@chakra-ui/react";
import { useState } from "react"; import { type TemplateVariable } from "@prisma/client";
import { BsCheck, BsX } from "react-icons/bs"; import { useEffect, useState } from "react";
import { BsPencil, BsX } from "react-icons/bs";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import { useExperiment, useHandledAsyncCallback, useScenarioVars } from "~/utils/hooks";
import { maybeReportError } from "~/utils/errorHandling/maybeReportError";
import { FloatingLabelInput } from "./FloatingLabelInput";
export const ScenarioVar = ({
variable,
isEditing,
setIsEditing,
}: {
variable: Pick<TemplateVariable, "id" | "label">;
isEditing: boolean;
setIsEditing: (isEditing: boolean) => void;
}) => {
const utils = api.useContext();
const [label, setLabel] = useState(variable.label);
useEffect(() => {
setLabel(variable.label);
}, [variable.label]);
const renameVarMutation = api.scenarioVars.rename.useMutation();
const [onRename] = useHandledAsyncCallback(async () => {
const resp = await renameVarMutation.mutateAsync({ id: variable.id, label });
if (maybeReportError(resp)) return;
setIsEditing(false);
await utils.scenarioVars.list.invalidate();
await utils.scenarios.list.invalidate();
}, [label, variable.id]);
const deleteMutation = api.scenarioVars.delete.useMutation();
const [onDeleteVar] = useHandledAsyncCallback(async () => {
await deleteMutation.mutateAsync({ id: variable.id });
await utils.scenarioVars.list.invalidate();
}, [variable.id]);
if (isEditing) {
return (
<HStack w="full">
<FloatingLabelInput
flex={1}
label="Renamed Variable"
value={label}
onChange={(e) => setLabel(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter") {
e.preventDefault();
onRename();
}
// If the user types a space, replace it with an underscore
if (e.key === " ") {
e.preventDefault();
setLabel((label) => label && `${label}_`);
}
}}
/>
<Button size="sm" onClick={() => setIsEditing(false)}>
Cancel
</Button>
<Button size="sm" colorScheme="blue" onClick={onRename}>
Save
</Button>
</HStack>
);
} else {
return (
<HStack w="full" borderTopWidth={1} borderColor="gray.200">
<Text flex={1}>{variable.label}</Text>
<IconButton
aria-label="Edit"
variant="unstyled"
minW="unset"
color="gray.400"
onClick={() => setIsEditing(true)}
_hover={{ color: "gray.800", cursor: "pointer" }}
icon={<Icon as={BsPencil} />}
/>
<IconButton
aria-label="Delete"
variant="unstyled"
minW="unset"
color="gray.400"
onClick={onDeleteVar}
_hover={{ color: "gray.800", cursor: "pointer" }}
icon={<Icon as={BsX} boxSize={6} />}
/>
</HStack>
);
}
};
export default function EditScenarioVars() { export default function EditScenarioVars() {
const experiment = useExperiment(); const experiment = useExperiment();
const vars = const vars = useScenarioVars();
api.templateVars.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
const [currentlyEditingId, setCurrentlyEditingId] = useState<string | null>(null);
const [newVariable, setNewVariable] = useState<string>(""); const [newVariable, setNewVariable] = useState<string>("");
const newVarIsValid = newVariable.length > 0 && !vars.map((v) => v.label).includes(newVariable); const newVarIsValid = newVariable?.length ?? 0 > 0;
const utils = api.useContext(); const utils = api.useContext();
const addVarMutation = api.templateVars.create.useMutation(); const addVarMutation = api.scenarioVars.create.useMutation();
const [onAddVar] = useHandledAsyncCallback(async () => { const [onAddVar] = useHandledAsyncCallback(async () => {
if (!experiment.data?.id) return; if (!experiment.data?.id) return;
if (!newVarIsValid) return; if (!newVariable) return;
await addVarMutation.mutateAsync({ const resp = await addVarMutation.mutateAsync({
experimentId: experiment.data.id, experimentId: experiment.data.id,
label: newVariable, label: newVariable,
}); });
await utils.templateVars.list.invalidate(); if (maybeReportError(resp)) return;
await utils.scenarioVars.list.invalidate();
setNewVariable(""); setNewVariable("");
}, [addVarMutation, experiment.data?.id, newVarIsValid, newVariable]); }, [addVarMutation, experiment.data?.id, newVarIsValid, newVariable]);
const deleteMutation = api.templateVars.delete.useMutation();
const [onDeleteVar] = useHandledAsyncCallback(async (id: string) => {
await deleteMutation.mutateAsync({ id });
await utils.templateVars.list.invalidate();
}, []);
return ( return (
<Stack> <Stack>
<Heading size="sm">Scenario Variables</Heading> <Heading size="sm">Scenario Variables</Heading>
<Stack spacing={2}> <VStack spacing={4}>
<Text fontSize="sm"> <Text fontSize="sm">
Scenario variables can be used in your prompt variants as well as evaluations. Scenario variables can be used in your prompt variants as well as evaluations.
</Text> </Text>
<HStack spacing={0}> <VStack spacing={0} w="full">
<Input {vars.data?.map((variable) => (
placeholder="Add Scenario Variable" <ScenarioVar
size="sm" variable={variable}
borderTopRadius={0}
borderRightRadius={0}
value={newVariable}
onChange={(e) => setNewVariable(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter") {
e.preventDefault();
onAddVar();
}
// If the user types a space, replace it with an underscore
if (e.key === " ") {
e.preventDefault();
setNewVariable((v) => v + "_");
}
}}
/>
<Button
size="xs"
height="100%"
borderLeftRadius={0}
isDisabled={!newVarIsValid}
onClick={onAddVar}
>
<Icon as={BsCheck} boxSize={8} />
</Button>
</HStack>
<HStack spacing={2} py={4} wrap="wrap">
{vars.map((variable) => (
<HStack
key={variable.id} key={variable.id}
spacing={0} isEditing={currentlyEditingId === variable.id}
bgColor="blue.100" setIsEditing={(isEditing) => {
color="blue.600" if (isEditing) {
pl={2} setCurrentlyEditingId(variable.id);
pr={0} } else {
fontWeight="bold" setCurrentlyEditingId(null);
> }
<Text fontSize="sm" flex={1}> }}
{variable.label} />
</Text>
<Button
size="xs"
variant="ghost"
colorScheme="blue"
p="unset"
minW="unset"
px="unset"
onClick={() => onDeleteVar(variable.id)}
>
<Icon as={BsX} boxSize={6} color="blue.800" />
</Button>
</HStack>
))} ))}
</HStack> </VStack>
</Stack> {currentlyEditingId !== "new" && (
<Button
colorScheme="blue"
size="sm"
onClick={() => setCurrentlyEditingId("new")}
alignSelf="end"
>
New Variable
</Button>
)}
{currentlyEditingId === "new" && (
<HStack w="full">
<FloatingLabelInput
flex={1}
label="New Variable"
value={newVariable}
onChange={(e) => setNewVariable(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter") {
e.preventDefault();
onAddVar();
}
// If the user types a space, replace it with an underscore
if (e.key === " ") {
e.preventDefault();
setNewVariable((v) => v && `${v}_`);
}
}}
/>
<Button size="sm" onClick={() => setCurrentlyEditingId(null)}>
Cancel
</Button>
<Button size="sm" colorScheme="blue" onClick={onAddVar}>
Save
</Button>
</HStack>
)}
</VStack>
</Stack> </Stack>
); );
} }

View File

@@ -1,7 +1,7 @@
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { type PromptVariant, type Scenario } from "../types"; import { type PromptVariant, type Scenario } from "../types";
import { type StackProps, Text, VStack } from "@chakra-ui/react"; import { type StackProps, Text, VStack } from "@chakra-ui/react";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import { useScenarioVars, useHandledAsyncCallback } from "~/utils/hooks";
import SyntaxHighlighter from "react-syntax-highlighter"; import SyntaxHighlighter from "react-syntax-highlighter";
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs"; import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
import stringify from "json-stringify-pretty-compact"; import stringify from "json-stringify-pretty-compact";
@@ -23,10 +23,7 @@ export default function OutputCell({
variant: PromptVariant; variant: PromptVariant;
}): ReactElement | null { }): ReactElement | null {
const utils = api.useContext(); const utils = api.useContext();
const experiment = useExperiment(); const vars = useScenarioVars().data;
const vars = api.templateVars.list.useQuery({
experimentId: experiment.data?.id ?? "",
}).data;
const scenarioVariables = scenario.variableValues as Record<string, string>; const scenarioVariables = scenario.variableValues as Record<string, string>;
const templateHasVariables = const templateHasVariables =
@@ -110,7 +107,7 @@ export default function OutputCell({
if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>; if (disabledReason) return <Text color="gray.500">{disabledReason}</Text>;
const showLogs = !streamedMessage && !mostRecentResponse?.output; const showLogs = !streamedMessage && !mostRecentResponse?.respPayload;
if (showLogs) if (showLogs)
return ( return (
@@ -163,13 +160,13 @@ export default function OutputCell({
</CellWrapper> </CellWrapper>
); );
const normalizedOutput = mostRecentResponse?.output const normalizedOutput = mostRecentResponse?.respPayload
? provider.normalizeOutput(mostRecentResponse?.output) ? provider.normalizeOutput(mostRecentResponse?.respPayload)
: streamedMessage : streamedMessage
? provider.normalizeOutput(streamedMessage) ? provider.normalizeOutput(streamedMessage)
: null; : null;
if (mostRecentResponse?.output && normalizedOutput?.type === "json") { if (mostRecentResponse?.respPayload && normalizedOutput?.type === "json") {
return ( return (
<CellWrapper> <CellWrapper>
<SyntaxHighlighter <SyntaxHighlighter
@@ -191,7 +188,7 @@ export default function OutputCell({
return ( return (
<CellWrapper> <CellWrapper>
<Text>{contentToDisplay}</Text> <Text whiteSpace="pre-wrap">{contentToDisplay}</Text>
</CellWrapper> </CellWrapper>
); );
} }

View File

@@ -19,8 +19,8 @@ export const OutputStats = ({
? modelResponse.receivedAt.getTime() - modelResponse.requestedAt.getTime() ? modelResponse.receivedAt.getTime() - modelResponse.requestedAt.getTime()
: 0; : 0;
const promptTokens = modelResponse.promptTokens; const inputTokens = modelResponse.inputTokens;
const completionTokens = modelResponse.completionTokens; const outputTokens = modelResponse.outputTokens;
return ( return (
<HStack <HStack
@@ -55,8 +55,8 @@ export const OutputStats = ({
</HStack> </HStack>
{modelResponse.cost && ( {modelResponse.cost && (
<CostTooltip <CostTooltip
promptTokens={promptTokens} inputTokens={inputTokens}
completionTokens={completionTokens} outputTokens={outputTokens}
cost={modelResponse.cost} cost={modelResponse.cost}
> >
<HStack spacing={0}> <HStack spacing={0}>

View File

@@ -1,7 +1,7 @@
import { isEqual } from "lodash-es"; import { isEqual } from "lodash-es";
import { useEffect, useState, type DragEvent } from "react"; import { useEffect, useState, type DragEvent } from "react";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useExperiment, useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks"; import { useExperimentAccess, useHandledAsyncCallback, useScenarioVars } from "~/utils/hooks";
import { type Scenario } from "./types"; import { type Scenario } from "./types";
import { import {
@@ -41,8 +41,7 @@ export default function ScenarioEditor({
if (savedValues) setValues(savedValues); if (savedValues) setValues(savedValues);
}, [savedValues]); }, [savedValues]);
const experiment = useExperiment(); const vars = useScenarioVars();
const vars = api.templateVars.list.useQuery({ experimentId: experiment.data?.id ?? "" });
const variableLabels = vars.data?.map((v) => v.label) ?? []; const variableLabels = vars.data?.map((v) => v.label) ?? [];

View File

@@ -58,7 +58,7 @@ export const ScenarioEditorModal = ({
await utils.scenarios.list.invalidate(); await utils.scenarios.list.invalidate();
}, [mutation, values]); }, [mutation, values]);
const vars = api.templateVars.list.useQuery({ experimentId: experiment.data?.id ?? "" }); const vars = api.scenarioVars.list.useQuery({ experimentId: experiment.data?.id ?? "" });
const variableLabels = vars.data?.map((v) => v.label) ?? []; const variableLabels = vars.data?.map((v) => v.label) ?? [];
return ( return (

View File

@@ -72,7 +72,7 @@ export const ScenariosHeader = () => {
Autogenerate Scenario Autogenerate Scenario
</MenuItem> </MenuItem>
<MenuItem icon={<BsPencil />} onClick={openDrawer}> <MenuItem icon={<BsPencil />} onClick={openDrawer}>
Edit Vars Add or Remove Variables
</MenuItem> </MenuItem>
</MenuList> </MenuList>
</Menu> </Menu>

View File

@@ -17,8 +17,8 @@ export default function VariantStats(props: { variant: PromptVariant }) {
initialData: { initialData: {
evalResults: [], evalResults: [],
overallCost: 0, overallCost: 0,
promptTokens: 0, inputTokens: 0,
completionTokens: 0, outputTokens: 0,
scenarioCount: 0, scenarioCount: 0,
outputCount: 0, outputCount: 0,
awaitingEvals: false, awaitingEvals: false,
@@ -68,8 +68,8 @@ export default function VariantStats(props: { variant: PromptVariant }) {
</HStack> </HStack>
{data.overallCost && ( {data.overallCost && (
<CostTooltip <CostTooltip
promptTokens={data.promptTokens} inputTokens={data.inputTokens}
completionTokens={data.completionTokens} outputTokens={data.outputTokens}
cost={data.overallCost} cost={data.overallCost}
> >
<HStack spacing={0} align="center" color="gray.500"> <HStack spacing={0} align="center" color="gray.500">

View File

@@ -90,15 +90,23 @@ function TableRow({
isExpanded: boolean; isExpanded: boolean;
onToggle: () => void; onToggle: () => void;
}) { }) {
const isError = loggedCall.modelResponse?.respStatus !== 200; const isError = loggedCall.modelResponse?.statusCode !== 200;
const timeAgo = dayjs(loggedCall.startTime).fromNow(); const timeAgo = dayjs(loggedCall.requestedAt).fromNow();
const fullTime = dayjs(loggedCall.startTime).toString(); const fullTime = dayjs(loggedCall.requestedAt).toString();
const model = useMemo( const model = useMemo(
() => loggedCall.tags.find((tag) => tag.name.startsWith("$model"))?.value, () => loggedCall.tags.find((tag) => tag.name.startsWith("$model"))?.value,
[loggedCall.tags], [loggedCall.tags],
); );
const durationCell = (
<Td isNumeric>
{loggedCall.cacheHit
? "Cache hit"
: ((loggedCall.modelResponse?.durationMs ?? 0) / 1000).toFixed(2) + "s"}
</Td>
);
return ( return (
<> <>
<Tr <Tr
@@ -120,11 +128,11 @@ function TableRow({
</Tooltip> </Tooltip>
</Td> </Td>
<Td width="100%">{model}</Td> <Td width="100%">{model}</Td>
<Td isNumeric>{((loggedCall.modelResponse?.durationMs ?? 0) / 1000).toFixed(2)}s</Td> {durationCell}
<Td isNumeric>{loggedCall.modelResponse?.inputTokens}</Td> <Td isNumeric>{loggedCall.modelResponse?.inputTokens}</Td>
<Td isNumeric>{loggedCall.modelResponse?.outputTokens}</Td> <Td isNumeric>{loggedCall.modelResponse?.outputTokens}</Td>
<Td sx={{ color: isError ? "red.500" : "green.500", fontWeight: "semibold" }} isNumeric> <Td sx={{ color: isError ? "red.500" : "green.500", fontWeight: "semibold" }} isNumeric>
{loggedCall.modelResponse?.respStatus ?? "No response"} {loggedCall.modelResponse?.statusCode ?? "No response"}
</Td> </Td>
</Tr> </Tr>
<Tr> <Tr>

View File

@@ -0,0 +1,61 @@
import {
ResponsiveContainer,
LineChart,
Line,
XAxis,
YAxis,
CartesianGrid,
Tooltip,
Legend,
} from "recharts";
import { useMemo } from "react";
import { useSelectedProject } from "~/utils/hooks";
import dayjs from "~/utils/dayjs";
import { api } from "~/utils/api";
export default function UsageGraph() {
const { data: selectedProject } = useSelectedProject();
const stats = api.dashboard.stats.useQuery(
{ projectId: selectedProject?.id ?? "" },
{ enabled: !!selectedProject },
);
const data = useMemo(() => {
return (
stats.data?.periods.map(({ period, numQueries, cost }) => ({
period,
Requests: numQueries,
"Total Spent (USD)": parseFloat(cost.toString()),
})) || []
);
}, [stats.data]);
return (
<ResponsiveContainer width="100%" height={400}>
<LineChart data={data} margin={{ top: 5, right: 20, left: 10, bottom: 5 }}>
<XAxis dataKey="period" tickFormatter={(str: string) => dayjs(str).format("MMM D")} />
<YAxis yAxisId="left" dataKey="Requests" orientation="left" stroke="#8884d8" />
<YAxis
yAxisId="right"
dataKey="Total Spent (USD)"
orientation="right"
unit="$"
stroke="#82ca9d"
/>
<Tooltip />
<Legend />
<CartesianGrid stroke="#f5f5f5" />
<Line dataKey="Requests" stroke="#8884d8" yAxisId="left" dot={false} strokeWidth={2} />
<Line
dataKey="Total Spent (USD)"
stroke="#82ca9d"
yAxisId="right"
dot={false}
strokeWidth={2}
/>
</LineChart>
</ResponsiveContainer>
);
}

View File

@@ -72,12 +72,12 @@ const CountLabel = ({ label, count }: { label: string; count: number }) => {
export const NewDatasetCard = () => { export const NewDatasetCard = () => {
const router = useRouter(); const router = useRouter();
const selectedOrgId = useAppStore((s) => s.selectedOrgId); const selectedProjectId = useAppStore((s) => s.selectedProjectId);
const createMutation = api.datasets.create.useMutation(); const createMutation = api.datasets.create.useMutation();
const [createDataset, isLoading] = useHandledAsyncCallback(async () => { const [createDataset, isLoading] = useHandledAsyncCallback(async () => {
const newDataset = await createMutation.mutateAsync({ organizationId: selectedOrgId ?? "" }); const newDataset = await createMutation.mutateAsync({ projectId: selectedProjectId ?? "" });
await router.push({ pathname: "/data/[id]", query: { id: newDataset.id } }); await router.push({ pathname: "/data/[id]", query: { id: newDataset.id } });
}, [createMutation, router, selectedOrgId]); }, [createMutation, router, selectedProjectId]);
return ( return (
<AspectRatio ratio={1.2} w="full"> <AspectRatio ratio={1.2} w="full">

View File

@@ -76,17 +76,17 @@ const CountLabel = ({ label, count }: { label: string; count: number }) => {
export const NewExperimentCard = () => { export const NewExperimentCard = () => {
const router = useRouter(); const router = useRouter();
const selectedOrgId = useAppStore((s) => s.selectedOrgId); const selectedProjectId = useAppStore((s) => s.selectedProjectId);
const createMutation = api.experiments.create.useMutation(); const createMutation = api.experiments.create.useMutation();
const [createExperiment, isLoading] = useHandledAsyncCallback(async () => { const [createExperiment, isLoading] = useHandledAsyncCallback(async () => {
const newExperiment = await createMutation.mutateAsync({ const newExperiment = await createMutation.mutateAsync({
organizationId: selectedOrgId ?? "", projectId: selectedProjectId ?? "",
}); });
await router.push({ await router.push({
pathname: "/experiments/[id]", pathname: "/experiments/[id]",
query: { id: newExperiment.id }, query: { id: newExperiment.id },
}); });
}, [createMutation, router, selectedOrgId]); }, [createMutation, router, selectedProjectId]);
return ( return (
<AspectRatio ratio={1.2} w="full"> <AspectRatio ratio={1.2} w="full">

View File

@@ -10,15 +10,15 @@ export const useOnForkButtonPressed = () => {
const user = useSession().data; const user = useSession().data;
const experiment = useExperiment(); const experiment = useExperiment();
const selectedOrgId = useAppStore((state) => state.selectedOrgId); const selectedProjectId = useAppStore((state) => state.selectedProjectId);
const forkMutation = api.experiments.fork.useMutation(); const forkMutation = api.experiments.fork.useMutation();
const [onFork, isForking] = useHandledAsyncCallback(async () => { const [onFork, isForking] = useHandledAsyncCallback(async () => {
if (!experiment.data?.id || !selectedOrgId) return; if (!experiment.data?.id || !selectedProjectId) return;
const forkedExperimentId = await forkMutation.mutateAsync({ const forkedExperimentId = await forkMutation.mutateAsync({
id: experiment.data.id, id: experiment.data.id,
organizationId: selectedOrgId, projectId: selectedProjectId,
}); });
await router.push({ pathname: "/experiments/[id]", query: { id: forkedExperimentId } }); await router.push({ pathname: "/experiments/[id]", query: { id: forkedExperimentId } });
}, [forkMutation, experiment.data?.id, router]); }, [forkMutation, experiment.data?.id, router]);

View File

@@ -40,8 +40,15 @@ const NavSidebar = () => {
borderRightWidth={1} borderRightWidth={1}
borderColor="gray.300" borderColor="gray.300"
> >
<HStack as={Link} href="/" _hover={{ textDecoration: "none" }} spacing={0} px={2} py={2}> <HStack
<Image src="/logo.svg" alt="" boxSize={6} mr={4} /> as={Link}
href="/"
_hover={{ textDecoration: "none" }}
spacing={{ base: 1, md: 0 }}
mx={2}
py={{ base: 1, md: 2 }}
>
<Image src="/logo.svg" alt="" boxSize={6} mr={4} ml={{ base: 0.5, md: 0 }} />
<Heading size="md" fontFamily="inconsolata, monospace"> <Heading size="md" fontFamily="inconsolata, monospace">
OpenPipe OpenPipe
</Heading> </Heading>

View File

@@ -1,11 +1,10 @@
import { Box, type BoxProps } from "@chakra-ui/react"; import { Box, type BoxProps, forwardRef } from "@chakra-ui/react";
import { useRouter } from "next/router"; import { useRouter } from "next/router";
const NavSidebarOption = ({ const NavSidebarOption = forwardRef<
activeHrefPattern, { activeHrefPattern?: string; disableHoverEffect?: boolean } & BoxProps,
disableHoverEffect, "div"
...props >(({ activeHrefPattern, disableHoverEffect, ...props }, ref) => {
}: { activeHrefPattern?: string; disableHoverEffect?: boolean } & BoxProps) => {
const router = useRouter(); const router = useRouter();
const isActive = activeHrefPattern && router.pathname.startsWith(activeHrefPattern); const isActive = activeHrefPattern && router.pathname.startsWith(activeHrefPattern);
return ( return (
@@ -18,10 +17,13 @@ const NavSidebarOption = ({
cursor="pointer" cursor="pointer"
borderRadius={4} borderRadius={4}
{...props} {...props}
ref={ref}
> >
{props.children} {props.children}
</Box> </Box>
); );
}; });
NavSidebarOption.displayName = "NavSidebarOption";
export default NavSidebarOption; export default NavSidebarOption;

View File

@@ -1,12 +1,12 @@
import { HStack, Flex, Text } from "@chakra-ui/react"; import { HStack, Flex, Text } from "@chakra-ui/react";
import { useSelectedOrg } from "~/utils/hooks"; import { useSelectedProject } from "~/utils/hooks";
// Have to export only contents here instead of full BreadcrumbItem because Chakra doesn't // Have to export only contents here instead of full BreadcrumbItem because Chakra doesn't
// recognize a BreadcrumbItem exported with this component as a valid child of Breadcrumb. // recognize a BreadcrumbItem exported with this component as a valid child of Breadcrumb.
export default function ProjectBreadcrumbContents({ orgName = "" }: { orgName?: string }) { export default function ProjectBreadcrumbContents({ projectName = "" }: { projectName?: string }) {
const { data: selectedOrg } = useSelectedOrg(); const { data: selectedProject } = useSelectedProject();
orgName = orgName || selectedOrg?.name || ""; projectName = projectName || selectedProject?.name || "";
return ( return (
<HStack w="full"> <HStack w="full">
@@ -18,10 +18,10 @@ export default function ProjectBreadcrumbContents({ orgName = "" }: { orgName?:
alignItems="center" alignItems="center"
justifyContent="center" justifyContent="center"
> >
<Text>{orgName[0]?.toUpperCase()}</Text> <Text>{projectName[0]?.toUpperCase()}</Text>
</Flex> </Flex>
<Text display={{ base: "none", md: "block" }} py={1}> <Text display={{ base: "none", md: "block" }} py={1}>
{orgName} {projectName}
</Text> </Text>
</HStack> </HStack>
); );

View File

@@ -15,41 +15,43 @@ import {
} from "@chakra-ui/react"; } from "@chakra-ui/react";
import React, { useEffect, useState } from "react"; import React, { useEffect, useState } from "react";
import Link from "next/link"; import Link from "next/link";
import { AiFillCaretDown } from "react-icons/ai"; import { BsChevronRight, BsGear, BsPlus } from "react-icons/bs";
import { BsGear, BsPlus } from "react-icons/bs"; import { type Project } from "@prisma/client";
import { type Organization } from "@prisma/client";
import { useAppStore } from "~/state/store"; import { useAppStore } from "~/state/store";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import NavSidebarOption from "./NavSidebarOption"; import NavSidebarOption from "./NavSidebarOption";
import { useHandledAsyncCallback, useSelectedOrg } from "~/utils/hooks"; import { useHandledAsyncCallback, useSelectedProject } from "~/utils/hooks";
import { useRouter } from "next/router"; import { useRouter } from "next/router";
export default function ProjectMenu() { export default function ProjectMenu() {
const router = useRouter(); const router = useRouter();
const isActive = router.pathname.startsWith("/home");
const utils = api.useContext(); const utils = api.useContext();
const selectedOrgId = useAppStore((s) => s.selectedOrgId); const selectedProjectId = useAppStore((s) => s.selectedProjectId);
const setSelectedOrgId = useAppStore((s) => s.setSelectedOrgId); const setselectedProjectId = useAppStore((s) => s.setselectedProjectId);
const { data: orgs } = api.organizations.list.useQuery(); const { data: projects } = api.projects.list.useQuery();
useEffect(() => { useEffect(() => {
if (orgs && orgs[0] && (!selectedOrgId || !orgs.find((org) => org.id === selectedOrgId))) { if (
setSelectedOrgId(orgs[0].id); projects &&
projects[0] &&
(!selectedProjectId || !projects.find((proj) => proj.id === selectedProjectId))
) {
setselectedProjectId(projects[0].id);
} }
}, [selectedOrgId, setSelectedOrgId, orgs]); }, [selectedProjectId, setselectedProjectId, projects]);
const { data: selectedOrg } = useSelectedOrg(); const { data: selectedProject } = useSelectedProject();
const popover = useDisclosure(); const popover = useDisclosure();
const createMutation = api.organizations.create.useMutation(); const createMutation = api.projects.create.useMutation();
const [createProject, isLoading] = useHandledAsyncCallback(async () => { const [createProject, isLoading] = useHandledAsyncCallback(async () => {
const newOrg = await createMutation.mutateAsync({ name: "New Project" }); const newProj = await createMutation.mutateAsync({ name: "New Project" });
await utils.organizations.list.invalidate(); await utils.projects.list.invalidate();
setSelectedOrgId(newOrg.id); setselectedProjectId(newProj.id);
await router.push({ pathname: "/project/settings" }); await router.push({ pathname: "/project/settings" });
}, [createMutation, router]); }, [createMutation, router]);
@@ -65,15 +67,10 @@ export default function ProjectMenu() {
> >
PROJECT PROJECT
</Text> </Text>
<NavSidebarOption> <Popover placement="right-end" isOpen={popover.isOpen} onClose={popover.onClose} closeOnBlur>
<Popover <PopoverTrigger>
placement="bottom-start" <NavSidebarOption>
isOpen={popover.isOpen} <HStack w="full" onClick={popover.onToggle}>
onClose={popover.onClose}
closeOnBlur
>
<PopoverTrigger>
<HStack w="full" justifyContent="space-between" onClick={popover.onToggle}>
<Flex <Flex
p={1} p={1}
borderRadius={4} borderRadius={4}
@@ -83,74 +80,68 @@ export default function ProjectMenu() {
m={{ base: 0, md: 1 }} m={{ base: 0, md: 1 }}
alignItems="center" alignItems="center"
justifyContent="center" justifyContent="center"
// onClick={sidebarExpanded ? undefined : openMenu}
> >
<Text>{selectedOrg?.name[0]?.toUpperCase()}</Text> <Text>{selectedProject?.name[0]?.toUpperCase()}</Text>
</Flex> </Flex>
<Text fontSize="sm" display={{ base: "none", md: "block" }} py={1}> <Text fontSize="sm" display={{ base: "none", md: "block" }} py={1} flex={1}>
{selectedOrg?.name} {selectedProject?.name}
</Text> </Text>
<Icon as={AiFillCaretDown} boxSize={3} size="xs" color="gray.500" mr={2} /> <Icon as={BsChevronRight} boxSize={4} color="gray.500" />
</HStack> </HStack>
</PopoverTrigger> </NavSidebarOption>
<PopoverContent </PopoverTrigger>
_focusVisible={{ boxShadow: "unset" }} <PopoverContent _focusVisible={{ outline: "unset" }} ml={-1} w="auto" minW={100} maxW={280}>
minW={0} <VStack alignItems="flex-start" spacing={2} py={4} px={2}>
borderColor="blue.400" <Text color="gray.500" fontSize="xs" fontWeight="bold" pb={1}>
w="full" PROJECTS
> </Text>
<VStack alignItems="flex-start" spacing={2} py={4} px={2}> <Divider />
<Text color="gray.500" fontSize="xs" fontWeight="bold" pb={1}> <VStack spacing={0} w="full">
PROJECTS {projects?.map((proj) => (
</Text> <ProjectOption
<Divider /> key={proj.id}
<VStack spacing={0} w="full"> proj={proj}
{orgs?.map((org) => ( isActive={proj.id === selectedProjectId}
<ProjectOption onClose={popover.onClose}
key={org.id} />
org={org} ))}
isActive={org.id === selectedOrgId}
onClose={popover.onClose}
/>
))}
</VStack>
<HStack
as={Button}
variant="ghost"
colorScheme="blue"
color="blue.400"
pr={8}
w="full"
onClick={createProject}
>
<Icon as={isLoading ? Spinner : BsPlus} boxSize={6} />
<Text>New project</Text>
</HStack>
</VStack> </VStack>
</PopoverContent> <HStack
</Popover> as={Button}
</NavSidebarOption> variant="ghost"
colorScheme="blue"
color="blue.400"
pr={8}
w="full"
onClick={createProject}
>
<Icon as={isLoading ? Spinner : BsPlus} boxSize={6} />
<Text>New project</Text>
</HStack>
</VStack>
</PopoverContent>
</Popover>
</VStack> </VStack>
); );
} }
const ProjectOption = ({ const ProjectOption = ({
org, proj,
isActive, isActive,
onClose, onClose,
}: { }: {
org: Organization; proj: Project;
isActive: boolean; isActive: boolean;
onClose: () => void; onClose: () => void;
}) => { }) => {
const setSelectedOrgId = useAppStore((s) => s.setSelectedOrgId); const setselectedProjectId = useAppStore((s) => s.setselectedProjectId);
const [gearHovered, setGearHovered] = useState(false); const [gearHovered, setGearHovered] = useState(false);
return ( return (
<HStack <HStack
as={Link} as={Link}
href="/experiments" href="/experiments"
onClick={() => { onClick={() => {
setSelectedOrgId(org.id); setselectedProjectId(proj.id);
onClose(); onClose();
}} }}
w="full" w="full"
@@ -158,12 +149,14 @@ const ProjectOption = ({
bgColor={isActive ? "gray.100" : "transparent"} bgColor={isActive ? "gray.100" : "transparent"}
_hover={gearHovered ? undefined : { bgColor: "gray.200", textDecoration: "none" }} _hover={gearHovered ? undefined : { bgColor: "gray.200", textDecoration: "none" }}
p={2} p={2}
borderRadius={4}
spacing={4}
> >
<Text>{org.name}</Text> <Text>{proj.name}</Text>
<IconButton <IconButton
as={Link} as={Link}
href="/project/settings" href="/project/settings"
aria-label={`Open ${org.name} settings`} aria-label={`Open ${proj.name} settings`}
icon={<Icon as={BsGear} boxSize={5} strokeWidth={0.5} color="gray.500" />} icon={<Icon as={BsGear} boxSize={5} strokeWidth={0.5} color="gray.500" />}
variant="ghost" variant="ghost"
size="xs" size="xs"

View File

@@ -9,7 +9,6 @@ import {
PopoverContent, PopoverContent,
Link, Link,
type StackProps, type StackProps,
Box,
} from "@chakra-ui/react"; } from "@chakra-ui/react";
import { type Session } from "next-auth"; import { type Session } from "next-auth";
import { signOut } from "next-auth/react"; import { signOut } from "next-auth/react";
@@ -27,30 +26,28 @@ export default function UserMenu({ user, ...rest }: { user: Session } & StackPro
<> <>
<Popover placement="right"> <Popover placement="right">
<PopoverTrigger> <PopoverTrigger>
<Box> <NavSidebarOption>
<NavSidebarOption> <HStack
<HStack // Weird values to make mobile look right; can clean up when we make the sidebar disappear on mobile
// Weird values to make mobile look right; can clean up when we make the sidebar disappear on mobile py={2}
py={2} px={1}
px={1} spacing={3}
spacing={3} {...rest}
{...rest} >
> {profileImage}
{profileImage} <VStack spacing={0} align="start" flex={1} flexShrink={1}>
<VStack spacing={0} align="start" flex={1} flexShrink={1}> <Text fontWeight="bold" fontSize="sm">
<Text fontWeight="bold" fontSize="sm"> {user.user.name}
{user.user.name} </Text>
</Text> <Text color="gray.500" fontSize="xs">
<Text color="gray.500" fontSize="xs"> {/* {user.user.email} */}
{/* {user.user.email} */} </Text>
</Text> </VStack>
</VStack> <Icon as={BsChevronRight} boxSize={4} color="gray.500" />
<Icon as={BsChevronRight} boxSize={4} color="gray.500" /> </HStack>
</HStack> </NavSidebarOption>
</NavSidebarOption>
</Box>
</PopoverTrigger> </PopoverTrigger>
<PopoverContent _focusVisible={{ boxShadow: "unset", outline: "unset" }} maxW="200px"> <PopoverContent _focusVisible={{ outline: "unset" }} ml={-1} minW={48} w="full">
<VStack align="stretch" spacing={0}> <VStack align="stretch" spacing={0}>
{/* sign out */} {/* sign out */}
<HStack <HStack

View File

@@ -16,7 +16,7 @@ import {
import { useRouter } from "next/router"; import { useRouter } from "next/router";
import { useRef, useState } from "react"; import { useRef, useState } from "react";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useHandledAsyncCallback, useSelectedOrg } from "~/utils/hooks"; import { useHandledAsyncCallback, useSelectedProject } from "~/utils/hooks";
export const DeleteProjectDialog = ({ export const DeleteProjectDialog = ({
isOpen, isOpen,
@@ -25,20 +25,20 @@ export const DeleteProjectDialog = ({
isOpen: boolean; isOpen: boolean;
onClose: () => void; onClose: () => void;
}) => { }) => {
const selectedOrg = useSelectedOrg(); const selectedProject = useSelectedProject();
const deleteMutation = api.organizations.delete.useMutation(); const deleteMutation = api.projects.delete.useMutation();
const utils = api.useContext(); const utils = api.useContext();
const router = useRouter(); const router = useRouter();
const cancelRef = useRef<HTMLButtonElement>(null); const cancelRef = useRef<HTMLButtonElement>(null);
const [onDeleteConfirm, isDeleting] = useHandledAsyncCallback(async () => { const [onDeleteConfirm, isDeleting] = useHandledAsyncCallback(async () => {
if (!selectedOrg.data?.id) return; if (!selectedProject.data?.id) return;
await deleteMutation.mutateAsync({ id: selectedOrg.data.id }); await deleteMutation.mutateAsync({ id: selectedProject.data.id });
await utils.organizations.list.invalidate(); await utils.projects.list.invalidate();
await router.push({ pathname: "/experiments" }); await router.push({ pathname: "/experiments" });
onClose(); onClose();
}, [deleteMutation, selectedOrg, router]); }, [deleteMutation, selectedProject, router]);
const [nameToDelete, setNameToDelete] = useState(""); const [nameToDelete, setNameToDelete] = useState("");
@@ -58,10 +58,10 @@ export const DeleteProjectDialog = ({
of the project below. of the project below.
</Text> </Text>
<Box bgColor="orange.100" w="full" p={2} borderRadius={4}> <Box bgColor="orange.100" w="full" p={2} borderRadius={4}>
<Text fontFamily="inconsolata">{selectedOrg.data?.name}</Text> <Text fontFamily="inconsolata">{selectedProject.data?.name}</Text>
</Box> </Box>
<Input <Input
placeholder={selectedOrg.data?.name} placeholder={selectedProject.data?.name}
value={nameToDelete} value={nameToDelete}
onChange={(e) => setNameToDelete(e.target.value)} onChange={(e) => setNameToDelete(e.target.value)}
/> />
@@ -76,7 +76,7 @@ export const DeleteProjectDialog = ({
colorScheme="red" colorScheme="red"
onClick={onDeleteConfirm} onClick={onDeleteConfirm}
ml={3} ml={3}
isDisabled={nameToDelete !== selectedOrg.data?.name} isDisabled={nameToDelete !== selectedProject.data?.name}
w={20} w={20}
> >
{isDeleting ? <Spinner /> : "Delete"} {isDeleting ? <Spinner /> : "Delete"}

View File

@@ -2,14 +2,14 @@ import { HStack, Icon, Text, Tooltip, type TooltipProps, VStack, Divider } from
import { BsCurrencyDollar } from "react-icons/bs"; import { BsCurrencyDollar } from "react-icons/bs";
type CostTooltipProps = { type CostTooltipProps = {
promptTokens: number | null; inputTokens: number | null;
completionTokens: number | null; outputTokens: number | null;
cost: number; cost: number;
} & TooltipProps; } & TooltipProps;
export const CostTooltip = ({ export const CostTooltip = ({
promptTokens, inputTokens,
completionTokens, outputTokens,
cost, cost,
children, children,
...props ...props
@@ -36,12 +36,12 @@ export const CostTooltip = ({
<HStack> <HStack>
<VStack w="28" spacing={1}> <VStack w="28" spacing={1}>
<Text>Prompt</Text> <Text>Prompt</Text>
<Text>{promptTokens ?? 0}</Text> <Text>{inputTokens ?? 0}</Text>
</VStack> </VStack>
<Divider borderColor="gray.200" h={8} orientation="vertical" /> <Divider borderColor="gray.200" h={8} orientation="vertical" />
<VStack w="28" spacing={1}> <VStack w="28" spacing={1}>
<Text whiteSpace="nowrap">Completion</Text> <Text whiteSpace="nowrap">Completion</Text>
<Text>{completionTokens ?? 0}</Text> <Text>{outputTokens ?? 0}</Text>
</VStack> </VStack>
</HStack> </HStack>
</VStack> </VStack>

View File

@@ -9,7 +9,8 @@
"claude-2", "claude-2",
"claude-2.0", "claude-2.0",
"claude-instant-1", "claude-instant-1",
"claude-instant-1.1" "claude-instant-1.1",
"claude-instant-1.2"
] ]
}, },
"prompt": { "prompt": {

View File

@@ -28,6 +28,10 @@ const modelProvider: AnthropicProvider = {
inputSchema: inputSchema as JSONSchema4, inputSchema: inputSchema as JSONSchema4,
canStream: true, canStream: true,
getCompletion, getCompletion,
getUsage: (input, output) => {
// TODO: add usage logic
return null;
},
...frontendModelProvider, ...frontendModelProvider,
}; };

View File

@@ -4,14 +4,10 @@ import {
type ChatCompletion, type ChatCompletion,
type CompletionCreateParams, type CompletionCreateParams,
} from "openai/resources/chat"; } from "openai/resources/chat";
import { countOpenAIChatTokens } from "~/utils/countTokens";
import { type CompletionResponse } from "../types"; import { type CompletionResponse } from "../types";
import { isArray, isString, omit } from "lodash-es"; import { isArray, isString, omit } from "lodash-es";
import { openai } from "~/server/utils/openai"; import { openai } from "~/server/utils/openai";
import { truthyFilter } from "~/utils/utils";
import { APIError } from "openai"; import { APIError } from "openai";
import frontendModelProvider from "./frontend";
import modelProvider, { type SupportedModel } from ".";
const mergeStreamedChunks = ( const mergeStreamedChunks = (
base: ChatCompletion | null, base: ChatCompletion | null,
@@ -60,9 +56,6 @@ export async function getCompletion(
): Promise<CompletionResponse<ChatCompletion>> { ): Promise<CompletionResponse<ChatCompletion>> {
const start = Date.now(); const start = Date.now();
let finalCompletion: ChatCompletion | null = null; let finalCompletion: ChatCompletion | null = null;
let promptTokens: number | undefined = undefined;
let completionTokens: number | undefined = undefined;
const modelName = modelProvider.getModel(input) as SupportedModel;
try { try {
if (onStream) { if (onStream) {
@@ -86,16 +79,6 @@ export async function getCompletion(
autoRetry: false, autoRetry: false,
}; };
} }
try {
promptTokens = countOpenAIChatTokens(modelName, input.messages);
completionTokens = countOpenAIChatTokens(
modelName,
finalCompletion.choices.map((c) => c.message).filter(truthyFilter),
);
} catch (err) {
// TODO handle this, library seems like maybe it doesn't work with function calls?
console.error(err);
}
} else { } else {
const resp = await openai.chat.completions.create( const resp = await openai.chat.completions.create(
{ ...input, stream: false }, { ...input, stream: false },
@@ -104,25 +87,14 @@ export async function getCompletion(
}, },
); );
finalCompletion = resp; finalCompletion = resp;
promptTokens = resp.usage?.prompt_tokens ?? 0;
completionTokens = resp.usage?.completion_tokens ?? 0;
} }
const timeToComplete = Date.now() - start; const timeToComplete = Date.now() - start;
const { promptTokenPrice, completionTokenPrice } = frontendModelProvider.models[modelName];
let cost = undefined;
if (promptTokenPrice && completionTokenPrice && promptTokens && completionTokens) {
cost = promptTokens * promptTokenPrice + completionTokens * completionTokenPrice;
}
return { return {
type: "success", type: "success",
statusCode: 200, statusCode: 200,
value: finalCompletion, value: finalCompletion,
timeToComplete, timeToComplete,
promptTokens,
completionTokens,
cost,
}; };
} catch (error: unknown) { } catch (error: unknown) {
if (error instanceof APIError) { if (error instanceof APIError) {

View File

@@ -4,6 +4,8 @@ import inputSchema from "./codegen/input.schema.json";
import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat"; import { type ChatCompletion, type CompletionCreateParams } from "openai/resources/chat";
import { getCompletion } from "./getCompletion"; import { getCompletion } from "./getCompletion";
import frontendModelProvider from "./frontend"; import frontendModelProvider from "./frontend";
import { countOpenAIChatTokens } from "~/utils/countTokens";
import { truthyFilter } from "~/utils/utils";
const supportedModels = [ const supportedModels = [
"gpt-4-0613", "gpt-4-0613",
@@ -39,6 +41,41 @@ const modelProvider: OpenaiChatModelProvider = {
inputSchema: inputSchema as JSONSchema4, inputSchema: inputSchema as JSONSchema4,
canStream: true, canStream: true,
getCompletion, getCompletion,
getUsage: (input, output) => {
if (output.choices.length === 0) return null;
const model = modelProvider.getModel(input);
if (!model) return null;
let inputTokens: number;
let outputTokens: number;
if (output.usage) {
inputTokens = output.usage.prompt_tokens;
outputTokens = output.usage.completion_tokens;
} else {
try {
inputTokens = countOpenAIChatTokens(model, input.messages);
outputTokens = countOpenAIChatTokens(
model,
output.choices.map((c) => c.message).filter(truthyFilter),
);
} catch (err) {
inputTokens = 0;
outputTokens = 0;
// TODO handle this, library seems like maybe it doesn't work with function calls?
console.error(err);
}
}
const { promptTokenPrice, completionTokenPrice } = frontendModelProvider.models[model];
let cost = undefined;
if (promptTokenPrice && completionTokenPrice && inputTokens && outputTokens) {
cost = inputTokens * promptTokenPrice + outputTokens * completionTokenPrice;
}
return { inputTokens: inputTokens, outputTokens: outputTokens, cost };
},
...frontendModelProvider, ...frontendModelProvider,
}; };

View File

@@ -75,6 +75,10 @@ const modelProvider: ReplicateLlama2Provider = {
}, },
canStream: true, canStream: true,
getCompletion, getCompletion,
getUsage: (input, output) => {
// TODO: add usage logic
return null;
},
...frontendModelProvider, ...frontendModelProvider,
}; };

View File

@@ -43,9 +43,6 @@ export type CompletionResponse<T> =
value: T; value: T;
timeToComplete: number; timeToComplete: number;
statusCode: number; statusCode: number;
promptTokens?: number;
completionTokens?: number;
cost?: number;
}; };
export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = { export type ModelProvider<SupportedModels extends string, InputSchema, OutputSchema> = {
@@ -56,6 +53,10 @@ export type ModelProvider<SupportedModels extends string, InputSchema, OutputSch
input: InputSchema, input: InputSchema,
onStream: ((partialOutput: OutputSchema) => void) | null, onStream: ((partialOutput: OutputSchema) => void) | null,
) => Promise<CompletionResponse<OutputSchema>>; ) => Promise<CompletionResponse<OutputSchema>>;
getUsage: (
input: InputSchema,
output: OutputSchema,
) => { gpuRuntime?: number; inputTokens?: number; outputTokens?: number; cost?: number } | null;
// This is just a convenience for type inference, don't use it at runtime // This is just a convenience for type inference, don't use it at runtime
_outputSchema?: OutputSchema | null; _outputSchema?: OutputSchema | null;

View File

@@ -60,7 +60,7 @@ export default function Dataset() {
<PageHeaderContainer> <PageHeaderContainer>
<Breadcrumb> <Breadcrumb>
<BreadcrumbItem> <BreadcrumbItem>
<ProjectBreadcrumbContents orgName={dataset.data?.organization?.name} /> <ProjectBreadcrumbContents projectName={dataset.data?.project?.name} />
</BreadcrumbItem> </BreadcrumbItem>
<BreadcrumbItem> <BreadcrumbItem>
<Link href="/data"> <Link href="/data">

View File

@@ -109,7 +109,7 @@ export default function Experiment() {
<PageHeaderContainer> <PageHeaderContainer>
<Breadcrumb> <Breadcrumb>
<BreadcrumbItem> <BreadcrumbItem>
<ProjectBreadcrumbContents orgName={experiment.data?.organization?.name} /> <ProjectBreadcrumbContents projectName={experiment.data?.project?.name} />
</BreadcrumbItem> </BreadcrumbItem>
<BreadcrumbItem> <BreadcrumbItem>
<Link href="/experiments"> <Link href="/experiments">

View File

@@ -18,47 +18,26 @@ import {
Breadcrumb, Breadcrumb,
BreadcrumbItem, BreadcrumbItem,
} from "@chakra-ui/react"; } from "@chakra-ui/react";
import {
LineChart,
Line,
XAxis,
YAxis,
CartesianGrid,
Tooltip,
Legend,
ResponsiveContainer,
} from "recharts";
import { Ban, DollarSign, Hash } from "lucide-react"; import { Ban, DollarSign, Hash } from "lucide-react";
import { useMemo } from "react";
import AppShell from "~/components/nav/AppShell"; import AppShell from "~/components/nav/AppShell";
import PageHeaderContainer from "~/components/nav/PageHeaderContainer"; import PageHeaderContainer from "~/components/nav/PageHeaderContainer";
import ProjectBreadcrumbContents from "~/components/nav/ProjectBreadcrumbContents"; import ProjectBreadcrumbContents from "~/components/nav/ProjectBreadcrumbContents";
import { useSelectedOrg } from "~/utils/hooks"; import { useSelectedProject } from "~/utils/hooks";
import dayjs from "~/utils/dayjs";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import LoggedCallTable from "~/components/dashboard/LoggedCallTable"; import LoggedCallTable from "~/components/dashboard/LoggedCallTable";
import UsageGraph from "~/components/dashboard/UsageGraph";
export default function LoggedCalls() { export default function LoggedCalls() {
const { data: selectedOrg } = useSelectedOrg(); const { data: selectedProject } = useSelectedProject();
const stats = api.dashboard.stats.useQuery( const stats = api.dashboard.stats.useQuery(
{ organizationId: selectedOrg?.id ?? "" }, { projectId: selectedProject?.id ?? "" },
{ enabled: !!selectedOrg }, { enabled: !!selectedProject },
); );
const data = useMemo(() => {
return (
stats.data?.periods.map(({ period, numQueries, totalCost }) => ({
period,
Requests: numQueries,
"Total Spent (USD)": parseFloat(totalCost.toString()),
})) || []
);
}, [stats.data]);
return ( return (
<AppShell requireAuth> <AppShell title="Logged Calls" requireAuth>
<PageHeaderContainer> <PageHeaderContainer>
<Breadcrumb> <Breadcrumb>
<BreadcrumbItem> <BreadcrumbItem>
@@ -71,7 +50,7 @@ export default function LoggedCalls() {
</PageHeaderContainer> </PageHeaderContainer>
<VStack px={8} pt={4} alignItems="flex-start" spacing={4}> <VStack px={8} pt={4} alignItems="flex-start" spacing={4}>
<Text fontSize="2xl" fontWeight="bold"> <Text fontSize="2xl" fontWeight="bold">
{selectedOrg?.name} {selectedProject?.name}
</Text> </Text>
<Divider /> <Divider />
<VStack margin="auto" spacing={4} align="stretch" w="full"> <VStack margin="auto" spacing={4} align="stretch" w="full">
@@ -83,39 +62,7 @@ export default function LoggedCalls() {
</Heading> </Heading>
</CardHeader> </CardHeader>
<CardBody> <CardBody>
<ResponsiveContainer width="100%" height={400}> <UsageGraph />
<LineChart data={data} margin={{ top: 5, right: 20, left: 10, bottom: 5 }}>
<XAxis
dataKey="period"
tickFormatter={(str: string) => dayjs(str).format("MMM D")}
/>
<YAxis yAxisId="left" dataKey="Requests" orientation="left" stroke="#8884d8" />
<YAxis
yAxisId="right"
dataKey="Total Spent (USD)"
orientation="right"
unit="$"
stroke="#82ca9d"
/>
<Tooltip />
<Legend />
<CartesianGrid stroke="#f5f5f5" />
<Line
dataKey="Requests"
stroke="#8884d8"
yAxisId="left"
dot={false}
strokeWidth={2}
/>
<Line
dataKey="Total Spent (USD)"
stroke="#82ca9d"
yAxisId="right"
dot={false}
strokeWidth={2}
/>
</LineChart>
</ResponsiveContainer>
</CardBody> </CardBody>
</Card> </Card>
<VStack spacing="4" width="300px" align="stretch"> <VStack spacing="4" width="300px" align="stretch">
@@ -127,7 +74,7 @@ export default function LoggedCalls() {
<Icon as={DollarSign} boxSize={4} color="gray.500" /> <Icon as={DollarSign} boxSize={4} color="gray.500" />
</HStack> </HStack>
<StatNumber> <StatNumber>
${parseFloat(stats.data?.totals?.totalCost?.toString() ?? "0").toFixed(2)} ${parseFloat(stats.data?.totals?.cost?.toString() ?? "0").toFixed(3)}
</StatNumber> </StatNumber>
</Stat> </Stat>
</CardBody> </CardBody>

View File

@@ -5,7 +5,6 @@ import {
type TextProps, type TextProps,
VStack, VStack,
HStack, HStack,
Input,
Button, Button,
Divider, Divider,
Icon, Icon,
@@ -17,33 +16,39 @@ import { BsTrash } from "react-icons/bs";
import AppShell from "~/components/nav/AppShell"; import AppShell from "~/components/nav/AppShell";
import PageHeaderContainer from "~/components/nav/PageHeaderContainer"; import PageHeaderContainer from "~/components/nav/PageHeaderContainer";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useHandledAsyncCallback, useSelectedOrg } from "~/utils/hooks"; import { useHandledAsyncCallback, useSelectedProject } from "~/utils/hooks";
import ProjectBreadcrumbContents from "~/components/nav/ProjectBreadcrumbContents"; import ProjectBreadcrumbContents from "~/components/nav/ProjectBreadcrumbContents";
import CopiableCode from "~/components/CopiableCode"; import CopiableCode from "~/components/CopiableCode";
import { DeleteProjectDialog } from "~/components/projectSettings/DeleteProjectDialog"; import { DeleteProjectDialog } from "~/components/projectSettings/DeleteProjectDialog";
import AutoResizeTextArea from "~/components/AutoResizeTextArea";
export default function Settings() { export default function Settings() {
const utils = api.useContext(); const utils = api.useContext();
const { data: selectedOrg } = useSelectedOrg(); const { data: selectedProject } = useSelectedProject();
const apiKey = const apiKey =
selectedOrg?.apiKeys?.length && selectedOrg?.apiKeys[0] ? selectedOrg?.apiKeys[0].apiKey : ""; selectedProject?.apiKeys?.length && selectedProject?.apiKeys[0]
? selectedProject?.apiKeys[0].apiKey
: "";
const updateMutation = api.organizations.update.useMutation(); const updateMutation = api.projects.update.useMutation();
const [onSaveName] = useHandledAsyncCallback(async () => { const [onSaveName] = useHandledAsyncCallback(async () => {
if (name && name !== selectedOrg?.name && selectedOrg?.id) { if (name && name !== selectedProject?.name && selectedProject?.id) {
await updateMutation.mutateAsync({ await updateMutation.mutateAsync({
id: selectedOrg.id, id: selectedProject.id,
updates: { name }, updates: { name },
}); });
await Promise.all([utils.organizations.get.invalidate({ id: selectedOrg.id })]); await Promise.all([
utils.projects.get.invalidate({ id: selectedProject.id }),
utils.projects.list.invalidate(),
]);
} }
}, [updateMutation, selectedOrg]); }, [updateMutation, selectedProject]);
const [name, setName] = useState(selectedOrg?.name); const [name, setName] = useState(selectedProject?.name);
useEffect(() => { useEffect(() => {
setName(selectedOrg?.name); setName(selectedProject?.name);
}, [selectedOrg?.name]); }, [selectedProject?.name]);
const deleteProjectOpen = useDisclosure(); const deleteProjectOpen = useDisclosure();
@@ -66,7 +71,7 @@ export default function Settings() {
Project Settings Project Settings
</Text> </Text>
<Text fontSize="sm"> <Text fontSize="sm">
Configure your project settings. These settings only apply to {selectedOrg?.name}. Configure your project settings. These settings only apply to {selectedProject?.name}.
</Text> </Text>
</VStack> </VStack>
<VStack <VStack
@@ -82,7 +87,7 @@ export default function Settings() {
<Text fontWeight="bold" fontSize="xl"> <Text fontWeight="bold" fontSize="xl">
Display Name Display Name
</Text> </Text>
<Input <AutoResizeTextArea
w="full" w="full"
maxW={600} maxW={600}
value={name} value={name}
@@ -90,7 +95,7 @@ export default function Settings() {
borderColor="gray.300" borderColor="gray.300"
/> />
<Button <Button
isDisabled={!name || name === selectedOrg?.name} isDisabled={!name || name === selectedProject?.name}
colorScheme="orange" colorScheme="orange"
borderRadius={4} borderRadius={4}
mt={2} mt={2}
@@ -113,12 +118,12 @@ export default function Settings() {
</VStack> </VStack>
<CopiableCode code={apiKey} /> <CopiableCode code={apiKey} />
<Divider /> <Divider />
{selectedOrg?.personalOrgUserId ? ( {selectedProject?.personalProjectUserId ? (
<VStack alignItems="flex-start"> <VStack alignItems="flex-start">
<Subtitle>Personal Project</Subtitle> <Subtitle>Personal Project</Subtitle>
<Text fontSize="sm"> <Text fontSize="sm">
This project is {selectedOrg?.personalOrgUser?.name}'s personal project. It cannot This project is {selectedProject?.personalProjectUser?.name}'s personal project.
be deleted. It cannot be deleted.
</Text> </Text>
</VStack> </VStack>
) : ( ) : (
@@ -129,15 +134,18 @@ export default function Settings() {
</Text> </Text>
<HStack <HStack
as={Button} as={Button}
isDisabled={selectedOrg?.role !== "ADMIN"} isDisabled={selectedProject?.role !== "ADMIN"}
colorScheme="red" colorScheme="red"
variant="outline" variant="outline"
borderRadius={4} borderRadius={4}
mt={2} mt={2}
height="auto"
onClick={deleteProjectOpen.onOpen} onClick={deleteProjectOpen.onOpen}
> >
<Icon as={BsTrash} /> <Icon as={BsTrash} />
<Text>Delete {selectedOrg?.name}</Text> <Text overflowWrap="break-word" whiteSpace="normal" py={2}>
Delete {selectedProject?.name}
</Text>
</HStack> </HStack>
</VStack> </VStack>
)} )}

View File

@@ -3,13 +3,13 @@ import { createTRPCRouter } from "~/server/api/trpc";
import { experimentsRouter } from "./routers/experiments.router"; import { experimentsRouter } from "./routers/experiments.router";
import { scenariosRouter } from "./routers/scenarios.router"; import { scenariosRouter } from "./routers/scenarios.router";
import { scenarioVariantCellsRouter } from "./routers/scenarioVariantCells.router"; import { scenarioVariantCellsRouter } from "./routers/scenarioVariantCells.router";
import { templateVarsRouter } from "./routers/templateVariables.router"; import { scenarioVarsRouter } from "./routers/scenarioVariables.router";
import { evaluationsRouter } from "./routers/evaluations.router"; import { evaluationsRouter } from "./routers/evaluations.router";
import { worldChampsRouter } from "./routers/worldChamps.router"; import { worldChampsRouter } from "./routers/worldChamps.router";
import { datasetsRouter } from "./routers/datasets.router"; import { datasetsRouter } from "./routers/datasets.router";
import { datasetEntries } from "./routers/datasetEntries.router"; import { datasetEntries } from "./routers/datasetEntries.router";
import { externalApiRouter } from "./routers/externalApi.router"; import { externalApiRouter } from "./routers/externalApi.router";
import { organizationsRouter } from "./routers/organizations.router"; import { projectsRouter } from "./routers/projects.router";
import { dashboardRouter } from "./routers/dashboard.router"; import { dashboardRouter } from "./routers/dashboard.router";
/** /**
@@ -22,12 +22,12 @@ export const appRouter = createTRPCRouter({
experiments: experimentsRouter, experiments: experimentsRouter,
scenarios: scenariosRouter, scenarios: scenariosRouter,
scenarioVariantCells: scenarioVariantCellsRouter, scenarioVariantCells: scenarioVariantCellsRouter,
templateVars: templateVarsRouter, scenarioVars: scenarioVarsRouter,
evaluations: evaluationsRouter, evaluations: evaluationsRouter,
worldChamps: worldChampsRouter, worldChamps: worldChampsRouter,
datasets: datasetsRouter, datasets: datasetsRouter,
datasetEntries: datasetEntries, datasetEntries: datasetEntries,
organizations: organizationsRouter, projects: projectsRouter,
dashboard: dashboardRouter, dashboard: dashboardRouter,
externalApi: externalApiRouter, externalApi: externalApiRouter,
}); });

View File

@@ -10,7 +10,7 @@ export const dashboardRouter = createTRPCRouter({
z.object({ z.object({
// TODO: actually take startDate into account // TODO: actually take startDate into account
startDate: z.string().optional(), startDate: z.string().optional(),
organizationId: z.string(), projectId: z.string(),
}), }),
) )
.query(async ({ input }) => { .query(async ({ input }) => {
@@ -22,11 +22,11 @@ export const dashboardRouter = createTRPCRouter({
"LoggedCall.id", "LoggedCall.id",
"LoggedCallModelResponse.originalLoggedCallId", "LoggedCallModelResponse.originalLoggedCallId",
) )
.where("organizationId", "=", input.organizationId) .where("projectId", "=", input.projectId)
.select(({ fn }) => [ .select(({ fn }) => [
sql<Date>`date_trunc('day', "LoggedCallModelResponse"."startTime")`.as("period"), sql<Date>`date_trunc('day', "LoggedCallModelResponse"."requestedAt")`.as("period"),
sql<number>`count("LoggedCall"."id")::int`.as("numQueries"), sql<number>`count("LoggedCall"."id")::int`.as("numQueries"),
fn.sum(fn.coalesce("LoggedCallModelResponse.totalCost", sql<number>`0`)).as("totalCost"), fn.sum(fn.coalesce("LoggedCallModelResponse.cost", sql<number>`0`)).as("cost"),
]) ])
.groupBy("period") .groupBy("period")
.orderBy("period") .orderBy("period")
@@ -57,7 +57,7 @@ export const dashboardRouter = createTRPCRouter({
backfilledPeriods.unshift({ backfilledPeriods.unshift({
period: dayjs(dayToMatch).toDate(), period: dayjs(dayToMatch).toDate(),
numQueries: 0, numQueries: 0,
totalCost: 0, cost: 0,
}); });
} }
dayToMatch = dayToMatch.subtract(1, "day"); dayToMatch = dayToMatch.subtract(1, "day");
@@ -70,23 +70,23 @@ export const dashboardRouter = createTRPCRouter({
"LoggedCall.id", "LoggedCall.id",
"LoggedCallModelResponse.originalLoggedCallId", "LoggedCallModelResponse.originalLoggedCallId",
) )
.where("organizationId", "=", input.organizationId) .where("projectId", "=", input.projectId)
.select(({ fn }) => [ .select(({ fn }) => [
fn.sum(fn.coalesce("LoggedCallModelResponse.totalCost", sql<number>`0`)).as("totalCost"), fn.sum(fn.coalesce("LoggedCallModelResponse.cost", sql<number>`0`)).as("cost"),
fn.count("LoggedCall.id").as("numQueries"), fn.count("LoggedCall.id").as("numQueries"),
]) ])
.executeTakeFirst(); .executeTakeFirst();
const errors = await kysely const errors = await kysely
.selectFrom("LoggedCall") .selectFrom("LoggedCall")
.where("organizationId", "=", input.organizationId) .where("projectId", "=", input.projectId)
.leftJoin( .leftJoin(
"LoggedCallModelResponse", "LoggedCallModelResponse",
"LoggedCall.id", "LoggedCall.id",
"LoggedCallModelResponse.originalLoggedCallId", "LoggedCallModelResponse.originalLoggedCallId",
) )
.select(({ fn }) => [fn.count("LoggedCall.id").as("count"), "respStatus as code"]) .select(({ fn }) => [fn.count("LoggedCall.id").as("count"), "statusCode as code"])
.where("respStatus", ">", 200) .where("statusCode", ">", 200)
.groupBy("code") .groupBy("code")
.orderBy("count", "desc") .orderBy("count", "desc")
.execute(); .execute();
@@ -108,7 +108,7 @@ export const dashboardRouter = createTRPCRouter({
// https://discord.com/channels/966627436387266600/1122258443886153758/1122258443886153758 // https://discord.com/channels/966627436387266600/1122258443886153758/1122258443886153758
loggedCalls: publicProcedure.input(z.object({})).query(async ({ input }) => { loggedCalls: publicProcedure.input(z.object({})).query(async ({ input }) => {
const loggedCalls = await prisma.loggedCall.findMany({ const loggedCalls = await prisma.loggedCall.findMany({
orderBy: { startTime: "desc" }, orderBy: { requestedAt: "desc" },
include: { tags: true, modelResponse: true }, include: { tags: true, modelResponse: true },
take: 20, take: 20,
}); });

View File

@@ -3,20 +3,20 @@ import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { import {
requireCanModifyDataset, requireCanModifyDataset,
requireCanModifyOrganization, requireCanModifyProject,
requireCanViewDataset, requireCanViewDataset,
requireCanViewOrganization, requireCanViewProject,
} from "~/utils/accessControl"; } from "~/utils/accessControl";
export const datasetsRouter = createTRPCRouter({ export const datasetsRouter = createTRPCRouter({
list: protectedProcedure list: protectedProcedure
.input(z.object({ organizationId: z.string() })) .input(z.object({ projectId: z.string() }))
.query(async ({ input, ctx }) => { .query(async ({ input, ctx }) => {
await requireCanViewOrganization(input.organizationId, ctx); await requireCanViewProject(input.projectId, ctx);
const datasets = await prisma.dataset.findMany({ const datasets = await prisma.dataset.findMany({
where: { where: {
organizationId: input.organizationId, projectId: input.projectId,
}, },
orderBy: { orderBy: {
createdAt: "desc", createdAt: "desc",
@@ -36,26 +36,26 @@ export const datasetsRouter = createTRPCRouter({
return await prisma.dataset.findFirstOrThrow({ return await prisma.dataset.findFirstOrThrow({
where: { id: input.id }, where: { id: input.id },
include: { include: {
organization: true, project: true,
}, },
}); });
}), }),
create: protectedProcedure create: protectedProcedure
.input(z.object({ organizationId: z.string() })) .input(z.object({ projectId: z.string() }))
.mutation(async ({ input, ctx }) => { .mutation(async ({ input, ctx }) => {
await requireCanModifyOrganization(input.organizationId, ctx); await requireCanModifyProject(input.projectId, ctx);
const numDatasets = await prisma.dataset.count({ const numDatasets = await prisma.dataset.count({
where: { where: {
organizationId: input.organizationId, projectId: input.projectId,
}, },
}); });
return await prisma.dataset.create({ return await prisma.dataset.create({
data: { data: {
name: `Dataset ${numDatasets + 1}`, name: `Dataset ${numDatasets + 1}`,
organizationId: input.organizationId, projectId: input.projectId,
}, },
}); });
}), }),

View File

@@ -8,9 +8,9 @@ import { generateNewCell } from "~/server/utils/generateNewCell";
import { import {
canModifyExperiment, canModifyExperiment,
requireCanModifyExperiment, requireCanModifyExperiment,
requireCanModifyOrganization, requireCanModifyProject,
requireCanViewExperiment, requireCanViewExperiment,
requireCanViewOrganization, requireCanViewProject,
} from "~/utils/accessControl"; } from "~/utils/accessControl";
import generateTypes from "~/modelProviders/generateTypes"; import generateTypes from "~/modelProviders/generateTypes";
import { promptConstructorVersion } from "~/promptConstructor/version"; import { promptConstructorVersion } from "~/promptConstructor/version";
@@ -44,13 +44,13 @@ export const experimentsRouter = createTRPCRouter({
}; };
}), }),
list: protectedProcedure list: protectedProcedure
.input(z.object({ organizationId: z.string() })) .input(z.object({ projectId: z.string() }))
.query(async ({ input, ctx }) => { .query(async ({ input, ctx }) => {
await requireCanViewOrganization(input.organizationId, ctx); await requireCanViewProject(input.projectId, ctx);
const experiments = await prisma.experiment.findMany({ const experiments = await prisma.experiment.findMany({
where: { where: {
organizationId: input.organizationId, projectId: input.projectId,
}, },
orderBy: { orderBy: {
sortIndex: "desc", sortIndex: "desc",
@@ -90,7 +90,7 @@ export const experimentsRouter = createTRPCRouter({
const experiment = await prisma.experiment.findFirstOrThrow({ const experiment = await prisma.experiment.findFirstOrThrow({
where: { id: input.id }, where: { id: input.id },
include: { include: {
organization: true, project: true,
}, },
}); });
@@ -108,10 +108,10 @@ export const experimentsRouter = createTRPCRouter({
}), }),
fork: protectedProcedure fork: protectedProcedure
.input(z.object({ id: z.string(), organizationId: z.string() })) .input(z.object({ id: z.string(), projectId: z.string() }))
.mutation(async ({ input, ctx }) => { .mutation(async ({ input, ctx }) => {
await requireCanViewExperiment(input.id, ctx); await requireCanViewExperiment(input.id, ctx);
await requireCanModifyOrganization(input.organizationId, ctx); await requireCanModifyProject(input.projectId, ctx);
const [ const [
existingExp, existingExp,
@@ -227,7 +227,7 @@ export const experimentsRouter = createTRPCRouter({
...modelResponseData, ...modelResponseData,
id: newModelResponseId, id: newModelResponseId,
scenarioVariantCellId: newCellId, scenarioVariantCellId: newCellId,
output: (modelResponse.output as Prisma.InputJsonValue) ?? undefined, respPayload: (modelResponse.respPayload as Prisma.InputJsonValue) ?? undefined,
}); });
for (const evaluation of outputEvaluations) { for (const evaluation of outputEvaluations) {
outputEvaluationsToCreate.push({ outputEvaluationsToCreate.push({
@@ -264,7 +264,7 @@ export const experimentsRouter = createTRPCRouter({
id: newExperimentId, id: newExperimentId,
sortIndex: maxSortIndex + 1, sortIndex: maxSortIndex + 1,
label: `${existingExp.label} (forked)`, label: `${existingExp.label} (forked)`,
organizationId: input.organizationId, projectId: input.projectId,
}, },
}), }),
prisma.promptVariant.createMany({ prisma.promptVariant.createMany({
@@ -294,9 +294,9 @@ export const experimentsRouter = createTRPCRouter({
}), }),
create: protectedProcedure create: protectedProcedure
.input(z.object({ organizationId: z.string() })) .input(z.object({ projectId: z.string() }))
.mutation(async ({ input, ctx }) => { .mutation(async ({ input, ctx }) => {
await requireCanModifyOrganization(input.organizationId, ctx); await requireCanModifyProject(input.projectId, ctx);
const maxSortIndex = const maxSortIndex =
( (
@@ -304,7 +304,7 @@ export const experimentsRouter = createTRPCRouter({
_max: { _max: {
sortIndex: true, sortIndex: true,
}, },
where: { organizationId: input.organizationId }, where: { projectId: input.projectId },
}) })
)._max?.sortIndex ?? 0; )._max?.sortIndex ?? 0;
@@ -312,7 +312,7 @@ export const experimentsRouter = createTRPCRouter({
data: { data: {
sortIndex: maxSortIndex + 1, sortIndex: maxSortIndex + 1,
label: `Experiment ${maxSortIndex + 1}`, label: `Experiment ${maxSortIndex + 1}`,
organizationId: input.organizationId, projectId: input.projectId,
}, },
}); });

View File

@@ -7,6 +7,11 @@ import { TRPCError } from "@trpc/server";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { hashRequest } from "~/server/utils/hashObject"; import { hashRequest } from "~/server/utils/hashObject";
import modelProvider from "~/modelProviders/openai-ChatCompletion";
import {
type ChatCompletion,
type CompletionCreateParams,
} from "openai/resources/chat/completions";
const reqValidator = z.object({ const reqValidator = z.object({
model: z.string(), model: z.string(),
@@ -16,11 +21,6 @@ const reqValidator = z.object({
const respValidator = z.object({ const respValidator = z.object({
id: z.string(), id: z.string(),
model: z.string(), model: z.string(),
usage: z.object({
total_tokens: z.number(),
prompt_tokens: z.number(),
completion_tokens: z.number(),
}),
choices: z.array( choices: z.array(
z.object({ z.object({
finish_reason: z.string(), finish_reason: z.string(),
@@ -35,11 +35,12 @@ export const externalApiRouter = createTRPCRouter({
method: "POST", method: "POST",
path: "/v1/check-cache", path: "/v1/check-cache",
description: "Check if a prompt is cached", description: "Check if a prompt is cached",
protect: true,
}, },
}) })
.input( .input(
z.object({ z.object({
startTime: z.number().describe("Unix timestamp in milliseconds"), requestedAt: z.number().describe("Unix timestamp in milliseconds"),
reqPayload: z.unknown().describe("JSON-encoded request payload"), reqPayload: z.unknown().describe("JSON-encoded request payload"),
tags: z tags: z
.record(z.string()) .record(z.string())
@@ -66,26 +67,20 @@ export const externalApiRouter = createTRPCRouter({
throw new TRPCError({ code: "UNAUTHORIZED" }); throw new TRPCError({ code: "UNAUTHORIZED" });
} }
const reqPayload = await reqValidator.spa(input.reqPayload); const reqPayload = await reqValidator.spa(input.reqPayload);
const cacheKey = hashRequest(key.organizationId, reqPayload as JsonValue); const cacheKey = hashRequest(key.projectId, reqPayload as JsonValue);
const existingResponse = await prisma.loggedCallModelResponse.findFirst({ const existingResponse = await prisma.loggedCallModelResponse.findFirst({
where: { where: { cacheKey },
cacheKey, include: { originalLoggedCall: true },
}, orderBy: { requestedAt: "desc" },
include: {
originalLoggedCall: true,
},
orderBy: {
startTime: "desc",
},
}); });
if (!existingResponse) return { respPayload: null }; if (!existingResponse) return { respPayload: null };
await prisma.loggedCall.create({ await prisma.loggedCall.create({
data: { data: {
organizationId: key.organizationId, projectId: key.projectId,
startTime: new Date(input.startTime), requestedAt: new Date(input.requestedAt),
cacheHit: true, cacheHit: true,
modelResponseId: existingResponse.id, modelResponseId: existingResponse.id,
}, },
@@ -102,16 +97,17 @@ export const externalApiRouter = createTRPCRouter({
method: "POST", method: "POST",
path: "/v1/report", path: "/v1/report",
description: "Report an API call", description: "Report an API call",
protect: true,
}, },
}) })
.input( .input(
z.object({ z.object({
startTime: z.number().describe("Unix timestamp in milliseconds"), requestedAt: z.number().describe("Unix timestamp in milliseconds"),
endTime: z.number().describe("Unix timestamp in milliseconds"), receivedAt: z.number().describe("Unix timestamp in milliseconds"),
reqPayload: z.unknown().describe("JSON-encoded request payload"), reqPayload: z.unknown().describe("JSON-encoded request payload"),
respPayload: z.unknown().optional().describe("JSON-encoded response payload"), respPayload: z.unknown().optional().describe("JSON-encoded response payload"),
respStatus: z.number().optional().describe("HTTP status code of response"), statusCode: z.number().optional().describe("HTTP status code of response"),
error: z.string().optional().describe("User-friendly error message"), errorMessage: z.string().optional().describe("User-friendly error message"),
tags: z tags: z
.record(z.string()) .record(z.string())
.optional() .optional()
@@ -122,6 +118,7 @@ export const externalApiRouter = createTRPCRouter({
) )
.output(z.void()) .output(z.void())
.mutation(async ({ input, ctx }) => { .mutation(async ({ input, ctx }) => {
console.log("GOT TAGS", input.tags);
const apiKey = ctx.apiKey; const apiKey = ctx.apiKey;
if (!apiKey) { if (!apiKey) {
throw new TRPCError({ code: "UNAUTHORIZED" }); throw new TRPCError({ code: "UNAUTHORIZED" });
@@ -135,19 +132,25 @@ export const externalApiRouter = createTRPCRouter({
const reqPayload = await reqValidator.spa(input.reqPayload); const reqPayload = await reqValidator.spa(input.reqPayload);
const respPayload = await respValidator.spa(input.respPayload); const respPayload = await respValidator.spa(input.respPayload);
const requestHash = hashRequest(key.organizationId, reqPayload as JsonValue); const requestHash = hashRequest(key.projectId, reqPayload as JsonValue);
const newLoggedCallId = uuidv4(); const newLoggedCallId = uuidv4();
const newModelResponseId = uuidv4(); const newModelResponseId = uuidv4();
const usage = respPayload.success ? respPayload.data.usage : undefined; let usage;
if (reqPayload.success && respPayload.success) {
usage = modelProvider.getUsage(
input.reqPayload as CompletionCreateParams,
input.respPayload as ChatCompletion,
);
}
await prisma.$transaction([ await prisma.$transaction([
prisma.loggedCall.create({ prisma.loggedCall.create({
data: { data: {
id: newLoggedCallId, id: newLoggedCallId,
organizationId: key.organizationId, projectId: key.projectId,
startTime: new Date(input.startTime), requestedAt: new Date(input.requestedAt),
cacheHit: false, cacheHit: false,
}, },
}), }),
@@ -155,20 +158,17 @@ export const externalApiRouter = createTRPCRouter({
data: { data: {
id: newModelResponseId, id: newModelResponseId,
originalLoggedCallId: newLoggedCallId, originalLoggedCallId: newLoggedCallId,
startTime: new Date(input.startTime), requestedAt: new Date(input.requestedAt),
endTime: new Date(input.endTime), receivedAt: new Date(input.receivedAt),
reqPayload: input.reqPayload as Prisma.InputJsonValue, reqPayload: input.reqPayload as Prisma.InputJsonValue,
respPayload: input.respPayload as Prisma.InputJsonValue, respPayload: input.respPayload as Prisma.InputJsonValue,
respStatus: input.respStatus, statusCode: input.statusCode,
error: input.error, errorMessage: input.errorMessage,
durationMs: input.endTime - input.startTime, durationMs: input.receivedAt - input.requestedAt,
...(respPayload.success cacheKey: respPayload.success ? requestHash : null,
? { inputTokens: usage?.inputTokens,
cacheKey: requestHash, outputTokens: usage?.outputTokens,
inputTokens: usage ? usage.prompt_tokens : undefined, cost: usage?.cost,
outputTokens: usage ? usage.completion_tokens : undefined,
}
: null),
}, },
}), }),
// Avoid foreign key constraint error by updating the logged call after the model response is created // Avoid foreign key constraint error by updating the logged call after the model response is created
@@ -182,24 +182,22 @@ export const externalApiRouter = createTRPCRouter({
}), }),
]); ]);
if (input.tags) { const tagsToCreate = Object.entries(input.tags ?? {}).map(([name, value]) => ({
const tagsToCreate = Object.entries(input.tags).map(([name, value]) => ({ loggedCallId: newLoggedCallId,
loggedCallId: newLoggedCallId, // sanitize tags
// sanitize tags name: name.replaceAll(/[^a-zA-Z0-9_]/g, "_"),
name: name.replaceAll(/[^a-zA-Z0-9_]/g, "_"), value,
value, }));
}));
if (reqPayload.success) { if (reqPayload.success) {
tagsToCreate.push({ tagsToCreate.push({
loggedCallId: newLoggedCallId, loggedCallId: newLoggedCallId,
name: "$model", name: "$model",
value: reqPayload.data.model, value: reqPayload.data.model,
});
}
await prisma.loggedCallTag.createMany({
data: tagsToCreate,
}); });
} }
await prisma.loggedCallTag.createMany({
data: tagsToCreate,
});
}), }),
}); });

View File

@@ -5,15 +5,15 @@ import { z } from "zod";
import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc"; import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { generateApiKey } from "~/server/utils/generateApiKey"; import { generateApiKey } from "~/server/utils/generateApiKey";
import userOrg from "~/server/utils/userOrg"; import userProject from "~/server/utils/userProject";
import { import {
requireCanModifyOrganization, requireCanModifyProject,
requireCanViewOrganization, requireCanViewProject,
requireIsOrgAdmin, requireIsProjectAdmin,
requireNothing, requireNothing,
} from "~/utils/accessControl"; } from "~/utils/accessControl";
export const organizationsRouter = createTRPCRouter({ export const projectsRouter = createTRPCRouter({
list: protectedProcedure.query(async ({ ctx }) => { list: protectedProcedure.query(async ({ ctx }) => {
const userId = ctx.session.user.id; const userId = ctx.session.user.id;
requireNothing(ctx); requireNothing(ctx);
@@ -22,9 +22,9 @@ export const organizationsRouter = createTRPCRouter({
return null; return null;
} }
const organizations = await prisma.organization.findMany({ const projects = await prisma.project.findMany({
where: { where: {
organizationUsers: { projectUsers: {
some: { userId: ctx.session.user.id }, some: { userId: ctx.session.user.id },
}, },
}, },
@@ -33,30 +33,30 @@ export const organizationsRouter = createTRPCRouter({
}, },
}); });
if (!organizations.length) { if (!projects.length) {
// TODO: We should move this to a separate endpoint that is called on sign up // TODO: We should move this to a separate endpoint that is called on sign up
const personalOrg = await userOrg(userId); const personalProject = await userProject(userId);
organizations.push(personalOrg); projects.push(personalProject);
} }
return organizations; return projects;
}), }),
get: protectedProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => { get: protectedProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
await requireCanViewOrganization(input.id, ctx); await requireCanViewProject(input.id, ctx);
const [org, userRole] = await prisma.$transaction([ const [proj, userRole] = await prisma.$transaction([
prisma.organization.findUnique({ prisma.project.findUnique({
where: { where: {
id: input.id, id: input.id,
}, },
include: { include: {
apiKeys: true, apiKeys: true,
personalOrgUser: true, personalProjectUser: true,
}, },
}), }),
prisma.organizationUser.findFirst({ prisma.projectUser.findFirst({
where: { where: {
userId: ctx.session.user.id, userId: ctx.session.user.id,
organizationId: input.id, projectId: input.id,
role: { role: {
in: ["ADMIN", "MEMBER"], in: ["ADMIN", "MEMBER"],
}, },
@@ -64,20 +64,20 @@ export const organizationsRouter = createTRPCRouter({
}), }),
]); ]);
if (!org) { if (!proj) {
throw new TRPCError({ code: "NOT_FOUND" }); throw new TRPCError({ code: "NOT_FOUND" });
} }
return { return {
...org, ...proj,
role: userRole?.role ?? null, role: userRole?.role ?? null,
}; };
}), }),
update: protectedProcedure update: protectedProcedure
.input(z.object({ id: z.string(), updates: z.object({ name: z.string() }) })) .input(z.object({ id: z.string(), updates: z.object({ name: z.string() }) }))
.mutation(async ({ input, ctx }) => { .mutation(async ({ input, ctx }) => {
await requireCanModifyOrganization(input.id, ctx); await requireCanModifyProject(input.id, ctx);
return await prisma.organization.update({ return await prisma.project.update({
where: { where: {
id: input.id, id: input.id,
}, },
@@ -90,36 +90,36 @@ export const organizationsRouter = createTRPCRouter({
.input(z.object({ name: z.string() })) .input(z.object({ name: z.string() }))
.mutation(async ({ input, ctx }) => { .mutation(async ({ input, ctx }) => {
requireNothing(ctx); requireNothing(ctx);
const newOrgId = uuidv4(); const newProjectId = uuidv4();
const [newOrg] = await prisma.$transaction([ const [newProject] = await prisma.$transaction([
prisma.organization.create({ prisma.project.create({
data: { data: {
id: newOrgId, id: newProjectId,
name: input.name, name: input.name,
}, },
}), }),
prisma.organizationUser.create({ prisma.projectUser.create({
data: { data: {
userId: ctx.session.user.id, userId: ctx.session.user.id,
organizationId: newOrgId, projectId: newProjectId,
role: "ADMIN", role: "ADMIN",
}, },
}), }),
prisma.apiKey.create({ prisma.apiKey.create({
data: { data: {
name: "Default API Key", name: "Default API Key",
organizationId: newOrgId, projectId: newProjectId,
apiKey: generateApiKey(), apiKey: generateApiKey(),
}, },
}), }),
]); ]);
return newOrg; return newProject;
}), }),
delete: protectedProcedure delete: protectedProcedure
.input(z.object({ id: z.string() })) .input(z.object({ id: z.string() }))
.mutation(async ({ input, ctx }) => { .mutation(async ({ input, ctx }) => {
await requireIsOrgAdmin(input.id, ctx); await requireIsProjectAdmin(input.id, ctx);
return await prisma.organization.delete({ return await prisma.project.delete({
where: { where: {
id: input.id, id: input.id,
}, },

View File

@@ -3,7 +3,7 @@ import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { Prisma } from "@prisma/client"; import { Prisma } from "@prisma/client";
import { generateNewCell } from "~/server/utils/generateNewCell"; import { generateNewCell } from "~/server/utils/generateNewCell";
import userError from "~/server/utils/error"; import { error, success } from "~/utils/errorHandling/standardResponses";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated"; import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants"; import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
import { type PromptVariant } from "@prisma/client"; import { type PromptVariant } from "@prisma/client";
@@ -55,7 +55,7 @@ export const promptVariantsRouter = createTRPCRouter({
where: { where: {
modelResponse: { modelResponse: {
outdated: false, outdated: false,
output: { not: Prisma.AnyNull }, respPayload: { not: Prisma.AnyNull },
scenarioVariantCell: { scenarioVariantCell: {
promptVariant: { promptVariant: {
id: input.variantId, id: input.variantId,
@@ -100,7 +100,7 @@ export const promptVariantsRouter = createTRPCRouter({
modelResponses: { modelResponses: {
some: { some: {
outdated: false, outdated: false,
output: { respPayload: {
not: Prisma.AnyNull, not: Prisma.AnyNull,
}, },
}, },
@@ -111,7 +111,7 @@ export const promptVariantsRouter = createTRPCRouter({
const overallTokens = await prisma.modelResponse.aggregate({ const overallTokens = await prisma.modelResponse.aggregate({
where: { where: {
outdated: false, outdated: false,
output: { respPayload: {
not: Prisma.AnyNull, not: Prisma.AnyNull,
}, },
scenarioVariantCell: { scenarioVariantCell: {
@@ -123,13 +123,13 @@ export const promptVariantsRouter = createTRPCRouter({
}, },
_sum: { _sum: {
cost: true, cost: true,
promptTokens: true, inputTokens: true,
completionTokens: true, outputTokens: true,
}, },
}); });
const promptTokens = overallTokens._sum?.promptTokens ?? 0; const inputTokens = overallTokens._sum?.inputTokens ?? 0;
const completionTokens = overallTokens._sum?.completionTokens ?? 0; const outputTokens = overallTokens._sum?.outputTokens ?? 0;
const awaitingEvals = !!evalResults.find( const awaitingEvals = !!evalResults.find(
(result) => result.totalCount < scenarioCount * evals.length, (result) => result.totalCount < scenarioCount * evals.length,
@@ -137,8 +137,8 @@ export const promptVariantsRouter = createTRPCRouter({
return { return {
evalResults, evalResults,
promptTokens, inputTokens,
completionTokens, outputTokens,
overallCost: overallTokens._sum?.cost ?? 0, overallCost: overallTokens._sum?.cost ?? 0,
scenarioCount, scenarioCount,
outputCount, outputCount,
@@ -315,7 +315,7 @@ export const promptVariantsRouter = createTRPCRouter({
const constructedPrompt = await parsePromptConstructor(existing.promptConstructor); const constructedPrompt = await parsePromptConstructor(existing.promptConstructor);
if ("error" in constructedPrompt) { if ("error" in constructedPrompt) {
return userError(constructedPrompt.error); return error(constructedPrompt.error);
} }
const model = input.newModel const model = input.newModel
@@ -353,7 +353,7 @@ export const promptVariantsRouter = createTRPCRouter({
const parsedPrompt = await parsePromptConstructor(input.promptConstructor); const parsedPrompt = await parsePromptConstructor(input.promptConstructor);
if ("error" in parsedPrompt) { if ("error" in parsedPrompt) {
return userError(parsedPrompt.error); return error(parsedPrompt.error);
} }
// Create a duplicate with only the config changed // Create a duplicate with only the config changed
@@ -398,7 +398,7 @@ export const promptVariantsRouter = createTRPCRouter({
}); });
} }
return { status: "ok" } as const; return success();
}), }),
reorder: protectedProcedure reorder: protectedProcedure

View File

@@ -0,0 +1,143 @@
import { type TemplateVariable } from "@prisma/client";
import { sql } from "kysely";
import { z } from "zod";
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { kysely, prisma } from "~/server/db";
import { error, success } from "~/utils/errorHandling/standardResponses";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const scenarioVarsRouter = createTRPCRouter({
create: protectedProcedure
.input(z.object({ experimentId: z.string(), label: z.string() }))
.mutation(async ({ input, ctx }) => {
await requireCanModifyExperiment(input.experimentId, ctx);
// Make sure there isn't an existing variable with the same name
const existingVariable = await prisma.templateVariable.findFirst({
where: {
experimentId: input.experimentId,
label: input.label,
},
});
if (existingVariable) {
return error(`A variable named ${input.label} already exists.`);
}
await prisma.templateVariable.create({
data: {
experimentId: input.experimentId,
label: input.label,
},
});
return success();
}),
rename: protectedProcedure
.input(z.object({ id: z.string(), label: z.string() }))
.mutation(async ({ input, ctx }) => {
const templateVariable = await prisma.templateVariable.findUniqueOrThrow({
where: { id: input.id },
});
await requireCanModifyExperiment(templateVariable.experimentId, ctx);
// Make sure there isn't an existing variable with the same name
const existingVariable = await prisma.templateVariable.findFirst({
where: {
experimentId: templateVariable.experimentId,
label: input.label,
},
});
if (existingVariable) {
return error(`A variable named ${input.label} already exists.`);
}
await renameTemplateVariable(templateVariable, input.label);
return success();
}),
delete: protectedProcedure
.input(z.object({ id: z.string() }))
.mutation(async ({ input, ctx }) => {
const { experimentId } = await prisma.templateVariable.findUniqueOrThrow({
where: { id: input.id },
});
await requireCanModifyExperiment(experimentId, ctx);
await prisma.templateVariable.delete({ where: { id: input.id } });
}),
list: publicProcedure
.input(z.object({ experimentId: z.string() }))
.query(async ({ input, ctx }) => {
await requireCanViewExperiment(input.experimentId, ctx);
return await prisma.templateVariable.findMany({
where: {
experimentId: input.experimentId,
},
orderBy: {
createdAt: "asc",
},
select: {
id: true,
label: true,
},
});
}),
});
export const renameTemplateVariable = async (
templateVariable: TemplateVariable,
newLabel: string,
) => {
const { experimentId } = templateVariable;
await kysely.transaction().execute(async (trx) => {
await trx
.updateTable("TemplateVariable")
.set({
label: newLabel,
})
.where("id", "=", templateVariable.id)
.execute();
await sql`
CREATE TEMP TABLE "TempTestScenario" AS
SELECT *
FROM "TestScenario"
WHERE "experimentId" = ${experimentId}
-- Only copy the rows that actually have a value for the variable, no reason to churn the rest and simplifies the update.
AND "variableValues"->${templateVariable.label} IS NOT NULL
`.execute(trx);
await sql`
UPDATE "TempTestScenario"
SET "variableValues" = jsonb_set(
"variableValues",
${`{${newLabel}}`},
"variableValues"->${templateVariable.label}
) - ${templateVariable.label},
"updatedAt" = NOW(),
"id" = uuid_generate_v4()
`.execute(trx);
// Print the contents of the temp table
const results = await sql`SELECT * FROM "TempTestScenario"`.execute(trx);
console.log(results.rows);
await trx
.updateTable("TestScenario")
.set({
visible: false,
})
.where("experimentId", "=", experimentId)
.execute();
await sql`
INSERT INTO "TestScenario" (id, "variableValues", "uiId", visible, "sortIndex", "experimentId", "createdAt", "updatedAt")
SELECT * FROM "TempTestScenario";
`.execute(trx);
});
};

View File

@@ -0,0 +1,110 @@
import { expect, it } from "vitest";
import { prisma } from "~/server/db";
import { renameTemplateVariable } from "./scenarioVariables.router";
const createExperiment = async () => {
return await prisma.experiment.create({
data: {
label: "Test Experiment",
project: {
create: {},
},
},
});
};
const createTemplateVar = async (experimentId: string, label: string) => {
return await prisma.templateVariable.create({
data: {
experimentId,
label,
},
});
};
it("renames templateVariables", async () => {
// Create experiments concurrently
const [exp1, exp2] = await Promise.all([createExperiment(), createExperiment()]);
// Create template variables concurrently
const [exp1Var, exp2Var1, exp2Var2] = await Promise.all([
createTemplateVar(exp1.id, "input1"),
createTemplateVar(exp2.id, "input1"),
createTemplateVar(exp2.id, "input2"),
]);
// Create test scenarios concurrently
const [exp1Scenario, exp2Scenario, exp2HiddenScenario] = await Promise.all([
prisma.testScenario.create({
data: {
experimentId: exp1.id,
visible: true,
variableValues: { input1: "test" },
},
}),
prisma.testScenario.create({
data: {
experimentId: exp2.id,
visible: true,
variableValues: { input1: "test1", otherInput: "otherTest" },
},
}),
prisma.testScenario.create({
data: {
experimentId: exp2.id,
visible: false,
variableValues: { otherInput: "otherTest2" },
},
}),
]);
await renameTemplateVariable(exp2Var1, "input1-renamed");
expect(await prisma.templateVariable.findUnique({ where: { id: exp2Var1.id } })).toMatchObject({
label: "input1-renamed",
});
// It shouldn't mess with unrelated experiments
expect(await prisma.testScenario.findUnique({ where: { id: exp1Scenario.id } })).toMatchObject({
visible: true,
variableValues: { input1: "test" },
});
// Make sure there are a total of 4 scenarios for exp2
expect(
await prisma.testScenario.count({
where: {
experimentId: exp2.id,
},
}),
).toBe(3);
// It shouldn't mess with the existing scenarios, except to hide them
expect(await prisma.testScenario.findUnique({ where: { id: exp2Scenario.id } })).toMatchObject({
visible: false,
variableValues: { input1: "test1", otherInput: "otherTest" },
});
// It should create a new scenario with the new variable name
const newScenario1 = await prisma.testScenario.findFirst({
where: {
experimentId: exp2.id,
variableValues: { equals: { "input1-renamed": "test1", otherInput: "otherTest" } },
},
});
expect(newScenario1).toMatchObject({
visible: true,
});
const newScenario2 = await prisma.testScenario.findFirst({
where: {
experimentId: exp2.id,
variableValues: { equals: { otherInput: "otherTest2" } },
},
});
expect(newScenario2).toMatchObject({
visible: false,
});
});

View File

@@ -1,49 +0,0 @@
import { z } from "zod";
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
export const templateVarsRouter = createTRPCRouter({
create: protectedProcedure
.input(z.object({ experimentId: z.string(), label: z.string() }))
.mutation(async ({ input, ctx }) => {
await requireCanModifyExperiment(input.experimentId, ctx);
await prisma.templateVariable.create({
data: {
experimentId: input.experimentId,
label: input.label,
},
});
}),
delete: protectedProcedure
.input(z.object({ id: z.string() }))
.mutation(async ({ input, ctx }) => {
const { experimentId } = await prisma.templateVariable.findUniqueOrThrow({
where: { id: input.id },
});
await requireCanModifyExperiment(experimentId, ctx);
await prisma.templateVariable.delete({ where: { id: input.id } });
}),
list: publicProcedure
.input(z.object({ experimentId: z.string() }))
.query(async ({ input, ctx }) => {
await requireCanViewExperiment(input.experimentId, ctx);
return await prisma.templateVariable.findMany({
where: {
experimentId: input.experimentId,
},
orderBy: {
createdAt: "asc",
},
select: {
id: true,
label: true,
},
});
}),
});

View File

@@ -64,7 +64,7 @@ export const createTRPCContext = async (opts: CreateNextContextOptions) => {
// Get the session from the server using the getServerSession wrapper function // Get the session from the server using the getServerSession wrapper function
const session = await getServerAuthSession({ req, res }); const session = await getServerAuthSession({ req, res });
const apiKey = req.headers["x-openpipe-api-key"] as string | null; const apiKey = req.headers.authorization?.split(" ")[1] as string | null;
return createInnerTRPCContext({ return createInnerTRPCContext({
session, session,

View File

@@ -9,8 +9,8 @@ import {
type OutputEvaluation, type OutputEvaluation,
type Dataset, type Dataset,
type DatasetEntry, type DatasetEntry,
type Organization, type Project,
type OrganizationUser, type ProjectUser,
type WorldChampEntrant, type WorldChampEntrant,
type LoggedCall, type LoggedCall,
type LoggedCallModelResponse, type LoggedCallModelResponse,
@@ -43,8 +43,8 @@ interface DB {
OutputEvaluation: OutputEvaluation; OutputEvaluation: OutputEvaluation;
Dataset: Dataset; Dataset: Dataset;
DatasetEntry: DatasetEntry; DatasetEntry: DatasetEntry;
Organization: Organization; Project: Project;
OrganizationUser: OrganizationUser; ProjectUser: ProjectUser;
WorldChampEntrant: WorldChampEntrant; WorldChampEntrant: WorldChampEntrant;
LoggedCall: LoggedCall; LoggedCall: LoggedCall;
LoggedCallModelResponse: LoggedCallModelResponse; LoggedCallModelResponse: LoggedCallModelResponse;

View File

@@ -4,21 +4,21 @@ import { generateApiKey } from "~/server/utils/generateApiKey";
console.log("backfilling api keys"); console.log("backfilling api keys");
const organizations = await prisma.organization.findMany({ const projects = await prisma.project.findMany({
include: { include: {
apiKeys: true, apiKeys: true,
}, },
}); });
console.log(`found ${organizations.length} organizations`); console.log(`found ${projects.length} projects`);
const apiKeysToCreate: Prisma.ApiKeyCreateManyInput[] = []; const apiKeysToCreate: Prisma.ApiKeyCreateManyInput[] = [];
for (const org of organizations) { for (const proj of projects) {
if (!org.apiKeys.length) { if (!proj.apiKeys.length) {
apiKeysToCreate.push({ apiKeysToCreate.push({
name: "Default API Key", name: "Default API Key",
organizationId: org.id, projectId: proj.id,
apiKey: generateApiKey(), apiKey: generateApiKey(),
}); });
} }

View File

@@ -1,63 +0,0 @@
import dayjs from "dayjs";
import { prisma } from "../db";
const projectId = "1234";
// Find all calls in the last 24 hours
const responses = await prisma.loggedCall.findMany({
where: {
organizationId: projectId,
startTime: {
gt: dayjs()
.subtract(24 * 3600)
.toDate(),
},
},
include: {
modelResponse: true,
},
orderBy: {
startTime: "desc",
},
});
// Find all calls in the last 24 hours with promptId 'hello-world'
const helloWorld = await prisma.loggedCall.findMany({
where: {
organizationId: projectId,
startTime: {
gt: dayjs()
.subtract(24 * 3600)
.toDate(),
},
tags: {
some: {
name: "promptId",
value: "hello-world",
},
},
},
include: {
modelResponse: true,
},
orderBy: {
startTime: "desc",
},
});
// Total spent on OpenAI in the last month
const totalSpent = await prisma.loggedCallModelResponse.aggregate({
_sum: {
totalCost: true,
},
where: {
originalLoggedCall: {
organizationId: projectId,
},
startTime: {
gt: dayjs()
.subtract(30 * 24 * 3600)
.toDate(),
},
},
});

View File

@@ -99,26 +99,27 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
} }
: null; : null;
const inputHash = hashObject(prompt as JsonValue); const cacheKey = hashObject(prompt as JsonValue);
let modelResponse = await prisma.modelResponse.create({ let modelResponse = await prisma.modelResponse.create({
data: { data: {
inputHash, cacheKey,
scenarioVariantCellId: cellId, scenarioVariantCellId: cellId,
requestedAt: new Date(), requestedAt: new Date(),
}, },
}); });
const response = await provider.getCompletion(prompt.modelInput, onStream); const response = await provider.getCompletion(prompt.modelInput, onStream);
if (response.type === "success") { if (response.type === "success") {
const usage = provider.getUsage(prompt.modelInput, response.value);
modelResponse = await prisma.modelResponse.update({ modelResponse = await prisma.modelResponse.update({
where: { id: modelResponse.id }, where: { id: modelResponse.id },
data: { data: {
output: response.value as Prisma.InputJsonObject, respPayload: response.value as Prisma.InputJsonObject,
statusCode: response.statusCode, statusCode: response.statusCode,
receivedAt: new Date(), receivedAt: new Date(),
promptTokens: response.promptTokens, inputTokens: usage?.inputTokens,
completionTokens: response.completionTokens, outputTokens: usage?.outputTokens,
cost: response.cost, cost: usage?.cost,
}, },
}); });

View File

@@ -1,6 +0,0 @@
export default function userError(message: string): { status: "error"; message: string } {
return {
status: "error",
message,
};
}

View File

@@ -51,7 +51,7 @@ export const runAllEvals = async (experimentId: string) => {
const outputs = await prisma.modelResponse.findMany({ const outputs = await prisma.modelResponse.findMany({
where: { where: {
outdated: false, outdated: false,
output: { respPayload: {
not: Prisma.AnyNull, not: Prisma.AnyNull,
}, },
scenarioVariantCell: { scenarioVariantCell: {

View File

@@ -57,7 +57,7 @@ export const generateNewCell = async (
return; return;
} }
const inputHash = hashObject(parsedConstructFn); const cacheKey = hashObject(parsedConstructFn);
cell = await prisma.scenarioVariantCell.create({ cell = await prisma.scenarioVariantCell.create({
data: { data: {
@@ -73,8 +73,8 @@ export const generateNewCell = async (
const matchingModelResponse = await prisma.modelResponse.findFirst({ const matchingModelResponse = await prisma.modelResponse.findFirst({
where: { where: {
inputHash, cacheKey,
output: { respPayload: {
not: Prisma.AnyNull, not: Prisma.AnyNull,
}, },
}, },
@@ -92,7 +92,7 @@ export const generateNewCell = async (
data: { data: {
...omit(matchingModelResponse, ["id", "scenarioVariantCell"]), ...omit(matchingModelResponse, ["id", "scenarioVariantCell"]),
scenarioVariantCellId: cell.id, scenarioVariantCellId: cell.id,
output: matchingModelResponse.output as Prisma.InputJsonValue, respPayload: matchingModelResponse.respPayload as Prisma.InputJsonValue,
}, },
}); });

View File

@@ -24,9 +24,9 @@ function sortKeys(obj: JsonValue): JsonValue {
return sortedObj; return sortedObj;
} }
export function hashRequest(organizationId: string, reqPayload: JsonValue): string { export function hashRequest(projectId: string, reqPayload: JsonValue): string {
const obj = { const obj = {
organizationId, projectId,
reqPayload, reqPayload,
}; };
return hashObject(obj); return hashObject(obj);

View File

@@ -71,7 +71,7 @@ export const runOneEval = async (
provider: SupportedProvider, provider: SupportedProvider,
): Promise<{ result: number; details?: string }> => { ): Promise<{ result: number; details?: string }> => {
const modelProvider = modelProviders[provider]; const modelProvider = modelProviders[provider];
const message = modelProvider.normalizeOutput(modelResponse.output); const message = modelProvider.normalizeOutput(modelResponse.respPayload);
if (!message) return { result: 0 }; if (!message) return { result: 0 };

View File

@@ -1,15 +1,15 @@
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { generateApiKey } from "./generateApiKey"; import { generateApiKey } from "./generateApiKey";
export default async function userOrg(userId: string) { export default async function userProject(userId: string) {
return await prisma.organization.upsert({ return await prisma.project.upsert({
where: { where: {
personalOrgUserId: userId, personalProjectUserId: userId,
}, },
update: {}, update: {},
create: { create: {
personalOrgUserId: userId, personalProjectUserId: userId,
organizationUsers: { projectUsers: {
create: { create: {
userId: userId, userId: userId,
role: "ADMIN", role: "ADMIN",

13
app/src/state/persist.ts Normal file
View File

@@ -0,0 +1,13 @@
import { type PersistOptions } from "zustand/middleware/persist";
import { type State } from "./store";
export const stateToPersist = {
selectedProjectId: null as string | null,
};
export const persistOptions: PersistOptions<State, typeof stateToPersist> = {
name: "persisted-app-store",
partialize: (state) => ({
selectedProjectId: state.selectedProjectId,
}),
};

View File

@@ -8,9 +8,9 @@ export const editorBackground = "#fafafa";
export type SharedVariantEditorSlice = { export type SharedVariantEditorSlice = {
monaco: null | ReturnType<typeof loader.__getMonacoInstance>; monaco: null | ReturnType<typeof loader.__getMonacoInstance>;
loadMonaco: () => Promise<void>; loadMonaco: () => Promise<void>;
scenarios: RouterOutputs["scenarios"]["list"]["scenarios"]; scenarioVars: RouterOutputs["scenarioVars"]["list"];
updateScenariosModel: () => void; updateScenariosModel: () => void;
setScenarios: (scenarios: RouterOutputs["scenarios"]["list"]["scenarios"]) => void; setScenarioVars: (scenarioVars: RouterOutputs["scenarioVars"]["list"]) => void;
}; };
export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> = (set, get) => ({ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> = (set, get) => ({
@@ -60,10 +60,10 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
}); });
get().sharedVariantEditor.updateScenariosModel(); get().sharedVariantEditor.updateScenariosModel();
}, },
scenarios: [], scenarioVars: [],
setScenarios: (scenarios) => { setScenarioVars: (scenarios) => {
set((state) => { set((state) => {
state.sharedVariantEditor.scenarios = scenarios; state.sharedVariantEditor.scenarioVars = scenarios;
}); });
get().sharedVariantEditor.updateScenariosModel(); get().sharedVariantEditor.updateScenariosModel();
@@ -74,16 +74,15 @@ export const createVariantEditorSlice: SliceCreator<SharedVariantEditorSlice> =
if (!monaco) return; if (!monaco) return;
const modelContents = ` const modelContents = `
const scenarios = ${JSON.stringify( declare var scenario: {
get().sharedVariantEditor.scenarios.map((s) => s.variableValues), ${get()
null, .sharedVariantEditor.scenarioVars.map((s) => `${s.label}: string;`)
2, .join("\n")}
)} as const; };
type Scenario = typeof scenarios[number];
declare var scenario: Scenario | { [key: string]: string };
`; `;
console.log(modelContents);
const scenariosModel = monaco.editor.getModel(monaco.Uri.parse("file:///scenarios.ts")); const scenariosModel = monaco.editor.getModel(monaco.Uri.parse("file:///scenarios.ts"));
if (scenariosModel) { if (scenariosModel) {

View File

@@ -1,11 +1,13 @@
import { type StateCreator, create } from "zustand"; import { type StateCreator, create } from "zustand";
import { immer } from "zustand/middleware/immer"; import { immer } from "zustand/middleware/immer";
import { persist } from "zustand/middleware";
import { createSelectors } from "./createSelectors"; import { createSelectors } from "./createSelectors";
import { import {
type SharedVariantEditorSlice, type SharedVariantEditorSlice,
createVariantEditorSlice, createVariantEditorSlice,
} from "./sharedVariantEditor.slice"; } from "./sharedVariantEditor.slice";
import { type APIClient } from "~/utils/api"; import { type APIClient } from "~/utils/api";
import { persistOptions, type stateToPersist } from "./persist";
export type State = { export type State = {
drawerOpen: boolean; drawerOpen: boolean;
@@ -14,8 +16,8 @@ export type State = {
api: APIClient | null; api: APIClient | null;
setApi: (api: APIClient) => void; setApi: (api: APIClient) => void;
sharedVariantEditor: SharedVariantEditorSlice; sharedVariantEditor: SharedVariantEditorSlice;
selectedOrgId: string | null; selectedProjectId: string | null;
setSelectedOrgId: (orgId: string) => void; setselectedProjectId: (id: string) => void;
}; };
export type SliceCreator<T> = StateCreator<State, [["zustand/immer", never]], [], T>; export type SliceCreator<T> = StateCreator<State, [["zustand/immer", never]], [], T>;
@@ -23,30 +25,36 @@ export type SliceCreator<T> = StateCreator<State, [["zustand/immer", never]], []
export type SetFn = Parameters<SliceCreator<unknown>>[0]; export type SetFn = Parameters<SliceCreator<unknown>>[0];
export type GetFn = Parameters<SliceCreator<unknown>>[1]; export type GetFn = Parameters<SliceCreator<unknown>>[1];
const useBaseStore = create<State, [["zustand/immer", never]]>( const useBaseStore = create<
immer((set, get, ...rest) => ({ State,
api: null, [["zustand/persist", typeof stateToPersist], ["zustand/immer", never]]
setApi: (api) => >(
set((state) => { persist(
state.api = api; immer((set, get, ...rest) => ({
}), api: null,
setApi: (api) =>
set((state) => {
state.api = api;
}),
drawerOpen: false, drawerOpen: false,
openDrawer: () => openDrawer: () =>
set((state) => { set((state) => {
state.drawerOpen = true; state.drawerOpen = true;
}), }),
closeDrawer: () => closeDrawer: () =>
set((state) => { set((state) => {
state.drawerOpen = false; state.drawerOpen = false;
}), }),
sharedVariantEditor: createVariantEditorSlice(set, get, ...rest), sharedVariantEditor: createVariantEditorSlice(set, get, ...rest),
selectedOrgId: null, selectedProjectId: null,
setSelectedOrgId: (orgId: string) => setselectedProjectId: (id: string) =>
set((state) => { set((state) => {
state.selectedOrgId = orgId; state.selectedProjectId = id;
}), }),
})), })),
persistOptions,
),
); );
export const useAppStore = createSelectors(useBaseStore); export const useAppStore = createSelectors(useBaseStore);

View File

@@ -1,16 +1,16 @@
import { useEffect } from "react"; import { useEffect } from "react";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useScenarios } from "~/utils/hooks"; import { useScenarioVars } from "~/utils/hooks";
import { useAppStore } from "./store"; import { useAppStore } from "./store";
export function useSyncVariantEditor() { export function useSyncVariantEditor() {
const scenarios = useScenarios(); const scenarioVars = useScenarioVars();
useEffect(() => { useEffect(() => {
if (scenarios.data) { if (scenarioVars.data) {
useAppStore.getState().sharedVariantEditor.setScenarios(scenarios.data.scenarios); useAppStore.getState().sharedVariantEditor.setScenarioVars(scenarioVars.data);
} }
}, [scenarios.data]); }, [scenarioVars.data]);
} }
export function SyncAppStore() { export function SyncAppStore() {

View File

@@ -0,0 +1,5 @@
import { configDotenv } from "dotenv";
configDotenv({
path: ".env.test",
});

View File

@@ -0,0 +1,13 @@
import "./loadEnv";
import { sql } from "kysely";
import { beforeEach } from "vitest";
import { kysely } from "~/server/db";
// Reset all Prisma data
const resetDb = async () => {
await sql`truncate "Experiment" cascade;`.execute(kysely);
};
beforeEach(async () => {
await resetDb();
});

View File

@@ -1,4 +1,9 @@
import { extendTheme, defineStyleConfig, ChakraProvider } from "@chakra-ui/react"; import {
extendTheme,
defineStyleConfig,
ChakraProvider,
createStandaloneToast,
} from "@chakra-ui/react";
import "@fontsource/inconsolata"; import "@fontsource/inconsolata";
import { modalAnatomy } from "@chakra-ui/anatomy"; import { modalAnatomy } from "@chakra-ui/anatomy";
import { createMultiStyleConfigHelpers } from "@chakra-ui/styled-system"; import { createMultiStyleConfigHelpers } from "@chakra-ui/styled-system";
@@ -63,6 +68,15 @@ const theme = extendTheme({
}, },
}); });
const { ToastContainer, toast } = createStandaloneToast(theme);
export { toast };
export const ChakraThemeProvider = ({ children }: { children: JSX.Element }) => { export const ChakraThemeProvider = ({ children }: { children: JSX.Element }) => {
return <ChakraProvider theme={theme}>{children}</ChakraProvider>; return (
<ChakraProvider theme={theme}>
<ToastContainer />
{children}
</ChakraProvider>
);
}; };

View File

@@ -1,4 +1,4 @@
import { OrganizationUserRole } from "@prisma/client"; import { ProjectUserRole } from "@prisma/client";
import { TRPCError } from "@trpc/server"; import { TRPCError } from "@trpc/server";
import { type TRPCContext } from "~/server/api/trpc"; import { type TRPCContext } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
@@ -16,16 +16,16 @@ export const requireNothing = (ctx: TRPCContext) => {
ctx.markAccessControlRun(); ctx.markAccessControlRun();
}; };
export const requireIsOrgAdmin = async (organizationId: string, ctx: TRPCContext) => { export const requireIsProjectAdmin = async (projectId: string, ctx: TRPCContext) => {
const userId = ctx.session?.user.id; const userId = ctx.session?.user.id;
if (!userId) { if (!userId) {
throw new TRPCError({ code: "UNAUTHORIZED" }); throw new TRPCError({ code: "UNAUTHORIZED" });
} }
const isAdmin = await prisma.organizationUser.findFirst({ const isAdmin = await prisma.projectUser.findFirst({
where: { where: {
userId, userId,
organizationId, projectId,
role: "ADMIN", role: "ADMIN",
}, },
}); });
@@ -37,16 +37,16 @@ export const requireIsOrgAdmin = async (organizationId: string, ctx: TRPCContext
ctx.markAccessControlRun(); ctx.markAccessControlRun();
}; };
export const requireCanViewOrganization = async (organizationId: string, ctx: TRPCContext) => { export const requireCanViewProject = async (projectId: string, ctx: TRPCContext) => {
const userId = ctx.session?.user.id; const userId = ctx.session?.user.id;
if (!userId) { if (!userId) {
throw new TRPCError({ code: "UNAUTHORIZED" }); throw new TRPCError({ code: "UNAUTHORIZED" });
} }
const canView = await prisma.organizationUser.findFirst({ const canView = await prisma.projectUser.findFirst({
where: { where: {
userId, userId,
organizationId, projectId,
}, },
}); });
@@ -57,17 +57,17 @@ export const requireCanViewOrganization = async (organizationId: string, ctx: TR
ctx.markAccessControlRun(); ctx.markAccessControlRun();
}; };
export const requireCanModifyOrganization = async (organizationId: string, ctx: TRPCContext) => { export const requireCanModifyProject = async (projectId: string, ctx: TRPCContext) => {
const userId = ctx.session?.user.id; const userId = ctx.session?.user.id;
if (!userId) { if (!userId) {
throw new TRPCError({ code: "UNAUTHORIZED" }); throw new TRPCError({ code: "UNAUTHORIZED" });
} }
const canModify = await prisma.organizationUser.findFirst({ const canModify = await prisma.projectUser.findFirst({
where: { where: {
userId, userId,
organizationId, projectId,
role: { in: [OrganizationUserRole.ADMIN, OrganizationUserRole.MEMBER] }, role: { in: [ProjectUserRole.ADMIN, ProjectUserRole.MEMBER] },
}, },
}); });
@@ -82,10 +82,10 @@ export const requireCanViewDataset = async (datasetId: string, ctx: TRPCContext)
const dataset = await prisma.dataset.findFirst({ const dataset = await prisma.dataset.findFirst({
where: { where: {
id: datasetId, id: datasetId,
organization: { project: {
organizationUsers: { projectUsers: {
some: { some: {
role: { in: [OrganizationUserRole.ADMIN, OrganizationUserRole.MEMBER] }, role: { in: [ProjectUserRole.ADMIN, ProjectUserRole.MEMBER] },
userId: ctx.session?.user.id, userId: ctx.session?.user.id,
}, },
}, },
@@ -120,10 +120,10 @@ export const canModifyExperiment = async (experimentId: string, userId: string)
prisma.experiment.findFirst({ prisma.experiment.findFirst({
where: { where: {
id: experimentId, id: experimentId,
organization: { project: {
organizationUsers: { projectUsers: {
some: { some: {
role: { in: [OrganizationUserRole.ADMIN, OrganizationUserRole.MEMBER] }, role: { in: [ProjectUserRole.ADMIN, ProjectUserRole.MEMBER] },
userId, userId,
}, },
}, },

View File

@@ -0,0 +1,20 @@
import { toast } from "~/theme/ChakraThemeProvider";
import { type error, type success } from "./standardResponses";
type SuccessType<T> = ReturnType<typeof success<T>>;
type ErrorType = ReturnType<typeof error>;
// Used client-side to report generic errors
export function maybeReportError<T>(response: SuccessType<T> | ErrorType): response is ErrorType {
if (response.status === "error") {
toast({
description: response.message,
status: "error",
duration: 5000,
isClosable: true,
});
return true;
}
return false;
}

View File

@@ -0,0 +1,11 @@
export function error(message: string): { status: "error"; message: string } {
return {
status: "error",
message,
};
}
export function success<T>(payload: T): { status: "success"; payload: T };
export function success(payload?: undefined): { status: "success"; payload: undefined };
export function success<T>(payload?: T) {
return { status: "success", payload };
}

View File

@@ -5,10 +5,10 @@ import { NumberParam, useQueryParam, withDefault } from "use-query-params";
import { useAppStore } from "~/state/store"; import { useAppStore } from "~/state/store";
export const useExperiments = () => { export const useExperiments = () => {
const selectedOrgId = useAppStore((state) => state.selectedOrgId); const selectedProjectId = useAppStore((state) => state.selectedProjectId);
return api.experiments.list.useQuery( return api.experiments.list.useQuery(
{ organizationId: selectedOrgId ?? "" }, { projectId: selectedProjectId ?? "" },
{ enabled: !!selectedOrgId }, { enabled: !!selectedProjectId },
); );
}; };
@@ -27,10 +27,10 @@ export const useExperimentAccess = () => {
}; };
export const useDatasets = () => { export const useDatasets = () => {
const selectedOrgId = useAppStore((state) => state.selectedOrgId); const selectedProjectId = useAppStore((state) => state.selectedProjectId);
return api.datasets.list.useQuery( return api.datasets.list.useQuery(
{ organizationId: selectedOrgId ?? "" }, { projectId: selectedProjectId ?? "" },
{ enabled: !!selectedOrgId }, { enabled: !!selectedProjectId },
); );
}; };
@@ -150,7 +150,19 @@ export const useScenario = (scenarioId: string) => {
export const useVisibleScenarioIds = () => useScenarios().data?.scenarios.map((s) => s.id) ?? []; export const useVisibleScenarioIds = () => useScenarios().data?.scenarios.map((s) => s.id) ?? [];
export const useSelectedOrg = () => { export const useSelectedProject = () => {
const selectedOrgId = useAppStore((state) => state.selectedOrgId); const selectedProjectId = useAppStore((state) => state.selectedProjectId);
return api.organizations.get.useQuery({ id: selectedOrgId ?? "" }, { enabled: !!selectedOrgId }); return api.projects.get.useQuery(
{ id: selectedProjectId ?? "" },
{ enabled: !!selectedProjectId },
);
};
export const useScenarioVars = () => {
const experiment = useExperiment();
return api.scenarioVars.list.useQuery(
{ experimentId: experiment.data?.id ?? "" },
{ enabled: experiment.data?.id != null },
);
}; };

View File

@@ -4,6 +4,10 @@ import { configDefaults, defineConfig, type UserConfig } from "vitest/config";
const config = defineConfig({ const config = defineConfig({
test: { test: {
...configDefaults, // Extending Vitest's default options ...configDefaults, // Extending Vitest's default options
setupFiles: ["./src/tests/helpers/setup.ts"],
// Unfortunately using threads seems to cause issues with isolated-vm
threads: false,
}, },
plugins: [tsconfigPaths()], plugins: [tsconfigPaths()],
}) as UserConfig; }) as UserConfig;

Binary file not shown.

Binary file not shown.

View File

11
client-libs/python/codegen.sh Executable file
View File

@@ -0,0 +1,11 @@
#! /bin/bash
set -e
cd "$(dirname "$0")"
poetry run openapi-python-client generate --url http://localhost:3000/api/openapi.json
rm -rf openpipe/api_client
mv open-pipe-api-client/open_pipe_api_client openpipe/api_client
rm -rf open-pipe-api-client

View File

@@ -0,0 +1,10 @@
from .openai import OpenAIWrapper
from .shared import configured_client
openai = OpenAIWrapper()
def configure_openpipe(base_url=None, api_key=None):
if base_url is not None:
configured_client._base_url = base_url
if api_key is not None:
configured_client.token = api_key

View File

@@ -0,0 +1,7 @@
""" A client library for accessing OpenPipe API """
from .client import AuthenticatedClient, Client
__all__ = (
"AuthenticatedClient",
"Client",
)

View File

@@ -0,0 +1 @@
""" Contains methods for accessing the API """

View File

@@ -0,0 +1,155 @@
from http import HTTPStatus
from typing import Any, Dict, Optional, Union
import httpx
from ... import errors
from ...client import AuthenticatedClient, Client
from ...models.external_api_check_cache_json_body import ExternalApiCheckCacheJsonBody
from ...models.external_api_check_cache_response_200 import ExternalApiCheckCacheResponse200
from ...types import Response
def _get_kwargs(
*,
json_body: ExternalApiCheckCacheJsonBody,
) -> Dict[str, Any]:
pass
json_json_body = json_body.to_dict()
return {
"method": "post",
"url": "/v1/check-cache",
"json": json_json_body,
}
def _parse_response(
*, client: Union[AuthenticatedClient, Client], response: httpx.Response
) -> Optional[ExternalApiCheckCacheResponse200]:
if response.status_code == HTTPStatus.OK:
response_200 = ExternalApiCheckCacheResponse200.from_dict(response.json())
return response_200
if client.raise_on_unexpected_status:
raise errors.UnexpectedStatus(response.status_code, response.content)
else:
return None
def _build_response(
*, client: Union[AuthenticatedClient, Client], response: httpx.Response
) -> Response[ExternalApiCheckCacheResponse200]:
return Response(
status_code=HTTPStatus(response.status_code),
content=response.content,
headers=response.headers,
parsed=_parse_response(client=client, response=response),
)
def sync_detailed(
*,
client: AuthenticatedClient,
json_body: ExternalApiCheckCacheJsonBody,
) -> Response[ExternalApiCheckCacheResponse200]:
"""Check if a prompt is cached
Args:
json_body (ExternalApiCheckCacheJsonBody):
Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.
Returns:
Response[ExternalApiCheckCacheResponse200]
"""
kwargs = _get_kwargs(
json_body=json_body,
)
response = client.get_httpx_client().request(
**kwargs,
)
return _build_response(client=client, response=response)
def sync(
*,
client: AuthenticatedClient,
json_body: ExternalApiCheckCacheJsonBody,
) -> Optional[ExternalApiCheckCacheResponse200]:
"""Check if a prompt is cached
Args:
json_body (ExternalApiCheckCacheJsonBody):
Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.
Returns:
ExternalApiCheckCacheResponse200
"""
return sync_detailed(
client=client,
json_body=json_body,
).parsed
async def asyncio_detailed(
*,
client: AuthenticatedClient,
json_body: ExternalApiCheckCacheJsonBody,
) -> Response[ExternalApiCheckCacheResponse200]:
"""Check if a prompt is cached
Args:
json_body (ExternalApiCheckCacheJsonBody):
Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.
Returns:
Response[ExternalApiCheckCacheResponse200]
"""
kwargs = _get_kwargs(
json_body=json_body,
)
response = await client.get_async_httpx_client().request(**kwargs)
return _build_response(client=client, response=response)
async def asyncio(
*,
client: AuthenticatedClient,
json_body: ExternalApiCheckCacheJsonBody,
) -> Optional[ExternalApiCheckCacheResponse200]:
"""Check if a prompt is cached
Args:
json_body (ExternalApiCheckCacheJsonBody):
Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.
Returns:
ExternalApiCheckCacheResponse200
"""
return (
await asyncio_detailed(
client=client,
json_body=json_body,
)
).parsed

View File

@@ -0,0 +1,98 @@
from http import HTTPStatus
from typing import Any, Dict, Optional, Union
import httpx
from ... import errors
from ...client import AuthenticatedClient, Client
from ...models.external_api_report_json_body import ExternalApiReportJsonBody
from ...types import Response
def _get_kwargs(
*,
json_body: ExternalApiReportJsonBody,
) -> Dict[str, Any]:
pass
json_json_body = json_body.to_dict()
return {
"method": "post",
"url": "/v1/report",
"json": json_json_body,
}
def _parse_response(*, client: Union[AuthenticatedClient, Client], response: httpx.Response) -> Optional[Any]:
if response.status_code == HTTPStatus.OK:
return None
if client.raise_on_unexpected_status:
raise errors.UnexpectedStatus(response.status_code, response.content)
else:
return None
def _build_response(*, client: Union[AuthenticatedClient, Client], response: httpx.Response) -> Response[Any]:
return Response(
status_code=HTTPStatus(response.status_code),
content=response.content,
headers=response.headers,
parsed=_parse_response(client=client, response=response),
)
def sync_detailed(
*,
client: AuthenticatedClient,
json_body: ExternalApiReportJsonBody,
) -> Response[Any]:
"""Report an API call
Args:
json_body (ExternalApiReportJsonBody):
Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.
Returns:
Response[Any]
"""
kwargs = _get_kwargs(
json_body=json_body,
)
response = client.get_httpx_client().request(
**kwargs,
)
return _build_response(client=client, response=response)
async def asyncio_detailed(
*,
client: AuthenticatedClient,
json_body: ExternalApiReportJsonBody,
) -> Response[Any]:
"""Report an API call
Args:
json_body (ExternalApiReportJsonBody):
Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.
Returns:
Response[Any]
"""
kwargs = _get_kwargs(
json_body=json_body,
)
response = await client.get_async_httpx_client().request(**kwargs)
return _build_response(client=client, response=response)

View File

@@ -0,0 +1,268 @@
import ssl
from typing import Any, Dict, Optional, Union
import httpx
from attrs import define, evolve, field
@define
class Client:
"""A class for keeping track of data related to the API
The following are accepted as keyword arguments and will be used to construct httpx Clients internally:
``base_url``: The base URL for the API, all requests are made to a relative path to this URL
``cookies``: A dictionary of cookies to be sent with every request
``headers``: A dictionary of headers to be sent with every request
``timeout``: The maximum amount of a time a request can take. API functions will raise
httpx.TimeoutException if this is exceeded.
``verify_ssl``: Whether or not to verify the SSL certificate of the API server. This should be True in production,
but can be set to False for testing purposes.
``follow_redirects``: Whether or not to follow redirects. Default value is False.
``httpx_args``: A dictionary of additional arguments to be passed to the ``httpx.Client`` and ``httpx.AsyncClient`` constructor.
Attributes:
raise_on_unexpected_status: Whether or not to raise an errors.UnexpectedStatus if the API returns a
status code that was not documented in the source OpenAPI document. Can also be provided as a keyword
argument to the constructor.
"""
raise_on_unexpected_status: bool = field(default=False, kw_only=True)
_base_url: str
_cookies: Dict[str, str] = field(factory=dict, kw_only=True)
_headers: Dict[str, str] = field(factory=dict, kw_only=True)
_timeout: Optional[httpx.Timeout] = field(default=None, kw_only=True)
_verify_ssl: Union[str, bool, ssl.SSLContext] = field(default=True, kw_only=True)
_follow_redirects: bool = field(default=False, kw_only=True)
_httpx_args: Dict[str, Any] = field(factory=dict, kw_only=True)
_client: Optional[httpx.Client] = field(default=None, init=False)
_async_client: Optional[httpx.AsyncClient] = field(default=None, init=False)
def with_headers(self, headers: Dict[str, str]) -> "Client":
"""Get a new client matching this one with additional headers"""
if self._client is not None:
self._client.headers.update(headers)
if self._async_client is not None:
self._async_client.headers.update(headers)
return evolve(self, headers={**self._headers, **headers})
def with_cookies(self, cookies: Dict[str, str]) -> "Client":
"""Get a new client matching this one with additional cookies"""
if self._client is not None:
self._client.cookies.update(cookies)
if self._async_client is not None:
self._async_client.cookies.update(cookies)
return evolve(self, cookies={**self._cookies, **cookies})
def with_timeout(self, timeout: httpx.Timeout) -> "Client":
"""Get a new client matching this one with a new timeout (in seconds)"""
if self._client is not None:
self._client.timeout = timeout
if self._async_client is not None:
self._async_client.timeout = timeout
return evolve(self, timeout=timeout)
def set_httpx_client(self, client: httpx.Client) -> "Client":
"""Manually the underlying httpx.Client
**NOTE**: This will override any other settings on the client, including cookies, headers, and timeout.
"""
self._client = client
return self
def get_httpx_client(self) -> httpx.Client:
"""Get the underlying httpx.Client, constructing a new one if not previously set"""
if self._client is None:
self._client = httpx.Client(
base_url=self._base_url,
cookies=self._cookies,
headers=self._headers,
timeout=self._timeout,
verify=self._verify_ssl,
follow_redirects=self._follow_redirects,
**self._httpx_args,
)
return self._client
def __enter__(self) -> "Client":
"""Enter a context manager for self.client—you cannot enter twice (see httpx docs)"""
self.get_httpx_client().__enter__()
return self
def __exit__(self, *args: Any, **kwargs: Any) -> None:
"""Exit a context manager for internal httpx.Client (see httpx docs)"""
self.get_httpx_client().__exit__(*args, **kwargs)
def set_async_httpx_client(self, async_client: httpx.AsyncClient) -> "Client":
"""Manually the underlying httpx.AsyncClient
**NOTE**: This will override any other settings on the client, including cookies, headers, and timeout.
"""
self._async_client = async_client
return self
def get_async_httpx_client(self) -> httpx.AsyncClient:
"""Get the underlying httpx.AsyncClient, constructing a new one if not previously set"""
if self._async_client is None:
self._async_client = httpx.AsyncClient(
base_url=self._base_url,
cookies=self._cookies,
headers=self._headers,
timeout=self._timeout,
verify=self._verify_ssl,
follow_redirects=self._follow_redirects,
**self._httpx_args,
)
return self._async_client
async def __aenter__(self) -> "Client":
"""Enter a context manager for underlying httpx.AsyncClient—you cannot enter twice (see httpx docs)"""
await self.get_async_httpx_client().__aenter__()
return self
async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
"""Exit a context manager for underlying httpx.AsyncClient (see httpx docs)"""
await self.get_async_httpx_client().__aexit__(*args, **kwargs)
@define
class AuthenticatedClient:
"""A Client which has been authenticated for use on secured endpoints
The following are accepted as keyword arguments and will be used to construct httpx Clients internally:
``base_url``: The base URL for the API, all requests are made to a relative path to this URL
``cookies``: A dictionary of cookies to be sent with every request
``headers``: A dictionary of headers to be sent with every request
``timeout``: The maximum amount of a time a request can take. API functions will raise
httpx.TimeoutException if this is exceeded.
``verify_ssl``: Whether or not to verify the SSL certificate of the API server. This should be True in production,
but can be set to False for testing purposes.
``follow_redirects``: Whether or not to follow redirects. Default value is False.
``httpx_args``: A dictionary of additional arguments to be passed to the ``httpx.Client`` and ``httpx.AsyncClient`` constructor.
Attributes:
raise_on_unexpected_status: Whether or not to raise an errors.UnexpectedStatus if the API returns a
status code that was not documented in the source OpenAPI document. Can also be provided as a keyword
argument to the constructor.
token: The token to use for authentication
prefix: The prefix to use for the Authorization header
auth_header_name: The name of the Authorization header
"""
raise_on_unexpected_status: bool = field(default=False, kw_only=True)
_base_url: str
_cookies: Dict[str, str] = field(factory=dict, kw_only=True)
_headers: Dict[str, str] = field(factory=dict, kw_only=True)
_timeout: Optional[httpx.Timeout] = field(default=None, kw_only=True)
_verify_ssl: Union[str, bool, ssl.SSLContext] = field(default=True, kw_only=True)
_follow_redirects: bool = field(default=False, kw_only=True)
_httpx_args: Dict[str, Any] = field(factory=dict, kw_only=True)
_client: Optional[httpx.Client] = field(default=None, init=False)
_async_client: Optional[httpx.AsyncClient] = field(default=None, init=False)
token: str
prefix: str = "Bearer"
auth_header_name: str = "Authorization"
def with_headers(self, headers: Dict[str, str]) -> "AuthenticatedClient":
"""Get a new client matching this one with additional headers"""
if self._client is not None:
self._client.headers.update(headers)
if self._async_client is not None:
self._async_client.headers.update(headers)
return evolve(self, headers={**self._headers, **headers})
def with_cookies(self, cookies: Dict[str, str]) -> "AuthenticatedClient":
"""Get a new client matching this one with additional cookies"""
if self._client is not None:
self._client.cookies.update(cookies)
if self._async_client is not None:
self._async_client.cookies.update(cookies)
return evolve(self, cookies={**self._cookies, **cookies})
def with_timeout(self, timeout: httpx.Timeout) -> "AuthenticatedClient":
"""Get a new client matching this one with a new timeout (in seconds)"""
if self._client is not None:
self._client.timeout = timeout
if self._async_client is not None:
self._async_client.timeout = timeout
return evolve(self, timeout=timeout)
def set_httpx_client(self, client: httpx.Client) -> "AuthenticatedClient":
"""Manually the underlying httpx.Client
**NOTE**: This will override any other settings on the client, including cookies, headers, and timeout.
"""
self._client = client
return self
def get_httpx_client(self) -> httpx.Client:
"""Get the underlying httpx.Client, constructing a new one if not previously set"""
if self._client is None:
self._headers[self.auth_header_name] = f"{self.prefix} {self.token}" if self.prefix else self.token
self._client = httpx.Client(
base_url=self._base_url,
cookies=self._cookies,
headers=self._headers,
timeout=self._timeout,
verify=self._verify_ssl,
follow_redirects=self._follow_redirects,
**self._httpx_args,
)
return self._client
def __enter__(self) -> "AuthenticatedClient":
"""Enter a context manager for self.client—you cannot enter twice (see httpx docs)"""
self.get_httpx_client().__enter__()
return self
def __exit__(self, *args: Any, **kwargs: Any) -> None:
"""Exit a context manager for internal httpx.Client (see httpx docs)"""
self.get_httpx_client().__exit__(*args, **kwargs)
def set_async_httpx_client(self, async_client: httpx.AsyncClient) -> "AuthenticatedClient":
"""Manually the underlying httpx.AsyncClient
**NOTE**: This will override any other settings on the client, including cookies, headers, and timeout.
"""
self._async_client = async_client
return self
def get_async_httpx_client(self) -> httpx.AsyncClient:
"""Get the underlying httpx.AsyncClient, constructing a new one if not previously set"""
if self._async_client is None:
self._headers[self.auth_header_name] = f"{self.prefix} {self.token}" if self.prefix else self.token
self._async_client = httpx.AsyncClient(
base_url=self._base_url,
cookies=self._cookies,
headers=self._headers,
timeout=self._timeout,
verify=self._verify_ssl,
follow_redirects=self._follow_redirects,
**self._httpx_args,
)
return self._async_client
async def __aenter__(self) -> "AuthenticatedClient":
"""Enter a context manager for underlying httpx.AsyncClient—you cannot enter twice (see httpx docs)"""
await self.get_async_httpx_client().__aenter__()
return self
async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
"""Exit a context manager for underlying httpx.AsyncClient (see httpx docs)"""
await self.get_async_httpx_client().__aexit__(*args, **kwargs)

View File

@@ -0,0 +1,14 @@
""" Contains shared errors types that can be raised from API functions """
class UnexpectedStatus(Exception):
"""Raised by api functions when the response status an undocumented status and Client.raise_on_unexpected_status is True"""
def __init__(self, status_code: int, content: bytes):
self.status_code = status_code
self.content = content
super().__init__(f"Unexpected status code: {status_code}")
__all__ = ["UnexpectedStatus"]

Some files were not shown because too many files have changed in this diff Show More