Compare commits

..

5 Commits

Author SHA1 Message Date
Kyle Corbitt
4e2ae7a441 Enqueue tasks more efficiently
Previously we were opening a new database connection for each task we added. Not a problem at small scale but kinda overwhelming for Postgres now that we have more usage.
2023-08-17 22:42:46 -07:00
Kyle Corbitt
94464c0617 Admin dashboard for jobs
Extremely simple jobs dashboard to sanity-check what we've got going on in the job queue.
2023-08-17 22:20:39 -07:00
arcticfly
980644f13c Support vicuna system message (#167)
* Support vicuna system message

* Change tags to USER and ASSISTANT
2023-08-17 21:02:27 -07:00
arcticfly
6a56250001 Add platypus 13b, vicuna 13b, and nous hermes 7b (#166)
* Add platypus

* Add vicuna 13b and nous hermes 7b
2023-08-17 20:01:10 -07:00
Kyle Corbitt
b1c7bbbd4a Merge pull request #165 from OpenPipe/better-output
Don't define CellWrapper inline
2023-08-17 19:07:32 -07:00
19 changed files with 597 additions and 96 deletions

View File

@@ -12,6 +12,7 @@ declare module "nextjs-routes" {
export type Route = export type Route =
| StaticRoute<"/account/signin"> | StaticRoute<"/account/signin">
| StaticRoute<"/admin/jobs">
| DynamicRoute<"/api/auth/[...nextauth]", { "nextauth": string[] }> | DynamicRoute<"/api/auth/[...nextauth]", { "nextauth": string[] }>
| StaticRoute<"/api/experiments/og-image"> | StaticRoute<"/api/experiments/og-image">
| DynamicRoute<"/api/trpc/[trpc]", { "trpc": string }> | DynamicRoute<"/api/trpc/[trpc]", { "trpc": string }>

View File

@@ -18,6 +18,7 @@
"lint": "next lint", "lint": "next lint",
"start": "TZ=UTC next start", "start": "TZ=UTC next start",
"codegen:clients": "tsx src/server/scripts/client-codegen.ts", "codegen:clients": "tsx src/server/scripts/client-codegen.ts",
"codegen:db": "prisma generate && kysely-codegen --dialect postgres --out-file src/server/db.types.ts",
"seed": "tsx prisma/seed.ts", "seed": "tsx prisma/seed.ts",
"check": "concurrently 'pnpm lint' 'pnpm tsc' 'pnpm prettier . --check'", "check": "concurrently 'pnpm lint' 'pnpm tsc' 'pnpm prettier . --check'",
"test": "pnpm vitest" "test": "pnpm vitest"
@@ -65,6 +66,7 @@
"json-stringify-pretty-compact": "^4.0.0", "json-stringify-pretty-compact": "^4.0.0",
"jsonschema": "^1.4.1", "jsonschema": "^1.4.1",
"kysely": "^0.26.1", "kysely": "^0.26.1",
"kysely-codegen": "^0.10.1",
"lodash-es": "^4.17.21", "lodash-es": "^4.17.21",
"lucide-react": "^0.265.0", "lucide-react": "^0.265.0",
"marked": "^7.0.3", "marked": "^7.0.3",

View File

@@ -10,6 +10,14 @@ await prisma.project.deleteMany({
where: { id: defaultId }, where: { id: defaultId },
}); });
// Mark all users as admins
await prisma.user.updateMany({
where: {},
data: {
role: "ADMIN",
},
});
// If there's an existing project, just seed into it // If there's an existing project, just seed into it
const project = const project =
(await prisma.project.findFirst({})) ?? (await prisma.project.findFirst({})) ??
@@ -18,12 +26,16 @@ const project =
})); }));
if (env.OPENPIPE_API_KEY) { if (env.OPENPIPE_API_KEY) {
await prisma.apiKey.create({ await prisma.apiKey.upsert({
data: { where: {
apiKey: env.OPENPIPE_API_KEY,
},
create: {
projectId: project.id, projectId: project.id,
name: "Default API Key", name: "Default API Key",
apiKey: env.OPENPIPE_API_KEY, apiKey: env.OPENPIPE_API_KEY,
}, },
update: {},
}); });
} }

View File

@@ -12,7 +12,6 @@ export const refinementActions: Record<string, RefinementAction> = {
definePrompt("openai/ChatCompletion", { definePrompt("openai/ChatCompletion", {
model: "gpt-4", model: "gpt-4",
stream: true,
messages: [ messages: [
{ {
role: "system", role: "system",
@@ -29,7 +28,6 @@ export const refinementActions: Record<string, RefinementAction> = {
definePrompt("openai/ChatCompletion", { definePrompt("openai/ChatCompletion", {
model: "gpt-4", model: "gpt-4",
stream: true,
messages: [ messages: [
{ {
role: "system", role: "system",
@@ -126,7 +124,6 @@ export const refinementActions: Record<string, RefinementAction> = {
definePrompt("openai/ChatCompletion", { definePrompt("openai/ChatCompletion", {
model: "gpt-4", model: "gpt-4",
stream: true,
messages: [ messages: [
{ {
role: "system", role: "system",
@@ -143,7 +140,6 @@ export const refinementActions: Record<string, RefinementAction> = {
definePrompt("openai/ChatCompletion", { definePrompt("openai/ChatCompletion", {
model: "gpt-4", model: "gpt-4",
stream: true,
messages: [ messages: [
{ {
role: "system", role: "system",
@@ -237,7 +233,6 @@ export const refinementActions: Record<string, RefinementAction> = {
definePrompt("openai/ChatCompletion", { definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo", model: "gpt-3.5-turbo",
stream: true,
messages: [ messages: [
{ {
role: "system", role: "system",

View File

@@ -3,10 +3,11 @@ import { type FrontendModelProvider } from "../types";
import { refinementActions } from "./refinementActions"; import { refinementActions } from "./refinementActions";
import { import {
templateOpenOrcaPrompt, templateOpenOrcaPrompt,
// templateAlpacaInstructPrompt, templateAlpacaInstructPrompt,
// templateSystemUserAssistantPrompt, // templateSystemUserAssistantPrompt,
templateInstructionInputResponsePrompt, templateInstructionInputResponsePrompt,
templateAiroborosPrompt, templateAiroborosPrompt,
templateVicunaPrompt,
} from "./templatePrompt"; } from "./templatePrompt";
const frontendModelProvider: FrontendModelProvider<SupportedModel, OpenpipeChatOutput> = { const frontendModelProvider: FrontendModelProvider<SupportedModel, OpenpipeChatOutput> = {
@@ -22,15 +23,16 @@ const frontendModelProvider: FrontendModelProvider<SupportedModel, OpenpipeChatO
learnMoreUrl: "https://huggingface.co/Open-Orca/OpenOrcaxOpenChat-Preview2-13B", learnMoreUrl: "https://huggingface.co/Open-Orca/OpenOrcaxOpenChat-Preview2-13B",
templatePrompt: templateOpenOrcaPrompt, templatePrompt: templateOpenOrcaPrompt,
}, },
// "Open-Orca/OpenOrca-Platypus2-13B": { "Open-Orca/OpenOrca-Platypus2-13B": {
// name: "OpenOrca-Platypus2-13B", name: "OpenOrca-Platypus2-13B",
// contextWindow: 4096, contextWindow: 4096,
// pricePerSecond: 0.0003, pricePerSecond: 0.0003,
// speed: "medium", speed: "medium",
// provider: "openpipe/Chat", provider: "openpipe/Chat",
// learnMoreUrl: "https://huggingface.co/Open-Orca/OpenOrca-Platypus2-13B", learnMoreUrl: "https://huggingface.co/Open-Orca/OpenOrca-Platypus2-13B",
// templatePrompt: templateAlpacaInstructPrompt, templatePrompt: templateAlpacaInstructPrompt,
// }, defaultStopTokens: ["</s>"],
},
// "stabilityai/StableBeluga-13B": { // "stabilityai/StableBeluga-13B": {
// name: "StableBeluga-13B", // name: "StableBeluga-13B",
// contextWindow: 4096, // contextWindow: 4096,
@@ -58,6 +60,24 @@ const frontendModelProvider: FrontendModelProvider<SupportedModel, OpenpipeChatO
learnMoreUrl: "https://huggingface.co/jondurbin/airoboros-l2-13b-gpt4-2.0", learnMoreUrl: "https://huggingface.co/jondurbin/airoboros-l2-13b-gpt4-2.0",
templatePrompt: templateAiroborosPrompt, templatePrompt: templateAiroborosPrompt,
}, },
"lmsys/vicuna-13b-v1.5": {
name: "vicuna-13b-v1.5",
contextWindow: 4096,
pricePerSecond: 0.0003,
speed: "medium",
provider: "openpipe/Chat",
learnMoreUrl: "https://huggingface.co/lmsys/vicuna-13b-v1.5",
templatePrompt: templateVicunaPrompt,
},
"NousResearch/Nous-Hermes-llama-2-7b": {
name: "Nous-Hermes-llama-2-7b",
contextWindow: 4096,
pricePerSecond: 0.0003,
speed: "medium",
provider: "openpipe/Chat",
learnMoreUrl: "https://huggingface.co/NousResearch/Nous-Hermes-llama-2-7b",
templatePrompt: templateInstructionInputResponsePrompt,
},
}, },
refinementActions, refinementActions,

View File

@@ -8,10 +8,12 @@ import frontendModelProvider from "./frontend";
const modelEndpoints: Record<OpenpipeChatInput["model"], string> = { const modelEndpoints: Record<OpenpipeChatInput["model"], string> = {
"Open-Orca/OpenOrcaxOpenChat-Preview2-13B": "https://5ef82gjxk8kdys-8000.proxy.runpod.net/v1", "Open-Orca/OpenOrcaxOpenChat-Preview2-13B": "https://5ef82gjxk8kdys-8000.proxy.runpod.net/v1",
// "Open-Orca/OpenOrca-Platypus2-13B": "https://lt5qlel6qcji8t-8000.proxy.runpod.net/v1", "Open-Orca/OpenOrca-Platypus2-13B": "https://lt5qlel6qcji8t-8000.proxy.runpod.net/v1",
// "stabilityai/StableBeluga-13B": "https://vcorl8mxni2ou1-8000.proxy.runpod.net/v1", // "stabilityai/StableBeluga-13B": "https://vcorl8mxni2ou1-8000.proxy.runpod.net/v1",
"NousResearch/Nous-Hermes-Llama2-13b": "https://ncv8pw3u0vb8j2-8000.proxy.runpod.net/v1", "NousResearch/Nous-Hermes-Llama2-13b": "https://ncv8pw3u0vb8j2-8000.proxy.runpod.net/v1",
"jondurbin/airoboros-l2-13b-gpt4-2.0": "https://9nrbx7oph4btou-8000.proxy.runpod.net/v1", "jondurbin/airoboros-l2-13b-gpt4-2.0": "https://9nrbx7oph4btou-8000.proxy.runpod.net/v1",
"lmsys/vicuna-13b-v1.5": "https://h88hkt3ux73rb7-8000.proxy.runpod.net/v1",
"NousResearch/Nous-Hermes-llama-2-7b": "https://ua1bpc6kv3dgge-8000.proxy.runpod.net/v1",
}; };
export async function getCompletion( export async function getCompletion(
@@ -36,10 +38,20 @@ export async function getCompletion(
const start = Date.now(); const start = Date.now();
let finalCompletion: OpenpipeChatOutput = ""; let finalCompletion: OpenpipeChatOutput = "";
const completionParams = {
model,
prompt: templatedPrompt,
...rest,
};
if (!completionParams.stop && frontendModelProvider.models[model].defaultStopTokens) {
completionParams.stop = frontendModelProvider.models[model].defaultStopTokens;
}
try { try {
if (onStream) { if (onStream) {
const resp = await openai.completions.create( const resp = await openai.completions.create(
{ model, prompt: templatedPrompt, ...rest, stream: true }, { ...completionParams, stream: true },
{ {
maxRetries: 0, maxRetries: 0,
}, },
@@ -58,7 +70,7 @@ export async function getCompletion(
} }
} else { } else {
const resp = await openai.completions.create( const resp = await openai.completions.create(
{ model, prompt: templatedPrompt, ...rest, stream: false }, { ...completionParams, stream: false },
{ {
maxRetries: 0, maxRetries: 0,
}, },

View File

@@ -6,10 +6,12 @@ import frontendModelProvider from "./frontend";
const supportedModels = [ const supportedModels = [
"Open-Orca/OpenOrcaxOpenChat-Preview2-13B", "Open-Orca/OpenOrcaxOpenChat-Preview2-13B",
// "Open-Orca/OpenOrca-Platypus2-13B", "Open-Orca/OpenOrca-Platypus2-13B",
// "stabilityai/StableBeluga-13B", // "stabilityai/StableBeluga-13B",
"NousResearch/Nous-Hermes-Llama2-13b", "NousResearch/Nous-Hermes-Llama2-13b",
"jondurbin/airoboros-l2-13b-gpt4-2.0", "jondurbin/airoboros-l2-13b-gpt4-2.0",
"lmsys/vicuna-13b-v1.5",
"NousResearch/Nous-Hermes-llama-2-7b",
] as const; ] as const;
export type SupportedModel = (typeof supportedModels)[number]; export type SupportedModel = (typeof supportedModels)[number];

View File

@@ -7,8 +7,11 @@
"type": "string", "type": "string",
"enum": [ "enum": [
"Open-Orca/OpenOrcaxOpenChat-Preview2-13B", "Open-Orca/OpenOrcaxOpenChat-Preview2-13B",
"Open-Orca/OpenOrca-Platypus2-13B",
"NousResearch/Nous-Hermes-Llama2-13b", "NousResearch/Nous-Hermes-Llama2-13b",
"jondurbin/airoboros-l2-13b-gpt4-2.0" "jondurbin/airoboros-l2-13b-gpt4-2.0",
"lmsys/vicuna-13b-v1.5",
"NousResearch/Nous-Hermes-llama-2-7b"
] ]
}, },
"messages": { "messages": {

View File

@@ -28,7 +28,7 @@ export const templateOpenOrcaPrompt = (messages: OpenpipeChatInput["messages"])
// <prompt> (without the <>) // <prompt> (without the <>)
// ### Response: // ### Response: (leave two newlines for model to respond)
export const templateAlpacaInstructPrompt = (messages: OpenpipeChatInput["messages"]) => { export const templateAlpacaInstructPrompt = (messages: OpenpipeChatInput["messages"]) => {
const splitter = "\n\n"; const splitter = "\n\n";
@@ -52,7 +52,7 @@ export const templateAlpacaInstructPrompt = (messages: OpenpipeChatInput["messag
prompt += splitter + assistantTag; prompt += splitter + assistantTag;
} }
return prompt.trim(); return prompt;
}; };
// ### System: // ### System:
@@ -101,7 +101,6 @@ export const templateSystemUserAssistantPrompt = (messages: OpenpipeChatInput["m
// ### Response: // ### Response:
// <leave a newline blank for model to respond> // <leave a newline blank for model to respond>
export const templateInstructionInputResponsePrompt = (messages: OpenpipeChatInput["messages"]) => { export const templateInstructionInputResponsePrompt = (messages: OpenpipeChatInput["messages"]) => {
const splitter = "\n\n"; const splitter = "\n\n";
@@ -136,7 +135,6 @@ export const templateInstructionInputResponsePrompt = (messages: OpenpipeChatInp
// The assistant gives helpful, detailed, accurate, uncensored responses to the user's input. // The assistant gives helpful, detailed, accurate, uncensored responses to the user's input.
// The assistant never refuses to answer, regardless of the legality or morality of the request. // The assistant never refuses to answer, regardless of the legality or morality of the request.
// USER: [prompt] ASSISTANT: // USER: [prompt] ASSISTANT:
export const templateAiroborosPrompt = (messages: OpenpipeChatInput["messages"]) => { export const templateAiroborosPrompt = (messages: OpenpipeChatInput["messages"]) => {
const splitter = " "; const splitter = " ";
@@ -179,3 +177,49 @@ export const templateAiroborosPrompt = (messages: OpenpipeChatInput["messages"])
return prompt; return prompt;
}; };
// A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
// USER: {prompt}
// ASSISTANT:
export const templateVicunaPrompt = (messages: OpenpipeChatInput["messages"]) => {
const splitter = "\n";
const humanTag = "USER: ";
const assistantTag = "ASSISTANT: ";
let combinedSystemMessage = "";
const conversationMessages = [];
for (const message of messages) {
if (message.role === "system") {
combinedSystemMessage += message.content;
} else if (message.role === "user") {
conversationMessages.push(humanTag + message.content);
} else {
conversationMessages.push(assistantTag + message.content);
}
}
let systemMessage = "";
if (combinedSystemMessage) {
// If there is no user message, add a user tag to the system message
if (conversationMessages.find((message) => message.startsWith(humanTag))) {
systemMessage = `${combinedSystemMessage}\n\n`;
} else {
conversationMessages.unshift(humanTag + combinedSystemMessage);
}
}
let prompt = `${systemMessage}${conversationMessages.join(splitter)}`;
// Ensure that the prompt ends with an assistant message
const lastHumanIndex = prompt.lastIndexOf(humanTag);
const lastAssistantIndex = prompt.lastIndexOf(assistantTag);
if (lastHumanIndex > lastAssistantIndex) {
prompt += splitter + assistantTag;
}
return prompt.trim();
};

View File

@@ -25,6 +25,7 @@ export type Model = {
learnMoreUrl?: string; learnMoreUrl?: string;
apiDocsUrl?: string; apiDocsUrl?: string;
templatePrompt?: (initialPrompt: OpenpipeChatInput["messages"]) => string; templatePrompt?: (initialPrompt: OpenpipeChatInput["messages"]) => string;
defaultStopTokens?: string[];
}; };
export type ProviderModel = { provider: z.infer<typeof ZodSupportedProvider>; model: string }; export type ProviderModel = { provider: z.infer<typeof ZodSupportedProvider>; model: string };

View File

@@ -0,0 +1,54 @@
import { Card, Table, Tbody, Td, Th, Thead, Tr } from "@chakra-ui/react";
import dayjs from "dayjs";
import { isDate, isObject, isString } from "lodash-es";
import AppShell from "~/components/nav/AppShell";
import { type RouterOutputs, api } from "~/utils/api";
const fieldsToShow: (keyof RouterOutputs["adminJobs"]["list"][0])[] = [
"id",
"queue_name",
"payload",
"priority",
"attempts",
"last_error",
"created_at",
"key",
"locked_at",
"run_at",
];
export default function Jobs() {
const jobs = api.adminJobs.list.useQuery({});
return (
<AppShell title="Admin Jobs">
<Card m={4} overflowX="auto">
<Table>
<Thead>
<Tr>
{fieldsToShow.map((field) => (
<Th key={field}>{field}</Th>
))}
</Tr>
</Thead>
<Tbody>
{jobs.data?.map((job) => (
<Tr key={job.id}>
{fieldsToShow.map((field) => {
// Check if object
let value = job[field];
if (isDate(value)) {
value = dayjs(value).format("YYYY-MM-DD HH:mm:ss");
} else if (isObject(value) && !isString(value)) {
value = JSON.stringify(value);
} // check if date
return <Td key={field}>{value}</Td>;
})}
</Tr>
))}
</Tbody>
</Table>
</Card>
</AppShell>
);
}

View File

@@ -12,6 +12,7 @@ import { projectsRouter } from "./routers/projects.router";
import { dashboardRouter } from "./routers/dashboard.router"; import { dashboardRouter } from "./routers/dashboard.router";
import { loggedCallsRouter } from "./routers/loggedCalls.router"; import { loggedCallsRouter } from "./routers/loggedCalls.router";
import { usersRouter } from "./routers/users.router"; import { usersRouter } from "./routers/users.router";
import { adminJobsRouter } from "./routers/adminJobs.router";
/** /**
* This is the primary router for your server. * This is the primary router for your server.
@@ -32,6 +33,7 @@ export const appRouter = createTRPCRouter({
dashboard: dashboardRouter, dashboard: dashboardRouter,
loggedCalls: loggedCallsRouter, loggedCalls: loggedCallsRouter,
users: usersRouter, users: usersRouter,
adminJobs: adminJobsRouter,
}); });
// export type definition of API // export type definition of API

View File

@@ -0,0 +1,18 @@
import { z } from "zod";
import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc";
import { kysely } from "~/server/db";
import { requireIsAdmin } from "~/utils/accessControl";
export const adminJobsRouter = createTRPCRouter({
list: protectedProcedure.input(z.object({})).query(async ({ ctx }) => {
await requireIsAdmin(ctx);
return await kysely
.selectFrom("graphile_worker.jobs")
.limit(100)
.selectAll()
.orderBy("created_at", "desc")
.execute();
}),
});

View File

@@ -335,7 +335,6 @@ export const experimentsRouter = createTRPCRouter({
definePrompt("openai/ChatCompletion", { definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
stream: true,
messages: [ messages: [
{ {
role: "system", role: "system",

View File

@@ -1,27 +1,6 @@
import { import { type DB } from "./db.types";
type Experiment,
type PromptVariant, import { PrismaClient } from "@prisma/client";
type TestScenario,
type TemplateVariable,
type ScenarioVariantCell,
type ModelResponse,
type Evaluation,
type OutputEvaluation,
type Dataset,
type DatasetEntry,
type Project,
type ProjectUser,
type WorldChampEntrant,
type LoggedCall,
type LoggedCallModelResponse,
type LoggedCallTag,
type ApiKey,
type Account,
type Session,
type User,
type VerificationToken,
PrismaClient,
} from "@prisma/client";
import { Kysely, PostgresDialect } from "kysely"; import { Kysely, PostgresDialect } from "kysely";
// TODO: Revert to normal import when our tsconfig.json is fixed // TODO: Revert to normal import when our tsconfig.json is fixed
// import { Pool } from "pg"; // import { Pool } from "pg";
@@ -32,30 +11,6 @@ const Pool = (UntypedPool.default ? UntypedPool.default : UntypedPool) as typeof
import { env } from "~/env.mjs"; import { env } from "~/env.mjs";
interface DB {
Experiment: Experiment;
PromptVariant: PromptVariant;
TestScenario: TestScenario;
TemplateVariable: TemplateVariable;
ScenarioVariantCell: ScenarioVariantCell;
ModelResponse: ModelResponse;
Evaluation: Evaluation;
OutputEvaluation: OutputEvaluation;
Dataset: Dataset;
DatasetEntry: DatasetEntry;
Project: Project;
ProjectUser: ProjectUser;
WorldChampEntrant: WorldChampEntrant;
LoggedCall: LoggedCall;
LoggedCallModelResponse: LoggedCallModelResponse;
LoggedCallTag: LoggedCallTag;
ApiKey: ApiKey;
Account: Account;
Session: Session;
User: User;
VerificationToken: VerificationToken;
}
const globalForPrisma = globalThis as unknown as { const globalForPrisma = globalThis as unknown as {
prisma: PrismaClient | undefined; prisma: PrismaClient | undefined;
}; };

336
app/src/server/db.types.ts Normal file
View File

@@ -0,0 +1,336 @@
import type { ColumnType } from "kysely";
export type Generated<T> = T extends ColumnType<infer S, infer I, infer U>
? ColumnType<S, I | undefined, U>
: ColumnType<T, T | undefined, T>;
export type Int8 = ColumnType<string, string | number | bigint, string | number | bigint>;
export type Json = ColumnType<JsonValue, string, string>;
export type JsonArray = JsonValue[];
export type JsonObject = {
[K in string]?: JsonValue;
};
export type JsonPrimitive = boolean | null | number | string;
export type JsonValue = JsonArray | JsonObject | JsonPrimitive;
export type Numeric = ColumnType<string, string | number, string | number>;
export type Timestamp = ColumnType<Date, Date | string, Date | string>;
export interface _PrismaMigrations {
id: string;
checksum: string;
finished_at: Timestamp | null;
migration_name: string;
logs: string | null;
rolled_back_at: Timestamp | null;
started_at: Generated<Timestamp>;
applied_steps_count: Generated<number>;
}
export interface Account {
id: string;
userId: string;
type: string;
provider: string;
providerAccountId: string;
refresh_token: string | null;
refresh_token_expires_in: number | null;
access_token: string | null;
expires_at: number | null;
token_type: string | null;
scope: string | null;
id_token: string | null;
session_state: string | null;
}
export interface ApiKey {
id: string;
name: string;
apiKey: string;
projectId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
}
export interface Dataset {
id: string;
name: string;
projectId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
}
export interface DatasetEntry {
id: string;
input: string;
output: string | null;
datasetId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
}
export interface Evaluation {
id: string;
label: string;
value: string;
evalType: "CONTAINS" | "DOES_NOT_CONTAIN" | "GPT4_EVAL";
experimentId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
}
export interface Experiment {
id: string;
label: string;
sortIndex: Generated<number>;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
projectId: string;
}
export interface GraphileWorkerJobQueues {
queue_name: string;
job_count: number;
locked_at: Timestamp | null;
locked_by: string | null;
}
export interface GraphileWorkerJobs {
id: Generated<Int8>;
queue_name: string | null;
task_identifier: string;
payload: Generated<Json>;
priority: Generated<number>;
run_at: Generated<Timestamp>;
attempts: Generated<number>;
max_attempts: Generated<number>;
last_error: string | null;
created_at: Generated<Timestamp>;
updated_at: Generated<Timestamp>;
key: string | null;
locked_at: Timestamp | null;
locked_by: string | null;
revision: Generated<number>;
flags: Json | null;
}
export interface GraphileWorkerKnownCrontabs {
identifier: string;
known_since: Timestamp;
last_execution: Timestamp | null;
}
export interface GraphileWorkerMigrations {
id: number;
ts: Generated<Timestamp>;
}
export interface LoggedCall {
id: string;
requestedAt: Timestamp;
cacheHit: boolean;
modelResponseId: string | null;
projectId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
model: string | null;
}
export interface LoggedCallModelResponse {
id: string;
reqPayload: Json;
statusCode: number | null;
respPayload: Json | null;
errorMessage: string | null;
requestedAt: Timestamp;
receivedAt: Timestamp;
cacheKey: string | null;
durationMs: number | null;
inputTokens: number | null;
outputTokens: number | null;
finishReason: string | null;
completionId: string | null;
cost: Numeric | null;
originalLoggedCallId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
}
export interface LoggedCallTag {
id: string;
name: string;
value: string | null;
loggedCallId: string;
projectId: string;
}
export interface ModelResponse {
id: string;
cacheKey: string;
respPayload: Json | null;
inputTokens: number | null;
outputTokens: number | null;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
scenarioVariantCellId: string;
cost: number | null;
requestedAt: Timestamp | null;
receivedAt: Timestamp | null;
statusCode: number | null;
errorMessage: string | null;
retryTime: Timestamp | null;
outdated: Generated<boolean>;
}
export interface OutputEvaluation {
id: string;
result: number;
details: string | null;
modelResponseId: string;
evaluationId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
}
export interface Project {
id: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
personalProjectUserId: string | null;
name: Generated<string>;
}
export interface ProjectUser {
id: string;
role: "ADMIN" | "MEMBER" | "VIEWER";
projectId: string;
userId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
}
export interface PromptVariant {
id: string;
label: string;
uiId: string;
visible: Generated<boolean>;
sortIndex: Generated<number>;
experimentId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
promptConstructor: string;
model: string;
promptConstructorVersion: number;
modelProvider: string;
}
export interface ScenarioVariantCell {
id: string;
errorMessage: string | null;
promptVariantId: string;
testScenarioId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
retrievalStatus: Generated<"COMPLETE" | "ERROR" | "IN_PROGRESS" | "PENDING">;
prompt: Json | null;
jobQueuedAt: Timestamp | null;
jobStartedAt: Timestamp | null;
}
export interface Session {
id: string;
sessionToken: string;
userId: string;
expires: Timestamp;
}
export interface TemplateVariable {
id: string;
label: string;
experimentId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
}
export interface TestScenario {
id: string;
variableValues: Json;
uiId: string;
visible: Generated<boolean>;
sortIndex: Generated<number>;
experimentId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
}
export interface User {
id: string;
name: string | null;
email: string | null;
emailVerified: Timestamp | null;
image: string | null;
createdAt: Generated<Timestamp>;
updatedAt: Generated<Timestamp>;
role: Generated<"ADMIN" | "USER">;
}
export interface UserInvitation {
id: string;
projectId: string;
email: string;
role: "ADMIN" | "MEMBER" | "VIEWER";
invitationToken: string;
senderId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
}
export interface VerificationToken {
identifier: string;
token: string;
expires: Timestamp;
}
export interface WorldChampEntrant {
id: string;
userId: string;
approved: Generated<boolean>;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
}
export interface DB {
_prisma_migrations: _PrismaMigrations;
Account: Account;
ApiKey: ApiKey;
Dataset: Dataset;
DatasetEntry: DatasetEntry;
Evaluation: Evaluation;
Experiment: Experiment;
"graphile_worker.job_queues": GraphileWorkerJobQueues;
"graphile_worker.jobs": GraphileWorkerJobs;
"graphile_worker.known_crontabs": GraphileWorkerKnownCrontabs;
"graphile_worker.migrations": GraphileWorkerMigrations;
LoggedCall: LoggedCall;
LoggedCallModelResponse: LoggedCallModelResponse;
LoggedCallTag: LoggedCallTag;
ModelResponse: ModelResponse;
OutputEvaluation: OutputEvaluation;
Project: Project;
ProjectUser: ProjectUser;
PromptVariant: PromptVariant;
ScenarioVariantCell: ScenarioVariantCell;
Session: Session;
TemplateVariable: TemplateVariable;
TestScenario: TestScenario;
User: User;
UserInvitation: UserInvitation;
VerificationToken: VerificationToken;
WorldChampEntrant: WorldChampEntrant;
}

View File

@@ -1,15 +1,24 @@
// Import necessary dependencies import { type Helpers, type Task, makeWorkerUtils } from "graphile-worker";
import { quickAddJob, type Helpers, type Task } from "graphile-worker";
import { env } from "~/env.mjs"; import { env } from "~/env.mjs";
// Define the defineTask function let workerUtilsPromise: ReturnType<typeof makeWorkerUtils> | null = null;
function workerUtils() {
if (!workerUtilsPromise) {
workerUtilsPromise = makeWorkerUtils({
connectionString: env.DATABASE_URL,
});
}
return workerUtilsPromise;
}
function defineTask<TPayload>( function defineTask<TPayload>(
taskIdentifier: string, taskIdentifier: string,
taskHandler: (payload: TPayload, helpers: Helpers) => Promise<void>, taskHandler: (payload: TPayload, helpers: Helpers) => Promise<void>,
) { ) {
const enqueue = async (payload: TPayload, runAt?: Date) => { const enqueue = async (payload: TPayload, runAt?: Date) => {
console.log("Enqueuing task", taskIdentifier, payload); console.log("Enqueuing task", taskIdentifier, payload);
await quickAddJob({ connectionString: env.DATABASE_URL }, taskIdentifier, payload, { runAt }); await (await workerUtils()).addJob(taskIdentifier, payload, { runAt });
}; };
const handler = (payload: TPayload, helpers: Helpers) => { const handler = (payload: TPayload, helpers: Helpers) => {

View File

@@ -17,6 +17,8 @@ export const requireNothing = (ctx: TRPCContext) => {
}; };
export const requireIsProjectAdmin = async (projectId: string, ctx: TRPCContext) => { export const requireIsProjectAdmin = async (projectId: string, ctx: TRPCContext) => {
ctx.markAccessControlRun();
const userId = ctx.session?.user.id; const userId = ctx.session?.user.id;
if (!userId) { if (!userId) {
throw new TRPCError({ code: "UNAUTHORIZED" }); throw new TRPCError({ code: "UNAUTHORIZED" });
@@ -33,11 +35,11 @@ export const requireIsProjectAdmin = async (projectId: string, ctx: TRPCContext)
if (!isAdmin) { if (!isAdmin) {
throw new TRPCError({ code: "UNAUTHORIZED" }); throw new TRPCError({ code: "UNAUTHORIZED" });
} }
ctx.markAccessControlRun();
}; };
export const requireCanViewProject = async (projectId: string, ctx: TRPCContext) => { export const requireCanViewProject = async (projectId: string, ctx: TRPCContext) => {
ctx.markAccessControlRun();
const userId = ctx.session?.user.id; const userId = ctx.session?.user.id;
if (!userId) { if (!userId) {
throw new TRPCError({ code: "UNAUTHORIZED" }); throw new TRPCError({ code: "UNAUTHORIZED" });
@@ -53,11 +55,11 @@ export const requireCanViewProject = async (projectId: string, ctx: TRPCContext)
if (!canView) { if (!canView) {
throw new TRPCError({ code: "UNAUTHORIZED" }); throw new TRPCError({ code: "UNAUTHORIZED" });
} }
ctx.markAccessControlRun();
}; };
export const requireCanModifyProject = async (projectId: string, ctx: TRPCContext) => { export const requireCanModifyProject = async (projectId: string, ctx: TRPCContext) => {
ctx.markAccessControlRun();
const userId = ctx.session?.user.id; const userId = ctx.session?.user.id;
if (!userId) { if (!userId) {
throw new TRPCError({ code: "UNAUTHORIZED" }); throw new TRPCError({ code: "UNAUTHORIZED" });
@@ -74,11 +76,11 @@ export const requireCanModifyProject = async (projectId: string, ctx: TRPCContex
if (!canModify) { if (!canModify) {
throw new TRPCError({ code: "UNAUTHORIZED" }); throw new TRPCError({ code: "UNAUTHORIZED" });
} }
ctx.markAccessControlRun();
}; };
export const requireCanViewDataset = async (datasetId: string, ctx: TRPCContext) => { export const requireCanViewDataset = async (datasetId: string, ctx: TRPCContext) => {
ctx.markAccessControlRun();
const dataset = await prisma.dataset.findFirst({ const dataset = await prisma.dataset.findFirst({
where: { where: {
id: datasetId, id: datasetId,
@@ -96,8 +98,6 @@ export const requireCanViewDataset = async (datasetId: string, ctx: TRPCContext)
if (!dataset) { if (!dataset) {
throw new TRPCError({ code: "UNAUTHORIZED" }); throw new TRPCError({ code: "UNAUTHORIZED" });
} }
ctx.markAccessControlRun();
}; };
export const requireCanModifyDataset = async (datasetId: string, ctx: TRPCContext) => { export const requireCanModifyDataset = async (datasetId: string, ctx: TRPCContext) => {
@@ -105,13 +105,10 @@ export const requireCanModifyDataset = async (datasetId: string, ctx: TRPCContex
await requireCanViewDataset(datasetId, ctx); await requireCanViewDataset(datasetId, ctx);
}; };
export const requireCanViewExperiment = async (experimentId: string, ctx: TRPCContext) => { export const requireCanViewExperiment = (experimentId: string, ctx: TRPCContext): Promise<void> => {
await prisma.experiment.findFirst({
where: { id: experimentId },
});
// Right now all experiments are publicly viewable, so this is a no-op. // Right now all experiments are publicly viewable, so this is a no-op.
ctx.markAccessControlRun(); ctx.markAccessControlRun();
return Promise.resolve();
}; };
export const canModifyExperiment = async (experimentId: string, userId: string) => { export const canModifyExperiment = async (experimentId: string, userId: string) => {
@@ -136,6 +133,8 @@ export const canModifyExperiment = async (experimentId: string, userId: string)
}; };
export const requireCanModifyExperiment = async (experimentId: string, ctx: TRPCContext) => { export const requireCanModifyExperiment = async (experimentId: string, ctx: TRPCContext) => {
ctx.markAccessControlRun();
const userId = ctx.session?.user.id; const userId = ctx.session?.user.id;
if (!userId) { if (!userId) {
throw new TRPCError({ code: "UNAUTHORIZED" }); throw new TRPCError({ code: "UNAUTHORIZED" });
@@ -144,6 +143,17 @@ export const requireCanModifyExperiment = async (experimentId: string, ctx: TRPC
if (!(await canModifyExperiment(experimentId, userId))) { if (!(await canModifyExperiment(experimentId, userId))) {
throw new TRPCError({ code: "UNAUTHORIZED" }); throw new TRPCError({ code: "UNAUTHORIZED" });
} }
};
ctx.markAccessControlRun();
export const requireIsAdmin = async (ctx: TRPCContext) => {
ctx.markAccessControlRun();
const userId = ctx.session?.user.id;
if (!userId) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
if (!(await isAdmin(userId))) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
}; };

28
pnpm-lock.yaml generated
View File

@@ -134,6 +134,9 @@ importers:
kysely: kysely:
specifier: ^0.26.1 specifier: ^0.26.1
version: 0.26.1 version: 0.26.1
kysely-codegen:
specifier: ^0.10.1
version: 0.10.1(kysely@0.26.1)(pg@8.11.2)
lodash-es: lodash-es:
specifier: ^4.17.21 specifier: ^4.17.21
version: 4.17.21 version: 4.17.21
@@ -6391,6 +6394,30 @@ packages:
object.values: 1.1.6 object.values: 1.1.6
dev: true dev: true
/kysely-codegen@0.10.1(kysely@0.26.1)(pg@8.11.2):
resolution: {integrity: sha512-8Bslh952gN5gtucRv4jTZDFD18RBioS6M50zHfe5kwb5iSyEAunU4ZYMdHzkHraa4zxjg5/183XlOryBCXLRIw==}
hasBin: true
peerDependencies:
better-sqlite3: '>=7.6.2'
kysely: '>=0.19.12'
mysql2: ^2.3.3 || ^3.0.0
pg: ^8.8.0
peerDependenciesMeta:
better-sqlite3:
optional: true
mysql2:
optional: true
pg:
optional: true
dependencies:
chalk: 4.1.2
dotenv: 16.3.1
kysely: 0.26.1
micromatch: 4.0.5
minimist: 1.2.8
pg: 8.11.2
dev: false
/kysely@0.26.1: /kysely@0.26.1:
resolution: {integrity: sha512-FVRomkdZofBu3O8SiwAOXrwbhPZZr8mBN5ZeUWyprH29jzvy6Inzqbd0IMmGxpd4rcOCL9HyyBNWBa8FBqDAdg==} resolution: {integrity: sha512-FVRomkdZofBu3O8SiwAOXrwbhPZZr8mBN5ZeUWyprH29jzvy6Inzqbd0IMmGxpd4rcOCL9HyyBNWBa8FBqDAdg==}
engines: {node: '>=14.0.0'} engines: {node: '>=14.0.0'}
@@ -6611,7 +6638,6 @@ packages:
dependencies: dependencies:
braces: 3.0.2 braces: 3.0.2
picomatch: 2.3.1 picomatch: 2.3.1
dev: true
/mime-db@1.52.0: /mime-db@1.52.0:
resolution: {integrity: sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==} resolution: {integrity: sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==}