move app to app/ subdir
This commit is contained in:
108
app/src/server/api/autogenerate/autogenerateDatasetEntries.ts
Normal file
108
app/src/server/api/autogenerate/autogenerateDatasetEntries.ts
Normal file
@@ -0,0 +1,108 @@
|
||||
import { type ChatCompletion } from "openai/resources/chat";
|
||||
import { openai } from "../../utils/openai";
|
||||
import { isAxiosError } from "./utils";
|
||||
import { type APIResponse } from "openai/core";
|
||||
import { sleep } from "~/server/utils/sleep";
|
||||
|
||||
const MAX_AUTO_RETRIES = 50;
|
||||
const MIN_DELAY = 500; // milliseconds
|
||||
const MAX_DELAY = 15000; // milliseconds
|
||||
|
||||
function calculateDelay(numPreviousTries: number): number {
|
||||
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
|
||||
const jitter = Math.random() * baseDelay;
|
||||
return baseDelay + jitter;
|
||||
}
|
||||
|
||||
const getCompletionWithBackoff = async (
|
||||
getCompletion: () => Promise<APIResponse<ChatCompletion>>,
|
||||
) => {
|
||||
let completion;
|
||||
let tries = 0;
|
||||
while (tries < MAX_AUTO_RETRIES) {
|
||||
try {
|
||||
completion = await getCompletion();
|
||||
break;
|
||||
} catch (e) {
|
||||
if (isAxiosError(e)) {
|
||||
console.error(e?.response?.data?.error?.message);
|
||||
} else {
|
||||
await sleep(calculateDelay(tries));
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
tries++;
|
||||
}
|
||||
return completion;
|
||||
};
|
||||
// TODO: Add seeds to ensure batches don't contain duplicate data
|
||||
const MAX_BATCH_SIZE = 5;
|
||||
|
||||
export const autogenerateDatasetEntries = async (
|
||||
numToGenerate: number,
|
||||
inputDescription: string,
|
||||
outputDescription: string,
|
||||
): Promise<{ input: string; output: string }[]> => {
|
||||
const batchSizes = Array.from({ length: Math.ceil(numToGenerate / MAX_BATCH_SIZE) }, (_, i) =>
|
||||
i === Math.ceil(numToGenerate / MAX_BATCH_SIZE) - 1 && numToGenerate % MAX_BATCH_SIZE
|
||||
? numToGenerate % MAX_BATCH_SIZE
|
||||
: MAX_BATCH_SIZE,
|
||||
);
|
||||
|
||||
const getCompletion = (batchSize: number) =>
|
||||
openai.chat.completions.create({
|
||||
model: "gpt-4",
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: `The user needs ${batchSize} rows of data, each with an input and an output.\n---\n The input should follow these requirements: ${inputDescription}\n---\n The output should follow these requirements: ${outputDescription}`,
|
||||
},
|
||||
],
|
||||
functions: [
|
||||
{
|
||||
name: "add_list_of_data",
|
||||
description: "Add a list of data to the database",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
rows: {
|
||||
type: "array",
|
||||
description: "The rows of data that match the description",
|
||||
items: {
|
||||
type: "object",
|
||||
properties: {
|
||||
input: {
|
||||
type: "string",
|
||||
description: "The input for this row",
|
||||
},
|
||||
output: {
|
||||
type: "string",
|
||||
description: "The output for this row",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
|
||||
function_call: { name: "add_list_of_data" },
|
||||
temperature: 0.5,
|
||||
});
|
||||
|
||||
const completionCallbacks = batchSizes.map((batchSize) =>
|
||||
getCompletionWithBackoff(() => getCompletion(batchSize)),
|
||||
);
|
||||
|
||||
const completions = await Promise.all(completionCallbacks);
|
||||
|
||||
const rows = completions.flatMap((completion) => {
|
||||
const parsed = JSON.parse(
|
||||
completion?.choices[0]?.message?.function_call?.arguments ?? "{rows: []}",
|
||||
) as { rows: { input: string; output: string }[] };
|
||||
return parsed.rows;
|
||||
});
|
||||
|
||||
return rows;
|
||||
};
|
||||
118
app/src/server/api/autogenerate/autogenerateScenarioValues.ts
Normal file
118
app/src/server/api/autogenerate/autogenerateScenarioValues.ts
Normal file
@@ -0,0 +1,118 @@
|
||||
import { type CompletionCreateParams } from "openai/resources/chat";
|
||||
import { prisma } from "../../db";
|
||||
import { openai } from "../../utils/openai";
|
||||
import { pick } from "lodash-es";
|
||||
import { isAxiosError } from "./utils";
|
||||
|
||||
export const autogenerateScenarioValues = async (
|
||||
experimentId: string,
|
||||
): Promise<Record<string, string>> => {
|
||||
const [experiment, variables, existingScenarios, prompt] = await Promise.all([
|
||||
prisma.experiment.findUnique({
|
||||
where: {
|
||||
id: experimentId,
|
||||
},
|
||||
}),
|
||||
prisma.templateVariable.findMany({
|
||||
where: {
|
||||
experimentId,
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId,
|
||||
visible: true,
|
||||
},
|
||||
orderBy: {
|
||||
sortIndex: "asc",
|
||||
},
|
||||
take: 10,
|
||||
}),
|
||||
prisma.promptVariant.findFirst({
|
||||
where: {
|
||||
experimentId,
|
||||
visible: true,
|
||||
},
|
||||
orderBy: {
|
||||
sortIndex: "asc",
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
||||
if (!experiment || !(variables?.length > 0) || !prompt) return {};
|
||||
|
||||
const messages: CompletionCreateParams.CreateChatCompletionRequestNonStreaming["messages"] = [
|
||||
{
|
||||
role: "system",
|
||||
content:
|
||||
"The user is testing multiple scenarios against the same prompt. Attempt to generate a new scenario that is different from the others.",
|
||||
},
|
||||
];
|
||||
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: `Prompt constructor function:\n---\n${prompt.promptConstructor}`,
|
||||
});
|
||||
|
||||
existingScenarios
|
||||
.map(
|
||||
(scenario) =>
|
||||
pick(
|
||||
scenario.variableValues,
|
||||
variables.map((variable) => variable.label),
|
||||
) as Record<string, string>,
|
||||
)
|
||||
.filter((vals) => Object.keys(vals ?? {}).length > 0)
|
||||
.forEach((vals) => {
|
||||
messages.push({
|
||||
role: "assistant",
|
||||
content: null,
|
||||
function_call: {
|
||||
name: "add_scenario",
|
||||
arguments: JSON.stringify(vals),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
const variableProperties = variables.reduce(
|
||||
(acc, variable) => {
|
||||
acc[variable.label] = { type: "string" };
|
||||
return acc;
|
||||
},
|
||||
{} as Record<string, { type: "string" }>,
|
||||
);
|
||||
|
||||
try {
|
||||
const completion = await openai.chat.completions.create({
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
messages,
|
||||
functions: [
|
||||
{
|
||||
name: "add_scenario",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: variableProperties,
|
||||
},
|
||||
},
|
||||
],
|
||||
|
||||
function_call: { name: "add_scenario" },
|
||||
temperature: 0.5,
|
||||
});
|
||||
|
||||
const parsed = JSON.parse(
|
||||
completion.choices[0]?.message?.function_call?.arguments ?? "{}",
|
||||
) as Record<string, string>;
|
||||
return parsed;
|
||||
} catch (e) {
|
||||
// If it's an axios error, try to get the error message
|
||||
if (isAxiosError(e)) {
|
||||
console.error(e?.response?.data?.error?.message);
|
||||
} else {
|
||||
console.error(e);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
return {};
|
||||
};
|
||||
18
app/src/server/api/autogenerate/utils.ts
Normal file
18
app/src/server/api/autogenerate/utils.ts
Normal file
@@ -0,0 +1,18 @@
|
||||
type AxiosError = {
|
||||
response?: {
|
||||
data?: {
|
||||
error?: {
|
||||
message?: string;
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
export function isAxiosError(error: unknown): error is AxiosError {
|
||||
if (typeof error === "object" && error !== null) {
|
||||
// Initial check
|
||||
const err = error as AxiosError;
|
||||
return err.response?.data?.error?.message !== undefined; // Check structure
|
||||
}
|
||||
return false;
|
||||
}
|
||||
30
app/src/server/api/root.router.ts
Normal file
30
app/src/server/api/root.router.ts
Normal file
@@ -0,0 +1,30 @@
|
||||
import { promptVariantsRouter } from "~/server/api/routers/promptVariants.router";
|
||||
import { createTRPCRouter } from "~/server/api/trpc";
|
||||
import { experimentsRouter } from "./routers/experiments.router";
|
||||
import { scenariosRouter } from "./routers/scenarios.router";
|
||||
import { scenarioVariantCellsRouter } from "./routers/scenarioVariantCells.router";
|
||||
import { templateVarsRouter } from "./routers/templateVariables.router";
|
||||
import { evaluationsRouter } from "./routers/evaluations.router";
|
||||
import { worldChampsRouter } from "./routers/worldChamps.router";
|
||||
import { datasetsRouter } from "./routers/datasets.router";
|
||||
import { datasetEntries } from "./routers/datasetEntries.router";
|
||||
|
||||
/**
|
||||
* This is the primary router for your server.
|
||||
*
|
||||
* All routers added in /api/routers should be manually added here.
|
||||
*/
|
||||
export const appRouter = createTRPCRouter({
|
||||
promptVariants: promptVariantsRouter,
|
||||
experiments: experimentsRouter,
|
||||
scenarios: scenariosRouter,
|
||||
scenarioVariantCells: scenarioVariantCellsRouter,
|
||||
templateVars: templateVarsRouter,
|
||||
evaluations: evaluationsRouter,
|
||||
worldChamps: worldChampsRouter,
|
||||
datasets: datasetsRouter,
|
||||
datasetEntries: datasetEntries,
|
||||
});
|
||||
|
||||
// export type definition of API
|
||||
export type AppRouter = typeof appRouter;
|
||||
149
app/src/server/api/routers/datasetEntries.router.ts
Normal file
149
app/src/server/api/routers/datasetEntries.router.ts
Normal file
@@ -0,0 +1,149 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { requireCanModifyDataset, requireCanViewDataset } from "~/utils/accessControl";
|
||||
import { autogenerateDatasetEntries } from "../autogenerate/autogenerateDatasetEntries";
|
||||
|
||||
const PAGE_SIZE = 10;
|
||||
|
||||
export const datasetEntries = createTRPCRouter({
|
||||
list: protectedProcedure
|
||||
.input(z.object({ datasetId: z.string(), page: z.number() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
await requireCanViewDataset(input.datasetId, ctx);
|
||||
|
||||
const { datasetId, page } = input;
|
||||
|
||||
const entries = await prisma.datasetEntry.findMany({
|
||||
where: {
|
||||
datasetId,
|
||||
},
|
||||
orderBy: { createdAt: "desc" },
|
||||
skip: (page - 1) * PAGE_SIZE,
|
||||
take: PAGE_SIZE,
|
||||
});
|
||||
|
||||
const count = await prisma.datasetEntry.count({
|
||||
where: {
|
||||
datasetId,
|
||||
},
|
||||
});
|
||||
|
||||
return {
|
||||
entries,
|
||||
startIndex: (page - 1) * PAGE_SIZE + 1,
|
||||
lastPage: Math.ceil(count / PAGE_SIZE),
|
||||
count,
|
||||
};
|
||||
}),
|
||||
createOne: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
datasetId: z.string(),
|
||||
input: z.string(),
|
||||
output: z.string().optional(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyDataset(input.datasetId, ctx);
|
||||
|
||||
return await prisma.datasetEntry.create({
|
||||
data: {
|
||||
datasetId: input.datasetId,
|
||||
input: input.input,
|
||||
output: input.output,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
autogenerateEntries: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
datasetId: z.string(),
|
||||
numToGenerate: z.number(),
|
||||
inputDescription: z.string(),
|
||||
outputDescription: z.string(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyDataset(input.datasetId, ctx);
|
||||
|
||||
const dataset = await prisma.dataset.findUnique({
|
||||
where: {
|
||||
id: input.datasetId,
|
||||
},
|
||||
});
|
||||
|
||||
if (!dataset) {
|
||||
throw new Error(`Dataset with id ${input.datasetId} does not exist`);
|
||||
}
|
||||
|
||||
const entries = await autogenerateDatasetEntries(
|
||||
input.numToGenerate,
|
||||
input.inputDescription,
|
||||
input.outputDescription,
|
||||
);
|
||||
|
||||
const createdEntries = await prisma.datasetEntry.createMany({
|
||||
data: entries.map((entry) => ({
|
||||
datasetId: input.datasetId,
|
||||
input: entry.input,
|
||||
output: entry.output,
|
||||
})),
|
||||
});
|
||||
|
||||
return createdEntries;
|
||||
}),
|
||||
|
||||
delete: protectedProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const datasetId = (
|
||||
await prisma.datasetEntry.findUniqueOrThrow({
|
||||
where: { id: input.id },
|
||||
})
|
||||
).datasetId;
|
||||
|
||||
await requireCanModifyDataset(datasetId, ctx);
|
||||
|
||||
return await prisma.datasetEntry.delete({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
update: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
updates: z.object({
|
||||
input: z.string(),
|
||||
output: z.string().optional(),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const existing = await prisma.datasetEntry.findUnique({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
|
||||
if (!existing) {
|
||||
throw new Error(`dataEntry with id ${input.id} does not exist`);
|
||||
}
|
||||
|
||||
await requireCanModifyDataset(existing.datasetId, ctx);
|
||||
|
||||
return await prisma.datasetEntry.update({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
data: {
|
||||
input: input.updates.input,
|
||||
output: input.updates.output,
|
||||
},
|
||||
});
|
||||
}),
|
||||
});
|
||||
91
app/src/server/api/routers/datasets.router.ts
Normal file
91
app/src/server/api/routers/datasets.router.ts
Normal file
@@ -0,0 +1,91 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import {
|
||||
requireCanModifyDataset,
|
||||
requireCanViewDataset,
|
||||
requireNothing,
|
||||
} from "~/utils/accessControl";
|
||||
import userOrg from "~/server/utils/userOrg";
|
||||
|
||||
export const datasetsRouter = createTRPCRouter({
|
||||
list: protectedProcedure.query(async ({ ctx }) => {
|
||||
// Anyone can list experiments
|
||||
requireNothing(ctx);
|
||||
|
||||
const datasets = await prisma.dataset.findMany({
|
||||
where: {
|
||||
organization: {
|
||||
organizationUsers: {
|
||||
some: { userId: ctx.session.user.id },
|
||||
},
|
||||
},
|
||||
},
|
||||
orderBy: {
|
||||
createdAt: "desc",
|
||||
},
|
||||
include: {
|
||||
_count: {
|
||||
select: { datasetEntries: true },
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
return datasets;
|
||||
}),
|
||||
|
||||
get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
|
||||
await requireCanViewDataset(input.id, ctx);
|
||||
return await prisma.dataset.findFirstOrThrow({
|
||||
where: { id: input.id },
|
||||
});
|
||||
}),
|
||||
|
||||
create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => {
|
||||
// Anyone can create an experiment
|
||||
requireNothing(ctx);
|
||||
|
||||
const numDatasets = await prisma.dataset.count({
|
||||
where: {
|
||||
organization: {
|
||||
organizationUsers: {
|
||||
some: { userId: ctx.session.user.id },
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
return await prisma.dataset.create({
|
||||
data: {
|
||||
name: `Dataset ${numDatasets + 1}`,
|
||||
organizationId: (await userOrg(ctx.session.user.id)).id,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
update: protectedProcedure
|
||||
.input(z.object({ id: z.string(), updates: z.object({ name: z.string() }) }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyDataset(input.id, ctx);
|
||||
return await prisma.dataset.update({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
data: {
|
||||
name: input.updates.name,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
delete: protectedProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyDataset(input.id, ctx);
|
||||
|
||||
await prisma.dataset.delete({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
}),
|
||||
});
|
||||
94
app/src/server/api/routers/evaluations.router.ts
Normal file
94
app/src/server/api/routers/evaluations.router.ts
Normal file
@@ -0,0 +1,94 @@
|
||||
import { EvalType } from "@prisma/client";
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { queueRunNewEval } from "~/server/tasks/runNewEval.task";
|
||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||
|
||||
export const evaluationsRouter = createTRPCRouter({
|
||||
list: publicProcedure
|
||||
.input(z.object({ experimentId: z.string() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.experimentId, ctx);
|
||||
|
||||
return await prisma.evaluation.findMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
},
|
||||
orderBy: { createdAt: "asc" },
|
||||
});
|
||||
}),
|
||||
|
||||
create: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
experimentId: z.string(),
|
||||
label: z.string(),
|
||||
value: z.string(),
|
||||
evalType: z.nativeEnum(EvalType),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyExperiment(input.experimentId, ctx);
|
||||
|
||||
await prisma.evaluation.create({
|
||||
data: {
|
||||
experimentId: input.experimentId,
|
||||
label: input.label,
|
||||
value: input.value,
|
||||
evalType: input.evalType,
|
||||
},
|
||||
});
|
||||
|
||||
await queueRunNewEval(input.experimentId);
|
||||
}),
|
||||
|
||||
update: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
updates: z.object({
|
||||
label: z.string().optional(),
|
||||
value: z.string().optional(),
|
||||
evalType: z.nativeEnum(EvalType).optional(),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.evaluation.findUniqueOrThrow({
|
||||
where: { id: input.id },
|
||||
});
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
|
||||
const evaluation = await prisma.evaluation.update({
|
||||
where: { id: input.id },
|
||||
data: {
|
||||
label: input.updates.label,
|
||||
value: input.updates.value,
|
||||
evalType: input.updates.evalType,
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.outputEvaluation.deleteMany({
|
||||
where: {
|
||||
evaluationId: evaluation.id,
|
||||
},
|
||||
});
|
||||
// Re-run all evals. Other eval results will already be cached, so this
|
||||
// should only re-run the updated one.
|
||||
await queueRunNewEval(experimentId);
|
||||
}),
|
||||
|
||||
delete: protectedProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.evaluation.findUniqueOrThrow({
|
||||
where: { id: input.id },
|
||||
});
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
|
||||
await prisma.evaluation.delete({
|
||||
where: { id: input.id },
|
||||
});
|
||||
}),
|
||||
});
|
||||
419
app/src/server/api/routers/experiments.router.ts
Normal file
419
app/src/server/api/routers/experiments.router.ts
Normal file
@@ -0,0 +1,419 @@
|
||||
import { z } from "zod";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { type Prisma } from "@prisma/client";
|
||||
import { prisma } from "~/server/db";
|
||||
import dedent from "dedent";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
import {
|
||||
canModifyExperiment,
|
||||
requireCanModifyExperiment,
|
||||
requireCanViewExperiment,
|
||||
requireNothing,
|
||||
} from "~/utils/accessControl";
|
||||
import userOrg from "~/server/utils/userOrg";
|
||||
import generateTypes from "~/modelProviders/generateTypes";
|
||||
import { promptConstructorVersion } from "~/promptConstructor/version";
|
||||
|
||||
export const experimentsRouter = createTRPCRouter({
|
||||
stats: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.id, ctx);
|
||||
|
||||
const [experiment, promptVariantCount, testScenarioCount] = await prisma.$transaction([
|
||||
prisma.experiment.findFirstOrThrow({
|
||||
where: { id: input.id },
|
||||
}),
|
||||
prisma.promptVariant.count({
|
||||
where: {
|
||||
experimentId: input.id,
|
||||
visible: true,
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.count({
|
||||
where: {
|
||||
experimentId: input.id,
|
||||
visible: true,
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
||||
return {
|
||||
experimentLabel: experiment.label,
|
||||
promptVariantCount,
|
||||
testScenarioCount,
|
||||
};
|
||||
}),
|
||||
list: protectedProcedure.query(async ({ ctx }) => {
|
||||
// Anyone can list experiments
|
||||
requireNothing(ctx);
|
||||
|
||||
const experiments = await prisma.experiment.findMany({
|
||||
where: {
|
||||
organization: {
|
||||
organizationUsers: {
|
||||
some: { userId: ctx.session.user.id },
|
||||
},
|
||||
},
|
||||
},
|
||||
orderBy: {
|
||||
sortIndex: "desc",
|
||||
},
|
||||
});
|
||||
|
||||
// TODO: look for cleaner way to do this. Maybe aggregate?
|
||||
const experimentsWithCounts = await Promise.all(
|
||||
experiments.map(async (experiment) => {
|
||||
const visibleTestScenarioCount = await prisma.testScenario.count({
|
||||
where: {
|
||||
experimentId: experiment.id,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
const visiblePromptVariantCount = await prisma.promptVariant.count({
|
||||
where: {
|
||||
experimentId: experiment.id,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
return {
|
||||
...experiment,
|
||||
testScenarioCount: visibleTestScenarioCount,
|
||||
promptVariantCount: visiblePromptVariantCount,
|
||||
};
|
||||
}),
|
||||
);
|
||||
|
||||
return experimentsWithCounts;
|
||||
}),
|
||||
|
||||
get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.id, ctx);
|
||||
const experiment = await prisma.experiment.findFirstOrThrow({
|
||||
where: { id: input.id },
|
||||
});
|
||||
|
||||
const canModify = ctx.session?.user.id
|
||||
? await canModifyExperiment(experiment.id, ctx.session?.user.id)
|
||||
: false;
|
||||
|
||||
return {
|
||||
...experiment,
|
||||
access: {
|
||||
canView: true,
|
||||
canModify,
|
||||
},
|
||||
};
|
||||
}),
|
||||
|
||||
fork: protectedProcedure.input(z.object({ id: z.string() })).mutation(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.id, ctx);
|
||||
|
||||
const [
|
||||
existingExp,
|
||||
existingVariants,
|
||||
existingScenarios,
|
||||
existingCells,
|
||||
evaluations,
|
||||
templateVariables,
|
||||
] = await prisma.$transaction([
|
||||
prisma.experiment.findUniqueOrThrow({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
}),
|
||||
prisma.promptVariant.findMany({
|
||||
where: {
|
||||
experimentId: input.id,
|
||||
visible: true,
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId: input.id,
|
||||
visible: true,
|
||||
},
|
||||
}),
|
||||
prisma.scenarioVariantCell.findMany({
|
||||
where: {
|
||||
testScenario: {
|
||||
visible: true,
|
||||
},
|
||||
promptVariant: {
|
||||
experimentId: input.id,
|
||||
visible: true,
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelResponses: {
|
||||
include: {
|
||||
outputEvaluations: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
prisma.evaluation.findMany({
|
||||
where: {
|
||||
experimentId: input.id,
|
||||
},
|
||||
}),
|
||||
prisma.templateVariable.findMany({
|
||||
where: {
|
||||
experimentId: input.id,
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
||||
const newExperimentId = uuidv4();
|
||||
|
||||
const existingToNewVariantIds = new Map<string, string>();
|
||||
const variantsToCreate: Prisma.PromptVariantCreateManyInput[] = [];
|
||||
for (const variant of existingVariants) {
|
||||
const newVariantId = uuidv4();
|
||||
existingToNewVariantIds.set(variant.id, newVariantId);
|
||||
variantsToCreate.push({
|
||||
...variant,
|
||||
id: newVariantId,
|
||||
experimentId: newExperimentId,
|
||||
});
|
||||
}
|
||||
|
||||
const existingToNewScenarioIds = new Map<string, string>();
|
||||
const scenariosToCreate: Prisma.TestScenarioCreateManyInput[] = [];
|
||||
for (const scenario of existingScenarios) {
|
||||
const newScenarioId = uuidv4();
|
||||
existingToNewScenarioIds.set(scenario.id, newScenarioId);
|
||||
scenariosToCreate.push({
|
||||
...scenario,
|
||||
id: newScenarioId,
|
||||
experimentId: newExperimentId,
|
||||
variableValues: scenario.variableValues as Prisma.InputJsonValue,
|
||||
});
|
||||
}
|
||||
|
||||
const existingToNewEvaluationIds = new Map<string, string>();
|
||||
const evaluationsToCreate: Prisma.EvaluationCreateManyInput[] = [];
|
||||
for (const evaluation of evaluations) {
|
||||
const newEvaluationId = uuidv4();
|
||||
existingToNewEvaluationIds.set(evaluation.id, newEvaluationId);
|
||||
evaluationsToCreate.push({
|
||||
...evaluation,
|
||||
id: newEvaluationId,
|
||||
experimentId: newExperimentId,
|
||||
});
|
||||
}
|
||||
|
||||
const cellsToCreate: Prisma.ScenarioVariantCellCreateManyInput[] = [];
|
||||
const modelResponsesToCreate: Prisma.ModelResponseCreateManyInput[] = [];
|
||||
const outputEvaluationsToCreate: Prisma.OutputEvaluationCreateManyInput[] = [];
|
||||
for (const cell of existingCells) {
|
||||
const newCellId = uuidv4();
|
||||
const { modelResponses, ...cellData } = cell;
|
||||
cellsToCreate.push({
|
||||
...cellData,
|
||||
id: newCellId,
|
||||
promptVariantId: existingToNewVariantIds.get(cell.promptVariantId) ?? "",
|
||||
testScenarioId: existingToNewScenarioIds.get(cell.testScenarioId) ?? "",
|
||||
prompt: (cell.prompt as Prisma.InputJsonValue) ?? undefined,
|
||||
});
|
||||
for (const modelResponse of modelResponses) {
|
||||
const newModelResponseId = uuidv4();
|
||||
const { outputEvaluations, ...modelResponseData } = modelResponse;
|
||||
modelResponsesToCreate.push({
|
||||
...modelResponseData,
|
||||
id: newModelResponseId,
|
||||
scenarioVariantCellId: newCellId,
|
||||
output: (modelResponse.output as Prisma.InputJsonValue) ?? undefined,
|
||||
});
|
||||
for (const evaluation of outputEvaluations) {
|
||||
outputEvaluationsToCreate.push({
|
||||
...evaluation,
|
||||
id: uuidv4(),
|
||||
modelResponseId: newModelResponseId,
|
||||
evaluationId: existingToNewEvaluationIds.get(evaluation.evaluationId) ?? "",
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const templateVariablesToCreate: Prisma.TemplateVariableCreateManyInput[] = [];
|
||||
for (const templateVariable of templateVariables) {
|
||||
templateVariablesToCreate.push({
|
||||
...templateVariable,
|
||||
id: uuidv4(),
|
||||
experimentId: newExperimentId,
|
||||
});
|
||||
}
|
||||
|
||||
const maxSortIndex =
|
||||
(
|
||||
await prisma.experiment.aggregate({
|
||||
_max: {
|
||||
sortIndex: true,
|
||||
},
|
||||
})
|
||||
)._max?.sortIndex ?? 0;
|
||||
|
||||
await prisma.$transaction([
|
||||
prisma.experiment.create({
|
||||
data: {
|
||||
id: newExperimentId,
|
||||
sortIndex: maxSortIndex + 1,
|
||||
label: `${existingExp.label} (forked)`,
|
||||
organizationId: (await userOrg(ctx.session.user.id)).id,
|
||||
},
|
||||
}),
|
||||
prisma.promptVariant.createMany({
|
||||
data: variantsToCreate,
|
||||
}),
|
||||
prisma.testScenario.createMany({
|
||||
data: scenariosToCreate,
|
||||
}),
|
||||
prisma.scenarioVariantCell.createMany({
|
||||
data: cellsToCreate,
|
||||
}),
|
||||
prisma.modelResponse.createMany({
|
||||
data: modelResponsesToCreate,
|
||||
}),
|
||||
prisma.evaluation.createMany({
|
||||
data: evaluationsToCreate,
|
||||
}),
|
||||
prisma.outputEvaluation.createMany({
|
||||
data: outputEvaluationsToCreate,
|
||||
}),
|
||||
prisma.templateVariable.createMany({
|
||||
data: templateVariablesToCreate,
|
||||
}),
|
||||
]);
|
||||
|
||||
return newExperimentId;
|
||||
}),
|
||||
|
||||
create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => {
|
||||
// Anyone can create an experiment
|
||||
requireNothing(ctx);
|
||||
|
||||
const organizationId = (await userOrg(ctx.session.user.id)).id;
|
||||
|
||||
const maxSortIndex =
|
||||
(
|
||||
await prisma.experiment.aggregate({
|
||||
_max: {
|
||||
sortIndex: true,
|
||||
},
|
||||
where: { organizationId },
|
||||
})
|
||||
)._max?.sortIndex ?? 0;
|
||||
|
||||
const exp = await prisma.experiment.create({
|
||||
data: {
|
||||
sortIndex: maxSortIndex + 1,
|
||||
label: `Experiment ${maxSortIndex + 1}`,
|
||||
organizationId,
|
||||
},
|
||||
});
|
||||
|
||||
const [variant, _, scenario1, scenario2, scenario3] = await prisma.$transaction([
|
||||
prisma.promptVariant.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
label: "Prompt Variant 1",
|
||||
sortIndex: 0,
|
||||
// The interpolated $ is necessary until dedent incorporates
|
||||
// https://github.com/dmnd/dedent/pull/46
|
||||
promptConstructor: dedent`
|
||||
/**
|
||||
* Use Javascript to define an OpenAI chat completion
|
||||
* (https://platform.openai.com/docs/api-reference/chat/create).
|
||||
*
|
||||
* You have access to the current scenario in the \`scenario\`
|
||||
* variable.
|
||||
*/
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
stream: true,
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: \`Write 'Start experimenting!' in ${"$"}{scenario.language}\`,
|
||||
},
|
||||
],
|
||||
});`,
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
modelProvider: "openai/ChatCompletion",
|
||||
promptConstructorVersion,
|
||||
},
|
||||
}),
|
||||
prisma.templateVariable.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
label: "language",
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
variableValues: {
|
||||
language: "English",
|
||||
},
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
variableValues: {
|
||||
language: "Spanish",
|
||||
},
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
variableValues: {
|
||||
language: "German",
|
||||
},
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
||||
await generateNewCell(variant.id, scenario1.id);
|
||||
await generateNewCell(variant.id, scenario2.id);
|
||||
await generateNewCell(variant.id, scenario3.id);
|
||||
|
||||
return exp;
|
||||
}),
|
||||
|
||||
update: protectedProcedure
|
||||
.input(z.object({ id: z.string(), updates: z.object({ label: z.string() }) }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyExperiment(input.id, ctx);
|
||||
return await prisma.experiment.update({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
data: {
|
||||
label: input.updates.label,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
delete: protectedProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyExperiment(input.id, ctx);
|
||||
|
||||
await prisma.experiment.delete({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
// Keeping these on `experiment` for now because we might want to limit the
|
||||
// providers based on your account/experiment
|
||||
promptTypes: publicProcedure.query(async () => {
|
||||
return await generateTypes();
|
||||
}),
|
||||
});
|
||||
419
app/src/server/api/routers/promptVariants.router.ts
Normal file
419
app/src/server/api/routers/promptVariants.router.ts
Normal file
@@ -0,0 +1,419 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
import userError from "~/server/utils/error";
|
||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
|
||||
import { type PromptVariant } from "@prisma/client";
|
||||
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
|
||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||
import modelProviders from "~/modelProviders/modelProviders";
|
||||
import { ZodSupportedProvider } from "~/modelProviders/types";
|
||||
import parsePromptConstructor from "~/promptConstructor/parse";
|
||||
import { promptConstructorVersion } from "~/promptConstructor/version";
|
||||
|
||||
export const promptVariantsRouter = createTRPCRouter({
|
||||
list: publicProcedure
|
||||
.input(z.object({ experimentId: z.string() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.experimentId, ctx);
|
||||
|
||||
return await prisma.promptVariant.findMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
orderBy: { sortIndex: "asc" },
|
||||
});
|
||||
}),
|
||||
|
||||
stats: publicProcedure
|
||||
.input(z.object({ variantId: z.string() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
const variant = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: input.variantId,
|
||||
},
|
||||
});
|
||||
|
||||
if (!variant) {
|
||||
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
|
||||
}
|
||||
|
||||
await requireCanViewExperiment(variant.experimentId, ctx);
|
||||
|
||||
const outputEvals = await prisma.outputEvaluation.groupBy({
|
||||
by: ["evaluationId"],
|
||||
_sum: {
|
||||
result: true,
|
||||
},
|
||||
_count: {
|
||||
id: true,
|
||||
},
|
||||
where: {
|
||||
modelResponse: {
|
||||
outdated: false,
|
||||
output: { not: Prisma.AnyNull },
|
||||
scenarioVariantCell: {
|
||||
promptVariant: {
|
||||
id: input.variantId,
|
||||
visible: true,
|
||||
},
|
||||
testScenario: {
|
||||
visible: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const evals = await prisma.evaluation.findMany({
|
||||
where: {
|
||||
experimentId: variant.experimentId,
|
||||
},
|
||||
});
|
||||
|
||||
const evalResults = evals.map((evalItem) => {
|
||||
const evalResult = outputEvals.find(
|
||||
(outputEval) => outputEval.evaluationId === evalItem.id,
|
||||
);
|
||||
return {
|
||||
id: evalItem.id,
|
||||
label: evalItem.label,
|
||||
passCount: evalResult?._sum?.result ?? 0,
|
||||
totalCount: evalResult?._count?.id ?? 1,
|
||||
};
|
||||
});
|
||||
|
||||
const scenarioCount = await prisma.testScenario.count({
|
||||
where: {
|
||||
experimentId: variant.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
const outputCount = await prisma.scenarioVariantCell.count({
|
||||
where: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenario: { visible: true },
|
||||
modelResponses: {
|
||||
some: {
|
||||
outdated: false,
|
||||
output: {
|
||||
not: Prisma.AnyNull,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const overallTokens = await prisma.modelResponse.aggregate({
|
||||
where: {
|
||||
outdated: false,
|
||||
output: {
|
||||
not: Prisma.AnyNull,
|
||||
},
|
||||
scenarioVariantCell: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenario: {
|
||||
visible: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
_sum: {
|
||||
cost: true,
|
||||
promptTokens: true,
|
||||
completionTokens: true,
|
||||
},
|
||||
});
|
||||
|
||||
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
|
||||
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
|
||||
|
||||
const awaitingEvals = !!evalResults.find(
|
||||
(result) => result.totalCount < scenarioCount * evals.length,
|
||||
);
|
||||
|
||||
return {
|
||||
evalResults,
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
overallCost: overallTokens._sum?.cost ?? 0,
|
||||
scenarioCount,
|
||||
outputCount,
|
||||
awaitingEvals,
|
||||
};
|
||||
}),
|
||||
|
||||
create: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
experimentId: z.string(),
|
||||
variantId: z.string().optional(),
|
||||
streamScenarios: z.array(z.string()),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.experimentId, ctx);
|
||||
|
||||
let originalVariant: PromptVariant | null = null;
|
||||
if (input.variantId) {
|
||||
originalVariant = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: input.variantId,
|
||||
},
|
||||
});
|
||||
} else {
|
||||
originalVariant = await prisma.promptVariant.findFirst({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
orderBy: {
|
||||
sortIndex: "desc",
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const largestSortIndex =
|
||||
(
|
||||
await prisma.promptVariant.aggregate({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
},
|
||||
_max: {
|
||||
sortIndex: true,
|
||||
},
|
||||
})
|
||||
)._max?.sortIndex ?? 0;
|
||||
|
||||
const newVariantLabel =
|
||||
input.variantId && originalVariant
|
||||
? `${originalVariant?.label} Copy`
|
||||
: `Prompt Variant ${largestSortIndex + 2}`;
|
||||
|
||||
const newConstructFn = await deriveNewConstructFn(originalVariant);
|
||||
|
||||
const createNewVariantAction = prisma.promptVariant.create({
|
||||
data: {
|
||||
experimentId: input.experimentId,
|
||||
label: newVariantLabel,
|
||||
sortIndex: (originalVariant?.sortIndex ?? 0) + 1,
|
||||
promptConstructor: newConstructFn,
|
||||
promptConstructorVersion:
|
||||
originalVariant?.promptConstructorVersion ?? promptConstructorVersion,
|
||||
model: originalVariant?.model ?? "gpt-3.5-turbo",
|
||||
modelProvider: originalVariant?.modelProvider ?? "openai/ChatCompletion",
|
||||
},
|
||||
});
|
||||
|
||||
const [newVariant] = await prisma.$transaction([
|
||||
createNewVariantAction,
|
||||
recordExperimentUpdated(input.experimentId),
|
||||
]);
|
||||
|
||||
if (originalVariant) {
|
||||
// Insert new variant to right of original variant
|
||||
await reorderPromptVariants(newVariant.id, originalVariant.id, true);
|
||||
}
|
||||
|
||||
const scenarios = await prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
for (const scenario of scenarios) {
|
||||
await generateNewCell(newVariant.id, scenario.id, {
|
||||
stream: input.streamScenarios.includes(scenario.id),
|
||||
});
|
||||
}
|
||||
|
||||
return newVariant;
|
||||
}),
|
||||
|
||||
update: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
updates: z.object({
|
||||
label: z.string().optional(),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const existing = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
|
||||
if (!existing) {
|
||||
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
|
||||
}
|
||||
|
||||
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||
|
||||
const updatePromptVariantAction = prisma.promptVariant.update({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
data: input.updates,
|
||||
});
|
||||
|
||||
const [updatedPromptVariant] = await prisma.$transaction([
|
||||
updatePromptVariantAction,
|
||||
recordExperimentUpdated(existing.experimentId),
|
||||
]);
|
||||
|
||||
return updatedPromptVariant;
|
||||
}),
|
||||
|
||||
hide: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.promptVariant.findUniqueOrThrow({
|
||||
where: { id: input.id },
|
||||
});
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
|
||||
const updatedPromptVariant = await prisma.promptVariant.update({
|
||||
where: { id: input.id },
|
||||
data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
|
||||
});
|
||||
|
||||
return updatedPromptVariant;
|
||||
}),
|
||||
|
||||
getModifiedPromptFn: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
instructions: z.string().optional(),
|
||||
newModel: z
|
||||
.object({
|
||||
provider: ZodSupportedProvider,
|
||||
model: z.string(),
|
||||
})
|
||||
.optional(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const existing = await prisma.promptVariant.findUniqueOrThrow({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||
|
||||
const constructedPrompt = await parsePromptConstructor(existing.promptConstructor);
|
||||
|
||||
if ("error" in constructedPrompt) {
|
||||
return userError(constructedPrompt.error);
|
||||
}
|
||||
|
||||
const model = input.newModel
|
||||
? modelProviders[input.newModel.provider].models[input.newModel.model]
|
||||
: undefined;
|
||||
|
||||
const promptConstructionFn = await deriveNewConstructFn(existing, model, input.instructions);
|
||||
|
||||
// TODO: Validate promptConstructionFn
|
||||
// TODO: Record in some sort of history
|
||||
|
||||
return promptConstructionFn;
|
||||
}),
|
||||
|
||||
replaceVariant: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
promptConstructor: z.string(),
|
||||
streamScenarios: z.array(z.string()),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const existing = await prisma.promptVariant.findUniqueOrThrow({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||
|
||||
if (!existing) {
|
||||
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
|
||||
}
|
||||
|
||||
const parsedPrompt = await parsePromptConstructor(input.promptConstructor);
|
||||
|
||||
if ("error" in parsedPrompt) {
|
||||
return userError(parsedPrompt.error);
|
||||
}
|
||||
|
||||
// Create a duplicate with only the config changed
|
||||
const newVariant = await prisma.promptVariant.create({
|
||||
data: {
|
||||
experimentId: existing.experimentId,
|
||||
label: existing.label,
|
||||
sortIndex: existing.sortIndex,
|
||||
uiId: existing.uiId,
|
||||
promptConstructor: input.promptConstructor,
|
||||
promptConstructorVersion: existing.promptConstructorVersion,
|
||||
modelProvider: parsedPrompt.modelProvider,
|
||||
model: parsedPrompt.model,
|
||||
},
|
||||
});
|
||||
|
||||
// Hide anything with the same uiId besides the new one
|
||||
const hideOldVariants = prisma.promptVariant.updateMany({
|
||||
where: {
|
||||
uiId: existing.uiId,
|
||||
id: {
|
||||
not: newVariant.id,
|
||||
},
|
||||
},
|
||||
data: {
|
||||
visible: false,
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.$transaction([hideOldVariants, recordExperimentUpdated(existing.experimentId)]);
|
||||
|
||||
const scenarios = await prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId: newVariant.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
for (const scenario of scenarios) {
|
||||
await generateNewCell(newVariant.id, scenario.id, {
|
||||
stream: input.streamScenarios.includes(scenario.id),
|
||||
});
|
||||
}
|
||||
|
||||
return { status: "ok" } as const;
|
||||
}),
|
||||
|
||||
reorder: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
draggedId: z.string(),
|
||||
droppedId: z.string(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.promptVariant.findUniqueOrThrow({
|
||||
where: { id: input.draggedId },
|
||||
});
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
|
||||
await reorderPromptVariants(input.draggedId, input.droppedId);
|
||||
}),
|
||||
});
|
||||
99
app/src/server/api/routers/scenarioVariantCells.router.ts
Normal file
99
app/src/server/api/routers/scenarioVariantCells.router.ts
Normal file
@@ -0,0 +1,99 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { queueQueryModel } from "~/server/tasks/queryModel.task";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||
|
||||
export const scenarioVariantCellsRouter = createTRPCRouter({
|
||||
get: publicProcedure
|
||||
.input(
|
||||
z.object({
|
||||
scenarioId: z.string(),
|
||||
variantId: z.string(),
|
||||
}),
|
||||
)
|
||||
.query(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.testScenario.findUniqueOrThrow({
|
||||
where: { id: input.scenarioId },
|
||||
});
|
||||
await requireCanViewExperiment(experimentId, ctx);
|
||||
|
||||
const [cell, numTotalEvals] = await prisma.$transaction([
|
||||
prisma.scenarioVariantCell.findUnique({
|
||||
where: {
|
||||
promptVariantId_testScenarioId: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenarioId: input.scenarioId,
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelResponses: {
|
||||
where: {
|
||||
outdated: false,
|
||||
},
|
||||
include: {
|
||||
outputEvaluations: {
|
||||
include: {
|
||||
evaluation: {
|
||||
select: { label: true },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
prisma.evaluation.count({
|
||||
where: { experimentId },
|
||||
}),
|
||||
]);
|
||||
|
||||
if (!cell) return null;
|
||||
|
||||
const lastResponse = cell.modelResponses?.[cell.modelResponses?.length - 1];
|
||||
const evalsComplete = lastResponse?.outputEvaluations?.length === numTotalEvals;
|
||||
|
||||
return {
|
||||
...cell,
|
||||
evalsComplete,
|
||||
};
|
||||
}),
|
||||
forceRefetch: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
scenarioId: z.string(),
|
||||
variantId: z.string(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.testScenario.findUniqueOrThrow({
|
||||
where: { id: input.scenarioId },
|
||||
});
|
||||
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
|
||||
const cell = await prisma.scenarioVariantCell.findUnique({
|
||||
where: {
|
||||
promptVariantId_testScenarioId: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenarioId: input.scenarioId,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
if (!cell) {
|
||||
await generateNewCell(input.variantId, input.scenarioId, { stream: true });
|
||||
return;
|
||||
}
|
||||
|
||||
await prisma.modelResponse.updateMany({
|
||||
where: { scenarioVariantCellId: cell.id },
|
||||
data: {
|
||||
outdated: true,
|
||||
},
|
||||
});
|
||||
|
||||
await queueQueryModel(cell.id, true);
|
||||
}),
|
||||
});
|
||||
252
app/src/server/api/routers/scenarios.router.ts
Normal file
252
app/src/server/api/routers/scenarios.router.ts
Normal file
@@ -0,0 +1,252 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { autogenerateScenarioValues } from "../autogenerate/autogenerateScenarioValues";
|
||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||
import { runAllEvals } from "~/server/utils/evaluations";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||
|
||||
const PAGE_SIZE = 10;
|
||||
|
||||
export const scenariosRouter = createTRPCRouter({
|
||||
list: publicProcedure
|
||||
.input(z.object({ experimentId: z.string(), page: z.number() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.experimentId, ctx);
|
||||
|
||||
const { experimentId, page } = input;
|
||||
|
||||
const scenarios = await prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId,
|
||||
visible: true,
|
||||
},
|
||||
orderBy: { sortIndex: "asc" },
|
||||
skip: (page - 1) * PAGE_SIZE,
|
||||
take: PAGE_SIZE,
|
||||
});
|
||||
|
||||
const count = await prisma.testScenario.count({
|
||||
where: {
|
||||
experimentId,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
return {
|
||||
scenarios,
|
||||
startIndex: (page - 1) * PAGE_SIZE + 1,
|
||||
lastPage: Math.ceil(count / PAGE_SIZE),
|
||||
count,
|
||||
};
|
||||
}),
|
||||
get: protectedProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
|
||||
const scenario = await prisma.testScenario.findUnique({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
|
||||
if (!scenario) {
|
||||
throw new Error(`Scenario with id ${input.id} does not exist`);
|
||||
}
|
||||
|
||||
await requireCanViewExperiment(scenario.experimentId, ctx);
|
||||
|
||||
return scenario;
|
||||
}),
|
||||
create: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
experimentId: z.string(),
|
||||
autogenerate: z.boolean().optional(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyExperiment(input.experimentId, ctx);
|
||||
|
||||
await prisma.testScenario.updateMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
},
|
||||
data: {
|
||||
sortIndex: {
|
||||
increment: 1,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const createNewScenarioAction = prisma.testScenario.create({
|
||||
data: {
|
||||
experimentId: input.experimentId,
|
||||
sortIndex: 0,
|
||||
variableValues: input.autogenerate
|
||||
? await autogenerateScenarioValues(input.experimentId)
|
||||
: {},
|
||||
},
|
||||
});
|
||||
|
||||
const [scenario] = await prisma.$transaction([
|
||||
createNewScenarioAction,
|
||||
recordExperimentUpdated(input.experimentId),
|
||||
]);
|
||||
|
||||
const promptVariants = await prisma.promptVariant.findMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
for (const variant of promptVariants) {
|
||||
await generateNewCell(variant.id, scenario.id, { stream: true });
|
||||
}
|
||||
}),
|
||||
|
||||
hide: protectedProcedure.input(z.object({ id: z.string() })).mutation(async ({ input, ctx }) => {
|
||||
const experimentId = (
|
||||
await prisma.testScenario.findUniqueOrThrow({
|
||||
where: { id: input.id },
|
||||
})
|
||||
).experimentId;
|
||||
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
const hiddenScenario = await prisma.testScenario.update({
|
||||
where: { id: input.id },
|
||||
data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
|
||||
});
|
||||
|
||||
// Reevaluate all evaluations now that this scenario is hidden
|
||||
await runAllEvals(hiddenScenario.experimentId);
|
||||
|
||||
return hiddenScenario;
|
||||
}),
|
||||
|
||||
reorder: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
draggedId: z.string(),
|
||||
droppedId: z.string(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const dragged = await prisma.testScenario.findUnique({
|
||||
where: {
|
||||
id: input.draggedId,
|
||||
},
|
||||
});
|
||||
|
||||
const dropped = await prisma.testScenario.findUnique({
|
||||
where: {
|
||||
id: input.droppedId,
|
||||
},
|
||||
});
|
||||
|
||||
if (!dragged || !dropped || dragged.experimentId !== dropped.experimentId) {
|
||||
throw new Error(
|
||||
`Prompt Variant with id ${input.draggedId} or ${input.droppedId} does not exist`,
|
||||
);
|
||||
}
|
||||
|
||||
await requireCanModifyExperiment(dragged.experimentId, ctx);
|
||||
|
||||
const visibleItems = await prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId: dragged.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
orderBy: {
|
||||
sortIndex: "asc",
|
||||
},
|
||||
});
|
||||
|
||||
// Remove the dragged item from its current position
|
||||
const orderedItems = visibleItems.filter((item) => item.id !== dragged.id);
|
||||
|
||||
// Find the index of the dragged item and the dropped item
|
||||
const dragIndex = visibleItems.findIndex((item) => item.id === dragged.id);
|
||||
const dropIndex = visibleItems.findIndex((item) => item.id === dropped.id);
|
||||
|
||||
// Determine the new index for the dragged item
|
||||
let newIndex;
|
||||
if (dragIndex < dropIndex) {
|
||||
newIndex = dropIndex + 1; // Insert after the dropped item
|
||||
} else {
|
||||
newIndex = dropIndex; // Insert before the dropped item
|
||||
}
|
||||
|
||||
// Insert the dragged item at the new position
|
||||
orderedItems.splice(newIndex, 0, dragged);
|
||||
|
||||
// Now, we need to update all the items with their new sortIndex
|
||||
await prisma.$transaction(
|
||||
orderedItems.map((item, index) => {
|
||||
return prisma.testScenario.update({
|
||||
where: {
|
||||
id: item.id,
|
||||
},
|
||||
data: {
|
||||
sortIndex: index,
|
||||
},
|
||||
});
|
||||
}),
|
||||
);
|
||||
}),
|
||||
|
||||
replaceWithValues: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
values: z.record(z.string()),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const existing = await prisma.testScenario.findUnique({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
|
||||
if (!existing) {
|
||||
throw new Error(`Scenario with id ${input.id} does not exist`);
|
||||
}
|
||||
|
||||
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||
|
||||
const newScenario = await prisma.testScenario.create({
|
||||
data: {
|
||||
experimentId: existing.experimentId,
|
||||
sortIndex: existing.sortIndex,
|
||||
variableValues: input.values,
|
||||
uiId: existing.uiId,
|
||||
},
|
||||
});
|
||||
|
||||
// Hide the old scenario
|
||||
await prisma.testScenario.updateMany({
|
||||
where: {
|
||||
uiId: existing.uiId,
|
||||
id: {
|
||||
not: newScenario.id,
|
||||
},
|
||||
},
|
||||
data: {
|
||||
visible: false,
|
||||
},
|
||||
});
|
||||
|
||||
const promptVariants = await prisma.promptVariant.findMany({
|
||||
where: {
|
||||
experimentId: newScenario.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
for (const variant of promptVariants) {
|
||||
await generateNewCell(variant.id, newScenario.id, { stream: true });
|
||||
}
|
||||
|
||||
return newScenario;
|
||||
}),
|
||||
});
|
||||
49
app/src/server/api/routers/templateVariables.router.ts
Normal file
49
app/src/server/api/routers/templateVariables.router.ts
Normal file
@@ -0,0 +1,49 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||
|
||||
export const templateVarsRouter = createTRPCRouter({
|
||||
create: protectedProcedure
|
||||
.input(z.object({ experimentId: z.string(), label: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyExperiment(input.experimentId, ctx);
|
||||
|
||||
await prisma.templateVariable.create({
|
||||
data: {
|
||||
experimentId: input.experimentId,
|
||||
label: input.label,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
delete: protectedProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.templateVariable.findUniqueOrThrow({
|
||||
where: { id: input.id },
|
||||
});
|
||||
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
|
||||
await prisma.templateVariable.delete({ where: { id: input.id } });
|
||||
}),
|
||||
|
||||
list: publicProcedure
|
||||
.input(z.object({ experimentId: z.string() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.experimentId, ctx);
|
||||
return await prisma.templateVariable.findMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
},
|
||||
orderBy: {
|
||||
createdAt: "asc",
|
||||
},
|
||||
select: {
|
||||
id: true,
|
||||
label: true,
|
||||
},
|
||||
});
|
||||
}),
|
||||
});
|
||||
36
app/src/server/api/routers/worldChamps.router.ts
Normal file
36
app/src/server/api/routers/worldChamps.router.ts
Normal file
@@ -0,0 +1,36 @@
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { requireNothing } from "~/utils/accessControl";
|
||||
|
||||
export const worldChampsRouter = createTRPCRouter({
|
||||
userStatus: publicProcedure.query(async ({ ctx }) => {
|
||||
const userId = ctx.session?.user.id;
|
||||
|
||||
if (!userId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return await prisma.worldChampEntrant.findUnique({
|
||||
where: { userId },
|
||||
});
|
||||
}),
|
||||
|
||||
apply: protectedProcedure.mutation(async ({ ctx }) => {
|
||||
const userId = ctx.session.user.id;
|
||||
requireNothing(ctx);
|
||||
|
||||
const existingEntrant = await prisma.worldChampEntrant.findUnique({
|
||||
where: { userId },
|
||||
});
|
||||
|
||||
if (existingEntrant) {
|
||||
return existingEntrant;
|
||||
}
|
||||
|
||||
return await prisma.worldChampEntrant.create({
|
||||
data: {
|
||||
userId,
|
||||
},
|
||||
});
|
||||
}),
|
||||
});
|
||||
151
app/src/server/api/trpc.ts
Normal file
151
app/src/server/api/trpc.ts
Normal file
@@ -0,0 +1,151 @@
|
||||
/**
|
||||
* YOU PROBABLY DON'T NEED TO EDIT THIS FILE, UNLESS:
|
||||
* 1. You want to modify request context (see Part 1).
|
||||
* 2. You want to create a new middleware or type of procedure (see Part 3).
|
||||
*
|
||||
* TL;DR - This is where all the tRPC server stuff is created and plugged in. The pieces you will
|
||||
* need to use are documented accordingly near the end.
|
||||
*/
|
||||
|
||||
import { initTRPC, TRPCError } from "@trpc/server";
|
||||
import { type CreateNextContextOptions } from "@trpc/server/adapters/next";
|
||||
import { type Session } from "next-auth";
|
||||
import superjson from "superjson";
|
||||
import { ZodError } from "zod";
|
||||
import { getServerAuthSession } from "~/server/auth";
|
||||
import { prisma } from "~/server/db";
|
||||
import { capturePath } from "~/utils/analytics/serverAnalytics";
|
||||
|
||||
/**
|
||||
* 1. CONTEXT
|
||||
*
|
||||
* This section defines the "contexts" that are available in the backend API.
|
||||
*
|
||||
* These allow you to access things when processing a request, like the database, the session, etc.
|
||||
*/
|
||||
|
||||
type CreateContextOptions = {
|
||||
session: Session | null;
|
||||
};
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-empty-function
|
||||
const noOp = () => {};
|
||||
|
||||
/**
|
||||
* This helper generates the "internals" for a tRPC context. If you need to use it, you can export
|
||||
* it from here.
|
||||
*
|
||||
* Examples of things you may need it for:
|
||||
* - testing, so we don't have to mock Next.js' req/res
|
||||
* - tRPC's `createSSGHelpers`, where we don't have req/res
|
||||
*
|
||||
* @see https://create.t3.gg/en/usage/trpc#-serverapitrpcts
|
||||
*/
|
||||
export const createInnerTRPCContext = (opts: CreateContextOptions) => {
|
||||
return {
|
||||
session: opts.session,
|
||||
prisma,
|
||||
markAccessControlRun: noOp,
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* This is the actual context you will use in your router. It will be used to process every request
|
||||
* that goes through your tRPC endpoint.
|
||||
*
|
||||
* @see https://trpc.io/docs/context
|
||||
*/
|
||||
export const createTRPCContext = async (opts: CreateNextContextOptions) => {
|
||||
const { req, res } = opts;
|
||||
|
||||
// Get the session from the server using the getServerSession wrapper function
|
||||
const session = await getServerAuthSession({ req, res });
|
||||
|
||||
return createInnerTRPCContext({
|
||||
session,
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* 2. INITIALIZATION
|
||||
*
|
||||
* This is where the tRPC API is initialized, connecting the context and transformer. We also parse
|
||||
* ZodErrors so that you get typesafety on the frontend if your procedure fails due to validation
|
||||
* errors on the backend.
|
||||
*/
|
||||
|
||||
export type TRPCContext = Awaited<ReturnType<typeof createTRPCContext>>;
|
||||
|
||||
const t = initTRPC.context<typeof createTRPCContext>().create({
|
||||
transformer: superjson,
|
||||
errorFormatter({ shape, error }) {
|
||||
return {
|
||||
...shape,
|
||||
data: {
|
||||
...shape.data,
|
||||
zodError: error.cause instanceof ZodError ? error.cause.flatten() : null,
|
||||
},
|
||||
};
|
||||
},
|
||||
});
|
||||
|
||||
/**
|
||||
* 3. ROUTER & PROCEDURE (THE IMPORTANT BIT)
|
||||
*
|
||||
* These are the pieces you use to build your tRPC API. You should import these a lot in the
|
||||
* "/src/server/api/routers" directory.
|
||||
*/
|
||||
|
||||
/**
|
||||
* This is how you create new routers and sub-routers in your tRPC API.
|
||||
*
|
||||
* @see https://trpc.io/docs/router
|
||||
*/
|
||||
export const createTRPCRouter = t.router;
|
||||
|
||||
/**
|
||||
* Public (unauthenticated) procedure
|
||||
*
|
||||
* This is the base piece you use to build new queries and mutations on your tRPC API. It does not
|
||||
* guarantee that a user querying is authorized, but you can still access user session data if they
|
||||
* are logged in.
|
||||
*/
|
||||
export const publicProcedure = t.procedure;
|
||||
|
||||
/** Reusable middleware that enforces users are logged in before running the procedure. */
|
||||
const enforceUserIsAuthed = t.middleware(async ({ ctx, next, path }) => {
|
||||
if (!ctx.session || !ctx.session.user) {
|
||||
throw new TRPCError({ code: "UNAUTHORIZED" });
|
||||
}
|
||||
|
||||
let accessControlRun = false;
|
||||
const resp = await next({
|
||||
ctx: {
|
||||
// infers the `session` as non-nullable
|
||||
session: { ...ctx.session, user: ctx.session.user },
|
||||
markAccessControlRun: () => {
|
||||
accessControlRun = true;
|
||||
},
|
||||
},
|
||||
});
|
||||
if (!accessControlRun)
|
||||
throw new TRPCError({
|
||||
code: "INTERNAL_SERVER_ERROR",
|
||||
message:
|
||||
"Protected routes must perform access control checks then explicitly invoke the `ctx.markAccessControlRun()` function to ensure we don't forget access control on a route.",
|
||||
});
|
||||
|
||||
capturePath(ctx.session, path);
|
||||
|
||||
return resp;
|
||||
});
|
||||
|
||||
/**
|
||||
* Protected (authenticated) procedure
|
||||
*
|
||||
* If you want a query or mutation to ONLY be accessible to logged in users, use this. It verifies
|
||||
* the session is valid and guarantees `ctx.session.user` is not null.
|
||||
*
|
||||
* @see https://trpc.io/docs/procedures
|
||||
*/
|
||||
export const protectedProcedure = t.procedure.use(enforceUserIsAuthed);
|
||||
Reference in New Issue
Block a user