move app to app/ subdir
This commit is contained in:
108
app/src/server/api/autogenerate/autogenerateDatasetEntries.ts
Normal file
108
app/src/server/api/autogenerate/autogenerateDatasetEntries.ts
Normal file
@@ -0,0 +1,108 @@
|
||||
import { type ChatCompletion } from "openai/resources/chat";
|
||||
import { openai } from "../../utils/openai";
|
||||
import { isAxiosError } from "./utils";
|
||||
import { type APIResponse } from "openai/core";
|
||||
import { sleep } from "~/server/utils/sleep";
|
||||
|
||||
const MAX_AUTO_RETRIES = 50;
|
||||
const MIN_DELAY = 500; // milliseconds
|
||||
const MAX_DELAY = 15000; // milliseconds
|
||||
|
||||
function calculateDelay(numPreviousTries: number): number {
|
||||
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
|
||||
const jitter = Math.random() * baseDelay;
|
||||
return baseDelay + jitter;
|
||||
}
|
||||
|
||||
const getCompletionWithBackoff = async (
|
||||
getCompletion: () => Promise<APIResponse<ChatCompletion>>,
|
||||
) => {
|
||||
let completion;
|
||||
let tries = 0;
|
||||
while (tries < MAX_AUTO_RETRIES) {
|
||||
try {
|
||||
completion = await getCompletion();
|
||||
break;
|
||||
} catch (e) {
|
||||
if (isAxiosError(e)) {
|
||||
console.error(e?.response?.data?.error?.message);
|
||||
} else {
|
||||
await sleep(calculateDelay(tries));
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
tries++;
|
||||
}
|
||||
return completion;
|
||||
};
|
||||
// TODO: Add seeds to ensure batches don't contain duplicate data
|
||||
const MAX_BATCH_SIZE = 5;
|
||||
|
||||
export const autogenerateDatasetEntries = async (
|
||||
numToGenerate: number,
|
||||
inputDescription: string,
|
||||
outputDescription: string,
|
||||
): Promise<{ input: string; output: string }[]> => {
|
||||
const batchSizes = Array.from({ length: Math.ceil(numToGenerate / MAX_BATCH_SIZE) }, (_, i) =>
|
||||
i === Math.ceil(numToGenerate / MAX_BATCH_SIZE) - 1 && numToGenerate % MAX_BATCH_SIZE
|
||||
? numToGenerate % MAX_BATCH_SIZE
|
||||
: MAX_BATCH_SIZE,
|
||||
);
|
||||
|
||||
const getCompletion = (batchSize: number) =>
|
||||
openai.chat.completions.create({
|
||||
model: "gpt-4",
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: `The user needs ${batchSize} rows of data, each with an input and an output.\n---\n The input should follow these requirements: ${inputDescription}\n---\n The output should follow these requirements: ${outputDescription}`,
|
||||
},
|
||||
],
|
||||
functions: [
|
||||
{
|
||||
name: "add_list_of_data",
|
||||
description: "Add a list of data to the database",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
rows: {
|
||||
type: "array",
|
||||
description: "The rows of data that match the description",
|
||||
items: {
|
||||
type: "object",
|
||||
properties: {
|
||||
input: {
|
||||
type: "string",
|
||||
description: "The input for this row",
|
||||
},
|
||||
output: {
|
||||
type: "string",
|
||||
description: "The output for this row",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
|
||||
function_call: { name: "add_list_of_data" },
|
||||
temperature: 0.5,
|
||||
});
|
||||
|
||||
const completionCallbacks = batchSizes.map((batchSize) =>
|
||||
getCompletionWithBackoff(() => getCompletion(batchSize)),
|
||||
);
|
||||
|
||||
const completions = await Promise.all(completionCallbacks);
|
||||
|
||||
const rows = completions.flatMap((completion) => {
|
||||
const parsed = JSON.parse(
|
||||
completion?.choices[0]?.message?.function_call?.arguments ?? "{rows: []}",
|
||||
) as { rows: { input: string; output: string }[] };
|
||||
return parsed.rows;
|
||||
});
|
||||
|
||||
return rows;
|
||||
};
|
||||
118
app/src/server/api/autogenerate/autogenerateScenarioValues.ts
Normal file
118
app/src/server/api/autogenerate/autogenerateScenarioValues.ts
Normal file
@@ -0,0 +1,118 @@
|
||||
import { type CompletionCreateParams } from "openai/resources/chat";
|
||||
import { prisma } from "../../db";
|
||||
import { openai } from "../../utils/openai";
|
||||
import { pick } from "lodash-es";
|
||||
import { isAxiosError } from "./utils";
|
||||
|
||||
export const autogenerateScenarioValues = async (
|
||||
experimentId: string,
|
||||
): Promise<Record<string, string>> => {
|
||||
const [experiment, variables, existingScenarios, prompt] = await Promise.all([
|
||||
prisma.experiment.findUnique({
|
||||
where: {
|
||||
id: experimentId,
|
||||
},
|
||||
}),
|
||||
prisma.templateVariable.findMany({
|
||||
where: {
|
||||
experimentId,
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId,
|
||||
visible: true,
|
||||
},
|
||||
orderBy: {
|
||||
sortIndex: "asc",
|
||||
},
|
||||
take: 10,
|
||||
}),
|
||||
prisma.promptVariant.findFirst({
|
||||
where: {
|
||||
experimentId,
|
||||
visible: true,
|
||||
},
|
||||
orderBy: {
|
||||
sortIndex: "asc",
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
||||
if (!experiment || !(variables?.length > 0) || !prompt) return {};
|
||||
|
||||
const messages: CompletionCreateParams.CreateChatCompletionRequestNonStreaming["messages"] = [
|
||||
{
|
||||
role: "system",
|
||||
content:
|
||||
"The user is testing multiple scenarios against the same prompt. Attempt to generate a new scenario that is different from the others.",
|
||||
},
|
||||
];
|
||||
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: `Prompt constructor function:\n---\n${prompt.promptConstructor}`,
|
||||
});
|
||||
|
||||
existingScenarios
|
||||
.map(
|
||||
(scenario) =>
|
||||
pick(
|
||||
scenario.variableValues,
|
||||
variables.map((variable) => variable.label),
|
||||
) as Record<string, string>,
|
||||
)
|
||||
.filter((vals) => Object.keys(vals ?? {}).length > 0)
|
||||
.forEach((vals) => {
|
||||
messages.push({
|
||||
role: "assistant",
|
||||
content: null,
|
||||
function_call: {
|
||||
name: "add_scenario",
|
||||
arguments: JSON.stringify(vals),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
const variableProperties = variables.reduce(
|
||||
(acc, variable) => {
|
||||
acc[variable.label] = { type: "string" };
|
||||
return acc;
|
||||
},
|
||||
{} as Record<string, { type: "string" }>,
|
||||
);
|
||||
|
||||
try {
|
||||
const completion = await openai.chat.completions.create({
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
messages,
|
||||
functions: [
|
||||
{
|
||||
name: "add_scenario",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: variableProperties,
|
||||
},
|
||||
},
|
||||
],
|
||||
|
||||
function_call: { name: "add_scenario" },
|
||||
temperature: 0.5,
|
||||
});
|
||||
|
||||
const parsed = JSON.parse(
|
||||
completion.choices[0]?.message?.function_call?.arguments ?? "{}",
|
||||
) as Record<string, string>;
|
||||
return parsed;
|
||||
} catch (e) {
|
||||
// If it's an axios error, try to get the error message
|
||||
if (isAxiosError(e)) {
|
||||
console.error(e?.response?.data?.error?.message);
|
||||
} else {
|
||||
console.error(e);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
return {};
|
||||
};
|
||||
18
app/src/server/api/autogenerate/utils.ts
Normal file
18
app/src/server/api/autogenerate/utils.ts
Normal file
@@ -0,0 +1,18 @@
|
||||
type AxiosError = {
|
||||
response?: {
|
||||
data?: {
|
||||
error?: {
|
||||
message?: string;
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
export function isAxiosError(error: unknown): error is AxiosError {
|
||||
if (typeof error === "object" && error !== null) {
|
||||
// Initial check
|
||||
const err = error as AxiosError;
|
||||
return err.response?.data?.error?.message !== undefined; // Check structure
|
||||
}
|
||||
return false;
|
||||
}
|
||||
30
app/src/server/api/root.router.ts
Normal file
30
app/src/server/api/root.router.ts
Normal file
@@ -0,0 +1,30 @@
|
||||
import { promptVariantsRouter } from "~/server/api/routers/promptVariants.router";
|
||||
import { createTRPCRouter } from "~/server/api/trpc";
|
||||
import { experimentsRouter } from "./routers/experiments.router";
|
||||
import { scenariosRouter } from "./routers/scenarios.router";
|
||||
import { scenarioVariantCellsRouter } from "./routers/scenarioVariantCells.router";
|
||||
import { templateVarsRouter } from "./routers/templateVariables.router";
|
||||
import { evaluationsRouter } from "./routers/evaluations.router";
|
||||
import { worldChampsRouter } from "./routers/worldChamps.router";
|
||||
import { datasetsRouter } from "./routers/datasets.router";
|
||||
import { datasetEntries } from "./routers/datasetEntries.router";
|
||||
|
||||
/**
|
||||
* This is the primary router for your server.
|
||||
*
|
||||
* All routers added in /api/routers should be manually added here.
|
||||
*/
|
||||
export const appRouter = createTRPCRouter({
|
||||
promptVariants: promptVariantsRouter,
|
||||
experiments: experimentsRouter,
|
||||
scenarios: scenariosRouter,
|
||||
scenarioVariantCells: scenarioVariantCellsRouter,
|
||||
templateVars: templateVarsRouter,
|
||||
evaluations: evaluationsRouter,
|
||||
worldChamps: worldChampsRouter,
|
||||
datasets: datasetsRouter,
|
||||
datasetEntries: datasetEntries,
|
||||
});
|
||||
|
||||
// export type definition of API
|
||||
export type AppRouter = typeof appRouter;
|
||||
149
app/src/server/api/routers/datasetEntries.router.ts
Normal file
149
app/src/server/api/routers/datasetEntries.router.ts
Normal file
@@ -0,0 +1,149 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { requireCanModifyDataset, requireCanViewDataset } from "~/utils/accessControl";
|
||||
import { autogenerateDatasetEntries } from "../autogenerate/autogenerateDatasetEntries";
|
||||
|
||||
const PAGE_SIZE = 10;
|
||||
|
||||
export const datasetEntries = createTRPCRouter({
|
||||
list: protectedProcedure
|
||||
.input(z.object({ datasetId: z.string(), page: z.number() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
await requireCanViewDataset(input.datasetId, ctx);
|
||||
|
||||
const { datasetId, page } = input;
|
||||
|
||||
const entries = await prisma.datasetEntry.findMany({
|
||||
where: {
|
||||
datasetId,
|
||||
},
|
||||
orderBy: { createdAt: "desc" },
|
||||
skip: (page - 1) * PAGE_SIZE,
|
||||
take: PAGE_SIZE,
|
||||
});
|
||||
|
||||
const count = await prisma.datasetEntry.count({
|
||||
where: {
|
||||
datasetId,
|
||||
},
|
||||
});
|
||||
|
||||
return {
|
||||
entries,
|
||||
startIndex: (page - 1) * PAGE_SIZE + 1,
|
||||
lastPage: Math.ceil(count / PAGE_SIZE),
|
||||
count,
|
||||
};
|
||||
}),
|
||||
createOne: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
datasetId: z.string(),
|
||||
input: z.string(),
|
||||
output: z.string().optional(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyDataset(input.datasetId, ctx);
|
||||
|
||||
return await prisma.datasetEntry.create({
|
||||
data: {
|
||||
datasetId: input.datasetId,
|
||||
input: input.input,
|
||||
output: input.output,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
autogenerateEntries: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
datasetId: z.string(),
|
||||
numToGenerate: z.number(),
|
||||
inputDescription: z.string(),
|
||||
outputDescription: z.string(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyDataset(input.datasetId, ctx);
|
||||
|
||||
const dataset = await prisma.dataset.findUnique({
|
||||
where: {
|
||||
id: input.datasetId,
|
||||
},
|
||||
});
|
||||
|
||||
if (!dataset) {
|
||||
throw new Error(`Dataset with id ${input.datasetId} does not exist`);
|
||||
}
|
||||
|
||||
const entries = await autogenerateDatasetEntries(
|
||||
input.numToGenerate,
|
||||
input.inputDescription,
|
||||
input.outputDescription,
|
||||
);
|
||||
|
||||
const createdEntries = await prisma.datasetEntry.createMany({
|
||||
data: entries.map((entry) => ({
|
||||
datasetId: input.datasetId,
|
||||
input: entry.input,
|
||||
output: entry.output,
|
||||
})),
|
||||
});
|
||||
|
||||
return createdEntries;
|
||||
}),
|
||||
|
||||
delete: protectedProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const datasetId = (
|
||||
await prisma.datasetEntry.findUniqueOrThrow({
|
||||
where: { id: input.id },
|
||||
})
|
||||
).datasetId;
|
||||
|
||||
await requireCanModifyDataset(datasetId, ctx);
|
||||
|
||||
return await prisma.datasetEntry.delete({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
update: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
updates: z.object({
|
||||
input: z.string(),
|
||||
output: z.string().optional(),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const existing = await prisma.datasetEntry.findUnique({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
|
||||
if (!existing) {
|
||||
throw new Error(`dataEntry with id ${input.id} does not exist`);
|
||||
}
|
||||
|
||||
await requireCanModifyDataset(existing.datasetId, ctx);
|
||||
|
||||
return await prisma.datasetEntry.update({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
data: {
|
||||
input: input.updates.input,
|
||||
output: input.updates.output,
|
||||
},
|
||||
});
|
||||
}),
|
||||
});
|
||||
91
app/src/server/api/routers/datasets.router.ts
Normal file
91
app/src/server/api/routers/datasets.router.ts
Normal file
@@ -0,0 +1,91 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import {
|
||||
requireCanModifyDataset,
|
||||
requireCanViewDataset,
|
||||
requireNothing,
|
||||
} from "~/utils/accessControl";
|
||||
import userOrg from "~/server/utils/userOrg";
|
||||
|
||||
export const datasetsRouter = createTRPCRouter({
|
||||
list: protectedProcedure.query(async ({ ctx }) => {
|
||||
// Anyone can list experiments
|
||||
requireNothing(ctx);
|
||||
|
||||
const datasets = await prisma.dataset.findMany({
|
||||
where: {
|
||||
organization: {
|
||||
organizationUsers: {
|
||||
some: { userId: ctx.session.user.id },
|
||||
},
|
||||
},
|
||||
},
|
||||
orderBy: {
|
||||
createdAt: "desc",
|
||||
},
|
||||
include: {
|
||||
_count: {
|
||||
select: { datasetEntries: true },
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
return datasets;
|
||||
}),
|
||||
|
||||
get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
|
||||
await requireCanViewDataset(input.id, ctx);
|
||||
return await prisma.dataset.findFirstOrThrow({
|
||||
where: { id: input.id },
|
||||
});
|
||||
}),
|
||||
|
||||
create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => {
|
||||
// Anyone can create an experiment
|
||||
requireNothing(ctx);
|
||||
|
||||
const numDatasets = await prisma.dataset.count({
|
||||
where: {
|
||||
organization: {
|
||||
organizationUsers: {
|
||||
some: { userId: ctx.session.user.id },
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
return await prisma.dataset.create({
|
||||
data: {
|
||||
name: `Dataset ${numDatasets + 1}`,
|
||||
organizationId: (await userOrg(ctx.session.user.id)).id,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
update: protectedProcedure
|
||||
.input(z.object({ id: z.string(), updates: z.object({ name: z.string() }) }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyDataset(input.id, ctx);
|
||||
return await prisma.dataset.update({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
data: {
|
||||
name: input.updates.name,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
delete: protectedProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyDataset(input.id, ctx);
|
||||
|
||||
await prisma.dataset.delete({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
}),
|
||||
});
|
||||
94
app/src/server/api/routers/evaluations.router.ts
Normal file
94
app/src/server/api/routers/evaluations.router.ts
Normal file
@@ -0,0 +1,94 @@
|
||||
import { EvalType } from "@prisma/client";
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { queueRunNewEval } from "~/server/tasks/runNewEval.task";
|
||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||
|
||||
export const evaluationsRouter = createTRPCRouter({
|
||||
list: publicProcedure
|
||||
.input(z.object({ experimentId: z.string() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.experimentId, ctx);
|
||||
|
||||
return await prisma.evaluation.findMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
},
|
||||
orderBy: { createdAt: "asc" },
|
||||
});
|
||||
}),
|
||||
|
||||
create: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
experimentId: z.string(),
|
||||
label: z.string(),
|
||||
value: z.string(),
|
||||
evalType: z.nativeEnum(EvalType),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyExperiment(input.experimentId, ctx);
|
||||
|
||||
await prisma.evaluation.create({
|
||||
data: {
|
||||
experimentId: input.experimentId,
|
||||
label: input.label,
|
||||
value: input.value,
|
||||
evalType: input.evalType,
|
||||
},
|
||||
});
|
||||
|
||||
await queueRunNewEval(input.experimentId);
|
||||
}),
|
||||
|
||||
update: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
updates: z.object({
|
||||
label: z.string().optional(),
|
||||
value: z.string().optional(),
|
||||
evalType: z.nativeEnum(EvalType).optional(),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.evaluation.findUniqueOrThrow({
|
||||
where: { id: input.id },
|
||||
});
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
|
||||
const evaluation = await prisma.evaluation.update({
|
||||
where: { id: input.id },
|
||||
data: {
|
||||
label: input.updates.label,
|
||||
value: input.updates.value,
|
||||
evalType: input.updates.evalType,
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.outputEvaluation.deleteMany({
|
||||
where: {
|
||||
evaluationId: evaluation.id,
|
||||
},
|
||||
});
|
||||
// Re-run all evals. Other eval results will already be cached, so this
|
||||
// should only re-run the updated one.
|
||||
await queueRunNewEval(experimentId);
|
||||
}),
|
||||
|
||||
delete: protectedProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.evaluation.findUniqueOrThrow({
|
||||
where: { id: input.id },
|
||||
});
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
|
||||
await prisma.evaluation.delete({
|
||||
where: { id: input.id },
|
||||
});
|
||||
}),
|
||||
});
|
||||
419
app/src/server/api/routers/experiments.router.ts
Normal file
419
app/src/server/api/routers/experiments.router.ts
Normal file
@@ -0,0 +1,419 @@
|
||||
import { z } from "zod";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { type Prisma } from "@prisma/client";
|
||||
import { prisma } from "~/server/db";
|
||||
import dedent from "dedent";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
import {
|
||||
canModifyExperiment,
|
||||
requireCanModifyExperiment,
|
||||
requireCanViewExperiment,
|
||||
requireNothing,
|
||||
} from "~/utils/accessControl";
|
||||
import userOrg from "~/server/utils/userOrg";
|
||||
import generateTypes from "~/modelProviders/generateTypes";
|
||||
import { promptConstructorVersion } from "~/promptConstructor/version";
|
||||
|
||||
export const experimentsRouter = createTRPCRouter({
|
||||
stats: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.id, ctx);
|
||||
|
||||
const [experiment, promptVariantCount, testScenarioCount] = await prisma.$transaction([
|
||||
prisma.experiment.findFirstOrThrow({
|
||||
where: { id: input.id },
|
||||
}),
|
||||
prisma.promptVariant.count({
|
||||
where: {
|
||||
experimentId: input.id,
|
||||
visible: true,
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.count({
|
||||
where: {
|
||||
experimentId: input.id,
|
||||
visible: true,
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
||||
return {
|
||||
experimentLabel: experiment.label,
|
||||
promptVariantCount,
|
||||
testScenarioCount,
|
||||
};
|
||||
}),
|
||||
list: protectedProcedure.query(async ({ ctx }) => {
|
||||
// Anyone can list experiments
|
||||
requireNothing(ctx);
|
||||
|
||||
const experiments = await prisma.experiment.findMany({
|
||||
where: {
|
||||
organization: {
|
||||
organizationUsers: {
|
||||
some: { userId: ctx.session.user.id },
|
||||
},
|
||||
},
|
||||
},
|
||||
orderBy: {
|
||||
sortIndex: "desc",
|
||||
},
|
||||
});
|
||||
|
||||
// TODO: look for cleaner way to do this. Maybe aggregate?
|
||||
const experimentsWithCounts = await Promise.all(
|
||||
experiments.map(async (experiment) => {
|
||||
const visibleTestScenarioCount = await prisma.testScenario.count({
|
||||
where: {
|
||||
experimentId: experiment.id,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
const visiblePromptVariantCount = await prisma.promptVariant.count({
|
||||
where: {
|
||||
experimentId: experiment.id,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
return {
|
||||
...experiment,
|
||||
testScenarioCount: visibleTestScenarioCount,
|
||||
promptVariantCount: visiblePromptVariantCount,
|
||||
};
|
||||
}),
|
||||
);
|
||||
|
||||
return experimentsWithCounts;
|
||||
}),
|
||||
|
||||
get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.id, ctx);
|
||||
const experiment = await prisma.experiment.findFirstOrThrow({
|
||||
where: { id: input.id },
|
||||
});
|
||||
|
||||
const canModify = ctx.session?.user.id
|
||||
? await canModifyExperiment(experiment.id, ctx.session?.user.id)
|
||||
: false;
|
||||
|
||||
return {
|
||||
...experiment,
|
||||
access: {
|
||||
canView: true,
|
||||
canModify,
|
||||
},
|
||||
};
|
||||
}),
|
||||
|
||||
fork: protectedProcedure.input(z.object({ id: z.string() })).mutation(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.id, ctx);
|
||||
|
||||
const [
|
||||
existingExp,
|
||||
existingVariants,
|
||||
existingScenarios,
|
||||
existingCells,
|
||||
evaluations,
|
||||
templateVariables,
|
||||
] = await prisma.$transaction([
|
||||
prisma.experiment.findUniqueOrThrow({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
}),
|
||||
prisma.promptVariant.findMany({
|
||||
where: {
|
||||
experimentId: input.id,
|
||||
visible: true,
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId: input.id,
|
||||
visible: true,
|
||||
},
|
||||
}),
|
||||
prisma.scenarioVariantCell.findMany({
|
||||
where: {
|
||||
testScenario: {
|
||||
visible: true,
|
||||
},
|
||||
promptVariant: {
|
||||
experimentId: input.id,
|
||||
visible: true,
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelResponses: {
|
||||
include: {
|
||||
outputEvaluations: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
prisma.evaluation.findMany({
|
||||
where: {
|
||||
experimentId: input.id,
|
||||
},
|
||||
}),
|
||||
prisma.templateVariable.findMany({
|
||||
where: {
|
||||
experimentId: input.id,
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
||||
const newExperimentId = uuidv4();
|
||||
|
||||
const existingToNewVariantIds = new Map<string, string>();
|
||||
const variantsToCreate: Prisma.PromptVariantCreateManyInput[] = [];
|
||||
for (const variant of existingVariants) {
|
||||
const newVariantId = uuidv4();
|
||||
existingToNewVariantIds.set(variant.id, newVariantId);
|
||||
variantsToCreate.push({
|
||||
...variant,
|
||||
id: newVariantId,
|
||||
experimentId: newExperimentId,
|
||||
});
|
||||
}
|
||||
|
||||
const existingToNewScenarioIds = new Map<string, string>();
|
||||
const scenariosToCreate: Prisma.TestScenarioCreateManyInput[] = [];
|
||||
for (const scenario of existingScenarios) {
|
||||
const newScenarioId = uuidv4();
|
||||
existingToNewScenarioIds.set(scenario.id, newScenarioId);
|
||||
scenariosToCreate.push({
|
||||
...scenario,
|
||||
id: newScenarioId,
|
||||
experimentId: newExperimentId,
|
||||
variableValues: scenario.variableValues as Prisma.InputJsonValue,
|
||||
});
|
||||
}
|
||||
|
||||
const existingToNewEvaluationIds = new Map<string, string>();
|
||||
const evaluationsToCreate: Prisma.EvaluationCreateManyInput[] = [];
|
||||
for (const evaluation of evaluations) {
|
||||
const newEvaluationId = uuidv4();
|
||||
existingToNewEvaluationIds.set(evaluation.id, newEvaluationId);
|
||||
evaluationsToCreate.push({
|
||||
...evaluation,
|
||||
id: newEvaluationId,
|
||||
experimentId: newExperimentId,
|
||||
});
|
||||
}
|
||||
|
||||
const cellsToCreate: Prisma.ScenarioVariantCellCreateManyInput[] = [];
|
||||
const modelResponsesToCreate: Prisma.ModelResponseCreateManyInput[] = [];
|
||||
const outputEvaluationsToCreate: Prisma.OutputEvaluationCreateManyInput[] = [];
|
||||
for (const cell of existingCells) {
|
||||
const newCellId = uuidv4();
|
||||
const { modelResponses, ...cellData } = cell;
|
||||
cellsToCreate.push({
|
||||
...cellData,
|
||||
id: newCellId,
|
||||
promptVariantId: existingToNewVariantIds.get(cell.promptVariantId) ?? "",
|
||||
testScenarioId: existingToNewScenarioIds.get(cell.testScenarioId) ?? "",
|
||||
prompt: (cell.prompt as Prisma.InputJsonValue) ?? undefined,
|
||||
});
|
||||
for (const modelResponse of modelResponses) {
|
||||
const newModelResponseId = uuidv4();
|
||||
const { outputEvaluations, ...modelResponseData } = modelResponse;
|
||||
modelResponsesToCreate.push({
|
||||
...modelResponseData,
|
||||
id: newModelResponseId,
|
||||
scenarioVariantCellId: newCellId,
|
||||
output: (modelResponse.output as Prisma.InputJsonValue) ?? undefined,
|
||||
});
|
||||
for (const evaluation of outputEvaluations) {
|
||||
outputEvaluationsToCreate.push({
|
||||
...evaluation,
|
||||
id: uuidv4(),
|
||||
modelResponseId: newModelResponseId,
|
||||
evaluationId: existingToNewEvaluationIds.get(evaluation.evaluationId) ?? "",
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const templateVariablesToCreate: Prisma.TemplateVariableCreateManyInput[] = [];
|
||||
for (const templateVariable of templateVariables) {
|
||||
templateVariablesToCreate.push({
|
||||
...templateVariable,
|
||||
id: uuidv4(),
|
||||
experimentId: newExperimentId,
|
||||
});
|
||||
}
|
||||
|
||||
const maxSortIndex =
|
||||
(
|
||||
await prisma.experiment.aggregate({
|
||||
_max: {
|
||||
sortIndex: true,
|
||||
},
|
||||
})
|
||||
)._max?.sortIndex ?? 0;
|
||||
|
||||
await prisma.$transaction([
|
||||
prisma.experiment.create({
|
||||
data: {
|
||||
id: newExperimentId,
|
||||
sortIndex: maxSortIndex + 1,
|
||||
label: `${existingExp.label} (forked)`,
|
||||
organizationId: (await userOrg(ctx.session.user.id)).id,
|
||||
},
|
||||
}),
|
||||
prisma.promptVariant.createMany({
|
||||
data: variantsToCreate,
|
||||
}),
|
||||
prisma.testScenario.createMany({
|
||||
data: scenariosToCreate,
|
||||
}),
|
||||
prisma.scenarioVariantCell.createMany({
|
||||
data: cellsToCreate,
|
||||
}),
|
||||
prisma.modelResponse.createMany({
|
||||
data: modelResponsesToCreate,
|
||||
}),
|
||||
prisma.evaluation.createMany({
|
||||
data: evaluationsToCreate,
|
||||
}),
|
||||
prisma.outputEvaluation.createMany({
|
||||
data: outputEvaluationsToCreate,
|
||||
}),
|
||||
prisma.templateVariable.createMany({
|
||||
data: templateVariablesToCreate,
|
||||
}),
|
||||
]);
|
||||
|
||||
return newExperimentId;
|
||||
}),
|
||||
|
||||
create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => {
|
||||
// Anyone can create an experiment
|
||||
requireNothing(ctx);
|
||||
|
||||
const organizationId = (await userOrg(ctx.session.user.id)).id;
|
||||
|
||||
const maxSortIndex =
|
||||
(
|
||||
await prisma.experiment.aggregate({
|
||||
_max: {
|
||||
sortIndex: true,
|
||||
},
|
||||
where: { organizationId },
|
||||
})
|
||||
)._max?.sortIndex ?? 0;
|
||||
|
||||
const exp = await prisma.experiment.create({
|
||||
data: {
|
||||
sortIndex: maxSortIndex + 1,
|
||||
label: `Experiment ${maxSortIndex + 1}`,
|
||||
organizationId,
|
||||
},
|
||||
});
|
||||
|
||||
const [variant, _, scenario1, scenario2, scenario3] = await prisma.$transaction([
|
||||
prisma.promptVariant.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
label: "Prompt Variant 1",
|
||||
sortIndex: 0,
|
||||
// The interpolated $ is necessary until dedent incorporates
|
||||
// https://github.com/dmnd/dedent/pull/46
|
||||
promptConstructor: dedent`
|
||||
/**
|
||||
* Use Javascript to define an OpenAI chat completion
|
||||
* (https://platform.openai.com/docs/api-reference/chat/create).
|
||||
*
|
||||
* You have access to the current scenario in the \`scenario\`
|
||||
* variable.
|
||||
*/
|
||||
|
||||
definePrompt("openai/ChatCompletion", {
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
stream: true,
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: \`Write 'Start experimenting!' in ${"$"}{scenario.language}\`,
|
||||
},
|
||||
],
|
||||
});`,
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
modelProvider: "openai/ChatCompletion",
|
||||
promptConstructorVersion,
|
||||
},
|
||||
}),
|
||||
prisma.templateVariable.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
label: "language",
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
variableValues: {
|
||||
language: "English",
|
||||
},
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
variableValues: {
|
||||
language: "Spanish",
|
||||
},
|
||||
},
|
||||
}),
|
||||
prisma.testScenario.create({
|
||||
data: {
|
||||
experimentId: exp.id,
|
||||
variableValues: {
|
||||
language: "German",
|
||||
},
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
||||
await generateNewCell(variant.id, scenario1.id);
|
||||
await generateNewCell(variant.id, scenario2.id);
|
||||
await generateNewCell(variant.id, scenario3.id);
|
||||
|
||||
return exp;
|
||||
}),
|
||||
|
||||
update: protectedProcedure
|
||||
.input(z.object({ id: z.string(), updates: z.object({ label: z.string() }) }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyExperiment(input.id, ctx);
|
||||
return await prisma.experiment.update({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
data: {
|
||||
label: input.updates.label,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
delete: protectedProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyExperiment(input.id, ctx);
|
||||
|
||||
await prisma.experiment.delete({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
// Keeping these on `experiment` for now because we might want to limit the
|
||||
// providers based on your account/experiment
|
||||
promptTypes: publicProcedure.query(async () => {
|
||||
return await generateTypes();
|
||||
}),
|
||||
});
|
||||
419
app/src/server/api/routers/promptVariants.router.ts
Normal file
419
app/src/server/api/routers/promptVariants.router.ts
Normal file
@@ -0,0 +1,419 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
import userError from "~/server/utils/error";
|
||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||
import { reorderPromptVariants } from "~/server/utils/reorderPromptVariants";
|
||||
import { type PromptVariant } from "@prisma/client";
|
||||
import { deriveNewConstructFn } from "~/server/utils/deriveNewContructFn";
|
||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||
import modelProviders from "~/modelProviders/modelProviders";
|
||||
import { ZodSupportedProvider } from "~/modelProviders/types";
|
||||
import parsePromptConstructor from "~/promptConstructor/parse";
|
||||
import { promptConstructorVersion } from "~/promptConstructor/version";
|
||||
|
||||
export const promptVariantsRouter = createTRPCRouter({
|
||||
list: publicProcedure
|
||||
.input(z.object({ experimentId: z.string() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.experimentId, ctx);
|
||||
|
||||
return await prisma.promptVariant.findMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
orderBy: { sortIndex: "asc" },
|
||||
});
|
||||
}),
|
||||
|
||||
stats: publicProcedure
|
||||
.input(z.object({ variantId: z.string() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
const variant = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: input.variantId,
|
||||
},
|
||||
});
|
||||
|
||||
if (!variant) {
|
||||
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
|
||||
}
|
||||
|
||||
await requireCanViewExperiment(variant.experimentId, ctx);
|
||||
|
||||
const outputEvals = await prisma.outputEvaluation.groupBy({
|
||||
by: ["evaluationId"],
|
||||
_sum: {
|
||||
result: true,
|
||||
},
|
||||
_count: {
|
||||
id: true,
|
||||
},
|
||||
where: {
|
||||
modelResponse: {
|
||||
outdated: false,
|
||||
output: { not: Prisma.AnyNull },
|
||||
scenarioVariantCell: {
|
||||
promptVariant: {
|
||||
id: input.variantId,
|
||||
visible: true,
|
||||
},
|
||||
testScenario: {
|
||||
visible: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const evals = await prisma.evaluation.findMany({
|
||||
where: {
|
||||
experimentId: variant.experimentId,
|
||||
},
|
||||
});
|
||||
|
||||
const evalResults = evals.map((evalItem) => {
|
||||
const evalResult = outputEvals.find(
|
||||
(outputEval) => outputEval.evaluationId === evalItem.id,
|
||||
);
|
||||
return {
|
||||
id: evalItem.id,
|
||||
label: evalItem.label,
|
||||
passCount: evalResult?._sum?.result ?? 0,
|
||||
totalCount: evalResult?._count?.id ?? 1,
|
||||
};
|
||||
});
|
||||
|
||||
const scenarioCount = await prisma.testScenario.count({
|
||||
where: {
|
||||
experimentId: variant.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
const outputCount = await prisma.scenarioVariantCell.count({
|
||||
where: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenario: { visible: true },
|
||||
modelResponses: {
|
||||
some: {
|
||||
outdated: false,
|
||||
output: {
|
||||
not: Prisma.AnyNull,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const overallTokens = await prisma.modelResponse.aggregate({
|
||||
where: {
|
||||
outdated: false,
|
||||
output: {
|
||||
not: Prisma.AnyNull,
|
||||
},
|
||||
scenarioVariantCell: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenario: {
|
||||
visible: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
_sum: {
|
||||
cost: true,
|
||||
promptTokens: true,
|
||||
completionTokens: true,
|
||||
},
|
||||
});
|
||||
|
||||
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
|
||||
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
|
||||
|
||||
const awaitingEvals = !!evalResults.find(
|
||||
(result) => result.totalCount < scenarioCount * evals.length,
|
||||
);
|
||||
|
||||
return {
|
||||
evalResults,
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
overallCost: overallTokens._sum?.cost ?? 0,
|
||||
scenarioCount,
|
||||
outputCount,
|
||||
awaitingEvals,
|
||||
};
|
||||
}),
|
||||
|
||||
create: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
experimentId: z.string(),
|
||||
variantId: z.string().optional(),
|
||||
streamScenarios: z.array(z.string()),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.experimentId, ctx);
|
||||
|
||||
let originalVariant: PromptVariant | null = null;
|
||||
if (input.variantId) {
|
||||
originalVariant = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: input.variantId,
|
||||
},
|
||||
});
|
||||
} else {
|
||||
originalVariant = await prisma.promptVariant.findFirst({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
orderBy: {
|
||||
sortIndex: "desc",
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const largestSortIndex =
|
||||
(
|
||||
await prisma.promptVariant.aggregate({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
},
|
||||
_max: {
|
||||
sortIndex: true,
|
||||
},
|
||||
})
|
||||
)._max?.sortIndex ?? 0;
|
||||
|
||||
const newVariantLabel =
|
||||
input.variantId && originalVariant
|
||||
? `${originalVariant?.label} Copy`
|
||||
: `Prompt Variant ${largestSortIndex + 2}`;
|
||||
|
||||
const newConstructFn = await deriveNewConstructFn(originalVariant);
|
||||
|
||||
const createNewVariantAction = prisma.promptVariant.create({
|
||||
data: {
|
||||
experimentId: input.experimentId,
|
||||
label: newVariantLabel,
|
||||
sortIndex: (originalVariant?.sortIndex ?? 0) + 1,
|
||||
promptConstructor: newConstructFn,
|
||||
promptConstructorVersion:
|
||||
originalVariant?.promptConstructorVersion ?? promptConstructorVersion,
|
||||
model: originalVariant?.model ?? "gpt-3.5-turbo",
|
||||
modelProvider: originalVariant?.modelProvider ?? "openai/ChatCompletion",
|
||||
},
|
||||
});
|
||||
|
||||
const [newVariant] = await prisma.$transaction([
|
||||
createNewVariantAction,
|
||||
recordExperimentUpdated(input.experimentId),
|
||||
]);
|
||||
|
||||
if (originalVariant) {
|
||||
// Insert new variant to right of original variant
|
||||
await reorderPromptVariants(newVariant.id, originalVariant.id, true);
|
||||
}
|
||||
|
||||
const scenarios = await prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
for (const scenario of scenarios) {
|
||||
await generateNewCell(newVariant.id, scenario.id, {
|
||||
stream: input.streamScenarios.includes(scenario.id),
|
||||
});
|
||||
}
|
||||
|
||||
return newVariant;
|
||||
}),
|
||||
|
||||
update: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
updates: z.object({
|
||||
label: z.string().optional(),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const existing = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
|
||||
if (!existing) {
|
||||
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
|
||||
}
|
||||
|
||||
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||
|
||||
const updatePromptVariantAction = prisma.promptVariant.update({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
data: input.updates,
|
||||
});
|
||||
|
||||
const [updatedPromptVariant] = await prisma.$transaction([
|
||||
updatePromptVariantAction,
|
||||
recordExperimentUpdated(existing.experimentId),
|
||||
]);
|
||||
|
||||
return updatedPromptVariant;
|
||||
}),
|
||||
|
||||
hide: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.promptVariant.findUniqueOrThrow({
|
||||
where: { id: input.id },
|
||||
});
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
|
||||
const updatedPromptVariant = await prisma.promptVariant.update({
|
||||
where: { id: input.id },
|
||||
data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
|
||||
});
|
||||
|
||||
return updatedPromptVariant;
|
||||
}),
|
||||
|
||||
getModifiedPromptFn: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
instructions: z.string().optional(),
|
||||
newModel: z
|
||||
.object({
|
||||
provider: ZodSupportedProvider,
|
||||
model: z.string(),
|
||||
})
|
||||
.optional(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const existing = await prisma.promptVariant.findUniqueOrThrow({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||
|
||||
const constructedPrompt = await parsePromptConstructor(existing.promptConstructor);
|
||||
|
||||
if ("error" in constructedPrompt) {
|
||||
return userError(constructedPrompt.error);
|
||||
}
|
||||
|
||||
const model = input.newModel
|
||||
? modelProviders[input.newModel.provider].models[input.newModel.model]
|
||||
: undefined;
|
||||
|
||||
const promptConstructionFn = await deriveNewConstructFn(existing, model, input.instructions);
|
||||
|
||||
// TODO: Validate promptConstructionFn
|
||||
// TODO: Record in some sort of history
|
||||
|
||||
return promptConstructionFn;
|
||||
}),
|
||||
|
||||
replaceVariant: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
promptConstructor: z.string(),
|
||||
streamScenarios: z.array(z.string()),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const existing = await prisma.promptVariant.findUniqueOrThrow({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||
|
||||
if (!existing) {
|
||||
throw new Error(`Prompt Variant with id ${input.id} does not exist`);
|
||||
}
|
||||
|
||||
const parsedPrompt = await parsePromptConstructor(input.promptConstructor);
|
||||
|
||||
if ("error" in parsedPrompt) {
|
||||
return userError(parsedPrompt.error);
|
||||
}
|
||||
|
||||
// Create a duplicate with only the config changed
|
||||
const newVariant = await prisma.promptVariant.create({
|
||||
data: {
|
||||
experimentId: existing.experimentId,
|
||||
label: existing.label,
|
||||
sortIndex: existing.sortIndex,
|
||||
uiId: existing.uiId,
|
||||
promptConstructor: input.promptConstructor,
|
||||
promptConstructorVersion: existing.promptConstructorVersion,
|
||||
modelProvider: parsedPrompt.modelProvider,
|
||||
model: parsedPrompt.model,
|
||||
},
|
||||
});
|
||||
|
||||
// Hide anything with the same uiId besides the new one
|
||||
const hideOldVariants = prisma.promptVariant.updateMany({
|
||||
where: {
|
||||
uiId: existing.uiId,
|
||||
id: {
|
||||
not: newVariant.id,
|
||||
},
|
||||
},
|
||||
data: {
|
||||
visible: false,
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.$transaction([hideOldVariants, recordExperimentUpdated(existing.experimentId)]);
|
||||
|
||||
const scenarios = await prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId: newVariant.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
for (const scenario of scenarios) {
|
||||
await generateNewCell(newVariant.id, scenario.id, {
|
||||
stream: input.streamScenarios.includes(scenario.id),
|
||||
});
|
||||
}
|
||||
|
||||
return { status: "ok" } as const;
|
||||
}),
|
||||
|
||||
reorder: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
draggedId: z.string(),
|
||||
droppedId: z.string(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.promptVariant.findUniqueOrThrow({
|
||||
where: { id: input.draggedId },
|
||||
});
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
|
||||
await reorderPromptVariants(input.draggedId, input.droppedId);
|
||||
}),
|
||||
});
|
||||
99
app/src/server/api/routers/scenarioVariantCells.router.ts
Normal file
99
app/src/server/api/routers/scenarioVariantCells.router.ts
Normal file
@@ -0,0 +1,99 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { queueQueryModel } from "~/server/tasks/queryModel.task";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||
|
||||
export const scenarioVariantCellsRouter = createTRPCRouter({
|
||||
get: publicProcedure
|
||||
.input(
|
||||
z.object({
|
||||
scenarioId: z.string(),
|
||||
variantId: z.string(),
|
||||
}),
|
||||
)
|
||||
.query(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.testScenario.findUniqueOrThrow({
|
||||
where: { id: input.scenarioId },
|
||||
});
|
||||
await requireCanViewExperiment(experimentId, ctx);
|
||||
|
||||
const [cell, numTotalEvals] = await prisma.$transaction([
|
||||
prisma.scenarioVariantCell.findUnique({
|
||||
where: {
|
||||
promptVariantId_testScenarioId: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenarioId: input.scenarioId,
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelResponses: {
|
||||
where: {
|
||||
outdated: false,
|
||||
},
|
||||
include: {
|
||||
outputEvaluations: {
|
||||
include: {
|
||||
evaluation: {
|
||||
select: { label: true },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
prisma.evaluation.count({
|
||||
where: { experimentId },
|
||||
}),
|
||||
]);
|
||||
|
||||
if (!cell) return null;
|
||||
|
||||
const lastResponse = cell.modelResponses?.[cell.modelResponses?.length - 1];
|
||||
const evalsComplete = lastResponse?.outputEvaluations?.length === numTotalEvals;
|
||||
|
||||
return {
|
||||
...cell,
|
||||
evalsComplete,
|
||||
};
|
||||
}),
|
||||
forceRefetch: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
scenarioId: z.string(),
|
||||
variantId: z.string(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.testScenario.findUniqueOrThrow({
|
||||
where: { id: input.scenarioId },
|
||||
});
|
||||
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
|
||||
const cell = await prisma.scenarioVariantCell.findUnique({
|
||||
where: {
|
||||
promptVariantId_testScenarioId: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenarioId: input.scenarioId,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
if (!cell) {
|
||||
await generateNewCell(input.variantId, input.scenarioId, { stream: true });
|
||||
return;
|
||||
}
|
||||
|
||||
await prisma.modelResponse.updateMany({
|
||||
where: { scenarioVariantCellId: cell.id },
|
||||
data: {
|
||||
outdated: true,
|
||||
},
|
||||
});
|
||||
|
||||
await queueQueryModel(cell.id, true);
|
||||
}),
|
||||
});
|
||||
252
app/src/server/api/routers/scenarios.router.ts
Normal file
252
app/src/server/api/routers/scenarios.router.ts
Normal file
@@ -0,0 +1,252 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { autogenerateScenarioValues } from "../autogenerate/autogenerateScenarioValues";
|
||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||
import { runAllEvals } from "~/server/utils/evaluations";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||
|
||||
const PAGE_SIZE = 10;
|
||||
|
||||
export const scenariosRouter = createTRPCRouter({
|
||||
list: publicProcedure
|
||||
.input(z.object({ experimentId: z.string(), page: z.number() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.experimentId, ctx);
|
||||
|
||||
const { experimentId, page } = input;
|
||||
|
||||
const scenarios = await prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId,
|
||||
visible: true,
|
||||
},
|
||||
orderBy: { sortIndex: "asc" },
|
||||
skip: (page - 1) * PAGE_SIZE,
|
||||
take: PAGE_SIZE,
|
||||
});
|
||||
|
||||
const count = await prisma.testScenario.count({
|
||||
where: {
|
||||
experimentId,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
return {
|
||||
scenarios,
|
||||
startIndex: (page - 1) * PAGE_SIZE + 1,
|
||||
lastPage: Math.ceil(count / PAGE_SIZE),
|
||||
count,
|
||||
};
|
||||
}),
|
||||
get: protectedProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
|
||||
const scenario = await prisma.testScenario.findUnique({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
|
||||
if (!scenario) {
|
||||
throw new Error(`Scenario with id ${input.id} does not exist`);
|
||||
}
|
||||
|
||||
await requireCanViewExperiment(scenario.experimentId, ctx);
|
||||
|
||||
return scenario;
|
||||
}),
|
||||
create: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
experimentId: z.string(),
|
||||
autogenerate: z.boolean().optional(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyExperiment(input.experimentId, ctx);
|
||||
|
||||
await prisma.testScenario.updateMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
},
|
||||
data: {
|
||||
sortIndex: {
|
||||
increment: 1,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const createNewScenarioAction = prisma.testScenario.create({
|
||||
data: {
|
||||
experimentId: input.experimentId,
|
||||
sortIndex: 0,
|
||||
variableValues: input.autogenerate
|
||||
? await autogenerateScenarioValues(input.experimentId)
|
||||
: {},
|
||||
},
|
||||
});
|
||||
|
||||
const [scenario] = await prisma.$transaction([
|
||||
createNewScenarioAction,
|
||||
recordExperimentUpdated(input.experimentId),
|
||||
]);
|
||||
|
||||
const promptVariants = await prisma.promptVariant.findMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
for (const variant of promptVariants) {
|
||||
await generateNewCell(variant.id, scenario.id, { stream: true });
|
||||
}
|
||||
}),
|
||||
|
||||
hide: protectedProcedure.input(z.object({ id: z.string() })).mutation(async ({ input, ctx }) => {
|
||||
const experimentId = (
|
||||
await prisma.testScenario.findUniqueOrThrow({
|
||||
where: { id: input.id },
|
||||
})
|
||||
).experimentId;
|
||||
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
const hiddenScenario = await prisma.testScenario.update({
|
||||
where: { id: input.id },
|
||||
data: { visible: false, experiment: { update: { updatedAt: new Date() } } },
|
||||
});
|
||||
|
||||
// Reevaluate all evaluations now that this scenario is hidden
|
||||
await runAllEvals(hiddenScenario.experimentId);
|
||||
|
||||
return hiddenScenario;
|
||||
}),
|
||||
|
||||
reorder: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
draggedId: z.string(),
|
||||
droppedId: z.string(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const dragged = await prisma.testScenario.findUnique({
|
||||
where: {
|
||||
id: input.draggedId,
|
||||
},
|
||||
});
|
||||
|
||||
const dropped = await prisma.testScenario.findUnique({
|
||||
where: {
|
||||
id: input.droppedId,
|
||||
},
|
||||
});
|
||||
|
||||
if (!dragged || !dropped || dragged.experimentId !== dropped.experimentId) {
|
||||
throw new Error(
|
||||
`Prompt Variant with id ${input.draggedId} or ${input.droppedId} does not exist`,
|
||||
);
|
||||
}
|
||||
|
||||
await requireCanModifyExperiment(dragged.experimentId, ctx);
|
||||
|
||||
const visibleItems = await prisma.testScenario.findMany({
|
||||
where: {
|
||||
experimentId: dragged.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
orderBy: {
|
||||
sortIndex: "asc",
|
||||
},
|
||||
});
|
||||
|
||||
// Remove the dragged item from its current position
|
||||
const orderedItems = visibleItems.filter((item) => item.id !== dragged.id);
|
||||
|
||||
// Find the index of the dragged item and the dropped item
|
||||
const dragIndex = visibleItems.findIndex((item) => item.id === dragged.id);
|
||||
const dropIndex = visibleItems.findIndex((item) => item.id === dropped.id);
|
||||
|
||||
// Determine the new index for the dragged item
|
||||
let newIndex;
|
||||
if (dragIndex < dropIndex) {
|
||||
newIndex = dropIndex + 1; // Insert after the dropped item
|
||||
} else {
|
||||
newIndex = dropIndex; // Insert before the dropped item
|
||||
}
|
||||
|
||||
// Insert the dragged item at the new position
|
||||
orderedItems.splice(newIndex, 0, dragged);
|
||||
|
||||
// Now, we need to update all the items with their new sortIndex
|
||||
await prisma.$transaction(
|
||||
orderedItems.map((item, index) => {
|
||||
return prisma.testScenario.update({
|
||||
where: {
|
||||
id: item.id,
|
||||
},
|
||||
data: {
|
||||
sortIndex: index,
|
||||
},
|
||||
});
|
||||
}),
|
||||
);
|
||||
}),
|
||||
|
||||
replaceWithValues: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
values: z.record(z.string()),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const existing = await prisma.testScenario.findUnique({
|
||||
where: {
|
||||
id: input.id,
|
||||
},
|
||||
});
|
||||
|
||||
if (!existing) {
|
||||
throw new Error(`Scenario with id ${input.id} does not exist`);
|
||||
}
|
||||
|
||||
await requireCanModifyExperiment(existing.experimentId, ctx);
|
||||
|
||||
const newScenario = await prisma.testScenario.create({
|
||||
data: {
|
||||
experimentId: existing.experimentId,
|
||||
sortIndex: existing.sortIndex,
|
||||
variableValues: input.values,
|
||||
uiId: existing.uiId,
|
||||
},
|
||||
});
|
||||
|
||||
// Hide the old scenario
|
||||
await prisma.testScenario.updateMany({
|
||||
where: {
|
||||
uiId: existing.uiId,
|
||||
id: {
|
||||
not: newScenario.id,
|
||||
},
|
||||
},
|
||||
data: {
|
||||
visible: false,
|
||||
},
|
||||
});
|
||||
|
||||
const promptVariants = await prisma.promptVariant.findMany({
|
||||
where: {
|
||||
experimentId: newScenario.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
});
|
||||
|
||||
for (const variant of promptVariants) {
|
||||
await generateNewCell(variant.id, newScenario.id, { stream: true });
|
||||
}
|
||||
|
||||
return newScenario;
|
||||
}),
|
||||
});
|
||||
49
app/src/server/api/routers/templateVariables.router.ts
Normal file
49
app/src/server/api/routers/templateVariables.router.ts
Normal file
@@ -0,0 +1,49 @@
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { requireCanModifyExperiment, requireCanViewExperiment } from "~/utils/accessControl";
|
||||
|
||||
export const templateVarsRouter = createTRPCRouter({
|
||||
create: protectedProcedure
|
||||
.input(z.object({ experimentId: z.string(), label: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
await requireCanModifyExperiment(input.experimentId, ctx);
|
||||
|
||||
await prisma.templateVariable.create({
|
||||
data: {
|
||||
experimentId: input.experimentId,
|
||||
label: input.label,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
delete: protectedProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { experimentId } = await prisma.templateVariable.findUniqueOrThrow({
|
||||
where: { id: input.id },
|
||||
});
|
||||
|
||||
await requireCanModifyExperiment(experimentId, ctx);
|
||||
|
||||
await prisma.templateVariable.delete({ where: { id: input.id } });
|
||||
}),
|
||||
|
||||
list: publicProcedure
|
||||
.input(z.object({ experimentId: z.string() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
await requireCanViewExperiment(input.experimentId, ctx);
|
||||
return await prisma.templateVariable.findMany({
|
||||
where: {
|
||||
experimentId: input.experimentId,
|
||||
},
|
||||
orderBy: {
|
||||
createdAt: "asc",
|
||||
},
|
||||
select: {
|
||||
id: true,
|
||||
label: true,
|
||||
},
|
||||
});
|
||||
}),
|
||||
});
|
||||
36
app/src/server/api/routers/worldChamps.router.ts
Normal file
36
app/src/server/api/routers/worldChamps.router.ts
Normal file
@@ -0,0 +1,36 @@
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { requireNothing } from "~/utils/accessControl";
|
||||
|
||||
export const worldChampsRouter = createTRPCRouter({
|
||||
userStatus: publicProcedure.query(async ({ ctx }) => {
|
||||
const userId = ctx.session?.user.id;
|
||||
|
||||
if (!userId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return await prisma.worldChampEntrant.findUnique({
|
||||
where: { userId },
|
||||
});
|
||||
}),
|
||||
|
||||
apply: protectedProcedure.mutation(async ({ ctx }) => {
|
||||
const userId = ctx.session.user.id;
|
||||
requireNothing(ctx);
|
||||
|
||||
const existingEntrant = await prisma.worldChampEntrant.findUnique({
|
||||
where: { userId },
|
||||
});
|
||||
|
||||
if (existingEntrant) {
|
||||
return existingEntrant;
|
||||
}
|
||||
|
||||
return await prisma.worldChampEntrant.create({
|
||||
data: {
|
||||
userId,
|
||||
},
|
||||
});
|
||||
}),
|
||||
});
|
||||
151
app/src/server/api/trpc.ts
Normal file
151
app/src/server/api/trpc.ts
Normal file
@@ -0,0 +1,151 @@
|
||||
/**
|
||||
* YOU PROBABLY DON'T NEED TO EDIT THIS FILE, UNLESS:
|
||||
* 1. You want to modify request context (see Part 1).
|
||||
* 2. You want to create a new middleware or type of procedure (see Part 3).
|
||||
*
|
||||
* TL;DR - This is where all the tRPC server stuff is created and plugged in. The pieces you will
|
||||
* need to use are documented accordingly near the end.
|
||||
*/
|
||||
|
||||
import { initTRPC, TRPCError } from "@trpc/server";
|
||||
import { type CreateNextContextOptions } from "@trpc/server/adapters/next";
|
||||
import { type Session } from "next-auth";
|
||||
import superjson from "superjson";
|
||||
import { ZodError } from "zod";
|
||||
import { getServerAuthSession } from "~/server/auth";
|
||||
import { prisma } from "~/server/db";
|
||||
import { capturePath } from "~/utils/analytics/serverAnalytics";
|
||||
|
||||
/**
|
||||
* 1. CONTEXT
|
||||
*
|
||||
* This section defines the "contexts" that are available in the backend API.
|
||||
*
|
||||
* These allow you to access things when processing a request, like the database, the session, etc.
|
||||
*/
|
||||
|
||||
type CreateContextOptions = {
|
||||
session: Session | null;
|
||||
};
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-empty-function
|
||||
const noOp = () => {};
|
||||
|
||||
/**
|
||||
* This helper generates the "internals" for a tRPC context. If you need to use it, you can export
|
||||
* it from here.
|
||||
*
|
||||
* Examples of things you may need it for:
|
||||
* - testing, so we don't have to mock Next.js' req/res
|
||||
* - tRPC's `createSSGHelpers`, where we don't have req/res
|
||||
*
|
||||
* @see https://create.t3.gg/en/usage/trpc#-serverapitrpcts
|
||||
*/
|
||||
export const createInnerTRPCContext = (opts: CreateContextOptions) => {
|
||||
return {
|
||||
session: opts.session,
|
||||
prisma,
|
||||
markAccessControlRun: noOp,
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* This is the actual context you will use in your router. It will be used to process every request
|
||||
* that goes through your tRPC endpoint.
|
||||
*
|
||||
* @see https://trpc.io/docs/context
|
||||
*/
|
||||
export const createTRPCContext = async (opts: CreateNextContextOptions) => {
|
||||
const { req, res } = opts;
|
||||
|
||||
// Get the session from the server using the getServerSession wrapper function
|
||||
const session = await getServerAuthSession({ req, res });
|
||||
|
||||
return createInnerTRPCContext({
|
||||
session,
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* 2. INITIALIZATION
|
||||
*
|
||||
* This is where the tRPC API is initialized, connecting the context and transformer. We also parse
|
||||
* ZodErrors so that you get typesafety on the frontend if your procedure fails due to validation
|
||||
* errors on the backend.
|
||||
*/
|
||||
|
||||
export type TRPCContext = Awaited<ReturnType<typeof createTRPCContext>>;
|
||||
|
||||
const t = initTRPC.context<typeof createTRPCContext>().create({
|
||||
transformer: superjson,
|
||||
errorFormatter({ shape, error }) {
|
||||
return {
|
||||
...shape,
|
||||
data: {
|
||||
...shape.data,
|
||||
zodError: error.cause instanceof ZodError ? error.cause.flatten() : null,
|
||||
},
|
||||
};
|
||||
},
|
||||
});
|
||||
|
||||
/**
|
||||
* 3. ROUTER & PROCEDURE (THE IMPORTANT BIT)
|
||||
*
|
||||
* These are the pieces you use to build your tRPC API. You should import these a lot in the
|
||||
* "/src/server/api/routers" directory.
|
||||
*/
|
||||
|
||||
/**
|
||||
* This is how you create new routers and sub-routers in your tRPC API.
|
||||
*
|
||||
* @see https://trpc.io/docs/router
|
||||
*/
|
||||
export const createTRPCRouter = t.router;
|
||||
|
||||
/**
|
||||
* Public (unauthenticated) procedure
|
||||
*
|
||||
* This is the base piece you use to build new queries and mutations on your tRPC API. It does not
|
||||
* guarantee that a user querying is authorized, but you can still access user session data if they
|
||||
* are logged in.
|
||||
*/
|
||||
export const publicProcedure = t.procedure;
|
||||
|
||||
/** Reusable middleware that enforces users are logged in before running the procedure. */
|
||||
const enforceUserIsAuthed = t.middleware(async ({ ctx, next, path }) => {
|
||||
if (!ctx.session || !ctx.session.user) {
|
||||
throw new TRPCError({ code: "UNAUTHORIZED" });
|
||||
}
|
||||
|
||||
let accessControlRun = false;
|
||||
const resp = await next({
|
||||
ctx: {
|
||||
// infers the `session` as non-nullable
|
||||
session: { ...ctx.session, user: ctx.session.user },
|
||||
markAccessControlRun: () => {
|
||||
accessControlRun = true;
|
||||
},
|
||||
},
|
||||
});
|
||||
if (!accessControlRun)
|
||||
throw new TRPCError({
|
||||
code: "INTERNAL_SERVER_ERROR",
|
||||
message:
|
||||
"Protected routes must perform access control checks then explicitly invoke the `ctx.markAccessControlRun()` function to ensure we don't forget access control on a route.",
|
||||
});
|
||||
|
||||
capturePath(ctx.session, path);
|
||||
|
||||
return resp;
|
||||
});
|
||||
|
||||
/**
|
||||
* Protected (authenticated) procedure
|
||||
*
|
||||
* If you want a query or mutation to ONLY be accessible to logged in users, use this. It verifies
|
||||
* the session is valid and guarantees `ctx.session.user` is not null.
|
||||
*
|
||||
* @see https://trpc.io/docs/procedures
|
||||
*/
|
||||
export const protectedProcedure = t.procedure.use(enforceUserIsAuthed);
|
||||
67
app/src/server/auth.ts
Normal file
67
app/src/server/auth.ts
Normal file
@@ -0,0 +1,67 @@
|
||||
import { PrismaAdapter } from "@next-auth/prisma-adapter";
|
||||
import { type GetServerSidePropsContext } from "next";
|
||||
import { getServerSession, type NextAuthOptions, type DefaultSession } from "next-auth";
|
||||
import { prisma } from "~/server/db";
|
||||
import GitHubProvider from "next-auth/providers/github";
|
||||
import { env } from "~/env.mjs";
|
||||
|
||||
/**
|
||||
* Module augmentation for `next-auth` types. Allows us to add custom properties to the `session`
|
||||
* object and keep type safety.
|
||||
*
|
||||
* @see https://next-auth.js.org/getting-started/typescript#module-augmentation
|
||||
*/
|
||||
declare module "next-auth" {
|
||||
interface Session extends DefaultSession {
|
||||
user: {
|
||||
id: string;
|
||||
// ...other properties
|
||||
// role: UserRole;
|
||||
} & DefaultSession["user"];
|
||||
}
|
||||
|
||||
// interface User {
|
||||
// // ...other properties
|
||||
// // role: UserRole;
|
||||
// }
|
||||
}
|
||||
|
||||
/**
|
||||
* Options for NextAuth.js used to configure adapters, providers, callbacks, etc.
|
||||
*
|
||||
* @see https://next-auth.js.org/configuration/options
|
||||
*/
|
||||
export const authOptions: NextAuthOptions = {
|
||||
callbacks: {
|
||||
session: ({ session, user }) => ({
|
||||
...session,
|
||||
user: {
|
||||
...session.user,
|
||||
id: user.id,
|
||||
},
|
||||
}),
|
||||
},
|
||||
adapter: PrismaAdapter(prisma),
|
||||
providers: [
|
||||
GitHubProvider({
|
||||
clientId: env.GITHUB_CLIENT_ID,
|
||||
clientSecret: env.GITHUB_CLIENT_SECRET,
|
||||
}),
|
||||
],
|
||||
theme: {
|
||||
logo: "/logo.svg",
|
||||
brandColor: "#ff5733",
|
||||
},
|
||||
};
|
||||
|
||||
/**
|
||||
* Wrapper for `getServerSession` so that you don't need to import the `authOptions` in every file.
|
||||
*
|
||||
* @see https://next-auth.js.org/configuration/nextjs
|
||||
*/
|
||||
export const getServerAuthSession = (ctx: {
|
||||
req: GetServerSidePropsContext["req"];
|
||||
res: GetServerSidePropsContext["res"];
|
||||
}) => {
|
||||
return getServerSession(ctx.req, ctx.res, authOptions);
|
||||
};
|
||||
17
app/src/server/db.ts
Normal file
17
app/src/server/db.ts
Normal file
@@ -0,0 +1,17 @@
|
||||
import { PrismaClient } from "@prisma/client";
|
||||
import { env } from "~/env.mjs";
|
||||
|
||||
const globalForPrisma = globalThis as unknown as {
|
||||
prisma: PrismaClient | undefined;
|
||||
};
|
||||
|
||||
export const prisma =
|
||||
globalForPrisma.prisma ??
|
||||
new PrismaClient({
|
||||
log:
|
||||
env.NODE_ENV === "development" && !env.RESTRICT_PRISMA_LOGS
|
||||
? ["query", "error", "warn"]
|
||||
: ["error"],
|
||||
});
|
||||
|
||||
if (env.NODE_ENV !== "production") globalForPrisma.prisma = prisma;
|
||||
19
app/src/server/scripts/openai-test.ts
Normal file
19
app/src/server/scripts/openai-test.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
import "dotenv/config";
|
||||
import { openai } from "../utils/openai";
|
||||
|
||||
const resp = await openai.chat.completions.create({
|
||||
model: "gpt-3.5-turbo-0613",
|
||||
stream: true,
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "count to 20",
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
for await (const part of resp) {
|
||||
console.log("part", part);
|
||||
}
|
||||
|
||||
console.log("final resp", resp);
|
||||
26
app/src/server/scripts/replicate-test.ts
Normal file
26
app/src/server/scripts/replicate-test.ts
Normal file
@@ -0,0 +1,26 @@
|
||||
/* eslint-disable */
|
||||
|
||||
import "dotenv/config";
|
||||
import Replicate from "replicate";
|
||||
|
||||
const replicate = new Replicate({
|
||||
auth: process.env.REPLICATE_API_TOKEN || "",
|
||||
});
|
||||
|
||||
console.log("going to run");
|
||||
const prediction = await replicate.predictions.create({
|
||||
version: "3725a659b5afff1a0ba9bead5fac3899d998feaad00e07032ca2b0e35eb14f8a",
|
||||
input: {
|
||||
prompt: "...",
|
||||
},
|
||||
});
|
||||
|
||||
console.log("waiting");
|
||||
setInterval(() => {
|
||||
replicate.predictions.get(prediction.id).then((prediction) => {
|
||||
console.log(prediction);
|
||||
});
|
||||
}, 500);
|
||||
// const output = await replicate.wait(prediction, {});
|
||||
|
||||
// console.log(output);
|
||||
12
app/src/server/scripts/studio-prod.sh
Executable file
12
app/src/server/scripts/studio-prod.sh
Executable file
@@ -0,0 +1,12 @@
|
||||
#! /bin/bash
|
||||
|
||||
set -e
|
||||
cd "$(dirname "$0")/../../.."
|
||||
|
||||
|
||||
set -o allexport
|
||||
source .env
|
||||
set +o allexport
|
||||
|
||||
echo "Connecting to prod db"
|
||||
DATABASE_URL=$PROD_DATABASE_URL pnpm prisma studio
|
||||
31
app/src/server/tasks/defineTask.ts
Normal file
31
app/src/server/tasks/defineTask.ts
Normal file
@@ -0,0 +1,31 @@
|
||||
// Import necessary dependencies
|
||||
import { quickAddJob, type Helpers, type Task } from "graphile-worker";
|
||||
import { env } from "~/env.mjs";
|
||||
|
||||
// Define the defineTask function
|
||||
function defineTask<TPayload>(
|
||||
taskIdentifier: string,
|
||||
taskHandler: (payload: TPayload, helpers: Helpers) => Promise<void>,
|
||||
) {
|
||||
const enqueue = async (payload: TPayload, runAt?: Date) => {
|
||||
console.log("Enqueuing task", taskIdentifier, payload);
|
||||
await quickAddJob({ connectionString: env.DATABASE_URL }, taskIdentifier, payload, { runAt });
|
||||
};
|
||||
|
||||
const handler = (payload: TPayload, helpers: Helpers) => {
|
||||
helpers.logger.info(`Running task ${taskIdentifier} with payload: ${JSON.stringify(payload)}`);
|
||||
return taskHandler(payload, helpers);
|
||||
};
|
||||
|
||||
const task = {
|
||||
identifier: taskIdentifier,
|
||||
handler: handler as Task,
|
||||
};
|
||||
|
||||
return {
|
||||
enqueue,
|
||||
task,
|
||||
};
|
||||
}
|
||||
|
||||
export default defineTask;
|
||||
188
app/src/server/tasks/queryModel.task.ts
Normal file
188
app/src/server/tasks/queryModel.task.ts
Normal file
@@ -0,0 +1,188 @@
|
||||
import { type Prisma } from "@prisma/client";
|
||||
import { type JsonObject } from "type-fest";
|
||||
import modelProviders from "~/modelProviders/modelProviders";
|
||||
import { prisma } from "~/server/db";
|
||||
import { wsConnection } from "~/utils/wsConnection";
|
||||
import { runEvalsForOutput } from "../utils/evaluations";
|
||||
import hashPrompt from "../utils/hashPrompt";
|
||||
import defineTask from "./defineTask";
|
||||
import parsePromptConstructor from "~/promptConstructor/parse";
|
||||
|
||||
export type QueryModelJob = {
|
||||
cellId: string;
|
||||
stream: boolean;
|
||||
numPreviousTries: number;
|
||||
};
|
||||
|
||||
const MAX_AUTO_RETRIES = 50;
|
||||
const MIN_DELAY = 500; // milliseconds
|
||||
const MAX_DELAY = 15000; // milliseconds
|
||||
|
||||
function calculateDelay(numPreviousTries: number): number {
|
||||
const baseDelay = Math.min(MAX_DELAY, MIN_DELAY * Math.pow(2, numPreviousTries));
|
||||
const jitter = Math.random() * baseDelay;
|
||||
return baseDelay + jitter;
|
||||
}
|
||||
|
||||
export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) => {
|
||||
console.log("RUNNING TASK", task);
|
||||
const { cellId, stream, numPreviousTries } = task;
|
||||
const cell = await prisma.scenarioVariantCell.findUnique({
|
||||
where: { id: cellId },
|
||||
include: { modelResponses: true },
|
||||
});
|
||||
if (!cell) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If cell is not pending, then some other job is already processing it
|
||||
if (cell.retrievalStatus !== "PENDING") {
|
||||
return;
|
||||
}
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
retrievalStatus: "IN_PROGRESS",
|
||||
jobStartedAt: new Date(),
|
||||
},
|
||||
});
|
||||
|
||||
const variant = await prisma.promptVariant.findUnique({
|
||||
where: { id: cell.promptVariantId },
|
||||
});
|
||||
if (!variant) {
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
errorMessage: "Prompt Variant not found",
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const scenario = await prisma.testScenario.findUnique({
|
||||
where: { id: cell.testScenarioId },
|
||||
});
|
||||
if (!scenario) {
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
errorMessage: "Scenario not found",
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const prompt = await parsePromptConstructor(
|
||||
variant.promptConstructor,
|
||||
scenario.variableValues as JsonObject,
|
||||
);
|
||||
|
||||
if ("error" in prompt) {
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
errorMessage: prompt.error,
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const provider = modelProviders[prompt.modelProvider];
|
||||
|
||||
const onStream = stream
|
||||
? (partialOutput: (typeof provider)["_outputSchema"]) => {
|
||||
wsConnection.emit("message", { channel: cell.id, payload: partialOutput });
|
||||
}
|
||||
: null;
|
||||
|
||||
const inputHash = hashPrompt(prompt);
|
||||
|
||||
let modelResponse = await prisma.modelResponse.create({
|
||||
data: {
|
||||
inputHash,
|
||||
scenarioVariantCellId: cellId,
|
||||
requestedAt: new Date(),
|
||||
},
|
||||
});
|
||||
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
||||
if (response.type === "success") {
|
||||
modelResponse = await prisma.modelResponse.update({
|
||||
where: { id: modelResponse.id },
|
||||
data: {
|
||||
output: response.value as Prisma.InputJsonObject,
|
||||
statusCode: response.statusCode,
|
||||
receivedAt: new Date(),
|
||||
promptTokens: response.promptTokens,
|
||||
completionTokens: response.completionTokens,
|
||||
cost: response.cost,
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
retrievalStatus: "COMPLETE",
|
||||
},
|
||||
});
|
||||
|
||||
await runEvalsForOutput(variant.experimentId, scenario, modelResponse, prompt.modelProvider);
|
||||
} else {
|
||||
const shouldRetry = response.autoRetry && numPreviousTries < MAX_AUTO_RETRIES;
|
||||
const delay = calculateDelay(numPreviousTries);
|
||||
const retryTime = new Date(Date.now() + delay);
|
||||
|
||||
await prisma.modelResponse.update({
|
||||
where: { id: modelResponse.id },
|
||||
data: {
|
||||
statusCode: response.statusCode,
|
||||
errorMessage: response.message,
|
||||
receivedAt: new Date(),
|
||||
retryTime: shouldRetry ? retryTime : null,
|
||||
},
|
||||
});
|
||||
|
||||
if (shouldRetry) {
|
||||
await queryModel.enqueue(
|
||||
{
|
||||
cellId,
|
||||
stream,
|
||||
numPreviousTries: numPreviousTries + 1,
|
||||
},
|
||||
retryTime,
|
||||
);
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
retrievalStatus: "PENDING",
|
||||
},
|
||||
});
|
||||
} else {
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
export const queueQueryModel = async (cellId: string, stream: boolean) => {
|
||||
await Promise.all([
|
||||
prisma.scenarioVariantCell.update({
|
||||
where: {
|
||||
id: cellId,
|
||||
},
|
||||
data: {
|
||||
retrievalStatus: "PENDING",
|
||||
errorMessage: null,
|
||||
jobQueuedAt: new Date(),
|
||||
},
|
||||
}),
|
||||
queryModel.enqueue({ cellId, stream, numPreviousTries: 0 }),
|
||||
]);
|
||||
};
|
||||
17
app/src/server/tasks/runNewEval.task.ts
Normal file
17
app/src/server/tasks/runNewEval.task.ts
Normal file
@@ -0,0 +1,17 @@
|
||||
import { runAllEvals } from "../utils/evaluations";
|
||||
import defineTask from "./defineTask";
|
||||
|
||||
export type RunNewEvalJob = {
|
||||
experimentId: string;
|
||||
};
|
||||
|
||||
// When a new eval is created, we want to run it on all existing outputs, but return the new eval first
|
||||
export const runNewEval = defineTask<RunNewEvalJob>("runNewEval", async (task) => {
|
||||
console.log("RUNNING TASK", task);
|
||||
const { experimentId } = task;
|
||||
await runAllEvals(experimentId);
|
||||
});
|
||||
|
||||
export const queueRunNewEval = async (experimentId: string) => {
|
||||
await runNewEval.enqueue({ experimentId });
|
||||
};
|
||||
29
app/src/server/tasks/worker.ts
Normal file
29
app/src/server/tasks/worker.ts
Normal file
@@ -0,0 +1,29 @@
|
||||
import { type TaskList, run } from "graphile-worker";
|
||||
import "dotenv/config";
|
||||
|
||||
import { env } from "~/env.mjs";
|
||||
import { queryModel } from "./queryModel.task";
|
||||
import { runNewEval } from "./runNewEval.task";
|
||||
|
||||
console.log("Starting worker");
|
||||
|
||||
const registeredTasks = [queryModel, runNewEval];
|
||||
|
||||
const taskList = registeredTasks.reduce((acc, task) => {
|
||||
acc[task.task.identifier] = task.task.handler;
|
||||
return acc;
|
||||
}, {} as TaskList);
|
||||
|
||||
// Run a worker to execute jobs:
|
||||
const runner = await run({
|
||||
connectionString: env.DATABASE_URL,
|
||||
concurrency: 50,
|
||||
// Install signal handlers for graceful shutdown on SIGINT, SIGTERM, etc
|
||||
noHandleSignals: false,
|
||||
pollInterval: 1000,
|
||||
taskList,
|
||||
});
|
||||
|
||||
console.log("Worker successfully started");
|
||||
|
||||
await runner.promise;
|
||||
143
app/src/server/utils/deriveNewContructFn.ts
Normal file
143
app/src/server/utils/deriveNewContructFn.ts
Normal file
@@ -0,0 +1,143 @@
|
||||
import { type PromptVariant } from "@prisma/client";
|
||||
import ivm from "isolated-vm";
|
||||
import dedent from "dedent";
|
||||
import { openai } from "./openai";
|
||||
import { isObject } from "lodash-es";
|
||||
import type { CreateChatCompletionRequestMessage } from "openai/resources/chat/completions";
|
||||
import formatPromptConstructor from "~/promptConstructor/format";
|
||||
import { type SupportedProvider, type Model } from "~/modelProviders/types";
|
||||
import modelProviders from "~/modelProviders/modelProviders";
|
||||
|
||||
const isolate = new ivm.Isolate({ memoryLimit: 128 });
|
||||
|
||||
export async function deriveNewConstructFn(
|
||||
originalVariant: PromptVariant | null,
|
||||
newModel?: Model,
|
||||
instructions?: string,
|
||||
) {
|
||||
if (originalVariant && !newModel && !instructions) {
|
||||
return originalVariant.promptConstructor;
|
||||
}
|
||||
if (originalVariant && (newModel || instructions)) {
|
||||
return await requestUpdatedPromptFunction(originalVariant, newModel, instructions);
|
||||
}
|
||||
return dedent`
|
||||
prompt = {
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: "Return 'Hello, world!'",
|
||||
}
|
||||
]
|
||||
}`;
|
||||
}
|
||||
|
||||
const NUM_RETRIES = 5;
|
||||
const requestUpdatedPromptFunction = async (
|
||||
originalVariant: PromptVariant,
|
||||
newModel?: Model,
|
||||
instructions?: string,
|
||||
) => {
|
||||
const originalModelProvider = modelProviders[originalVariant.modelProvider as SupportedProvider];
|
||||
const originalModel = originalModelProvider.models[originalVariant.model] as Model;
|
||||
let newContructionFn = "";
|
||||
for (let i = 0; i < NUM_RETRIES; i++) {
|
||||
try {
|
||||
const messages: CreateChatCompletionRequestMessage[] = [
|
||||
{
|
||||
role: "system",
|
||||
content: `Your job is to update prompt constructor functions. Here is the api shape for the current model:\n---\n${JSON.stringify(
|
||||
originalModelProvider.inputSchema,
|
||||
null,
|
||||
2,
|
||||
)}\n\nDo not add any assistant messages.`,
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: `This is the current prompt constructor function:\n---\n${originalVariant.promptConstructor}`,
|
||||
},
|
||||
];
|
||||
if (newModel) {
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: `Return the prompt constructor function for ${newModel.name} given the existing prompt constructor function for ${originalModel.name}`,
|
||||
});
|
||||
if (newModel.provider !== originalModel.provider) {
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: `As seen in the first argument to definePrompt, the old provider endpoint was "${
|
||||
originalModel.provider
|
||||
}". The new provider endpoint is "${
|
||||
newModel.provider
|
||||
}". Here is the schema for the new model:\n---\n${JSON.stringify(
|
||||
modelProviders[newModel.provider].inputSchema,
|
||||
null,
|
||||
2,
|
||||
)}`,
|
||||
});
|
||||
} else {
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: `The provider is the same as the old provider: ${originalModel.provider}`,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (instructions) {
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: instructions,
|
||||
});
|
||||
}
|
||||
const completion = await openai.chat.completions.create({
|
||||
model: "gpt-4",
|
||||
messages,
|
||||
functions: [
|
||||
{
|
||||
name: "update_prompt_constructor_function",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
new_prompt_function: {
|
||||
type: "string",
|
||||
description: "The new prompt function, runnable in typescript",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
function_call: {
|
||||
name: "update_prompt_constructor_function",
|
||||
},
|
||||
});
|
||||
const argString = completion.choices[0]?.message?.function_call?.arguments || "{}";
|
||||
|
||||
const code = `
|
||||
global.contructPromptFunctionArgs = ${argString};
|
||||
`;
|
||||
|
||||
const context = await isolate.createContext();
|
||||
|
||||
const jail = context.global;
|
||||
await jail.set("global", jail.derefInto());
|
||||
|
||||
const script = await isolate.compileScript(code);
|
||||
|
||||
await script.run(context);
|
||||
const contructPromptFunctionArgs = (await context.global.get(
|
||||
"contructPromptFunctionArgs",
|
||||
)) as ivm.Reference;
|
||||
|
||||
const args = await contructPromptFunctionArgs.copy(); // Get the actual value from the isolate
|
||||
|
||||
if (args && isObject(args) && "new_prompt_function" in args) {
|
||||
newContructionFn = await formatPromptConstructor(args.new_prompt_function as string);
|
||||
break;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
|
||||
return newContructionFn;
|
||||
};
|
||||
6
app/src/server/utils/error.ts
Normal file
6
app/src/server/utils/error.ts
Normal file
@@ -0,0 +1,6 @@
|
||||
export default function userError(message: string): { status: "error"; message: string } {
|
||||
return {
|
||||
status: "error",
|
||||
message,
|
||||
};
|
||||
}
|
||||
99
app/src/server/utils/evaluations.ts
Normal file
99
app/src/server/utils/evaluations.ts
Normal file
@@ -0,0 +1,99 @@
|
||||
import { type ModelResponse, type Evaluation, Prisma } from "@prisma/client";
|
||||
import { prisma } from "../db";
|
||||
import { runOneEval } from "./runOneEval";
|
||||
import { type Scenario } from "~/components/OutputsTable/types";
|
||||
import { type SupportedProvider } from "~/modelProviders/types";
|
||||
|
||||
const runAndSaveEval = async (
|
||||
evaluation: Evaluation,
|
||||
scenario: Scenario,
|
||||
modelResponse: ModelResponse,
|
||||
provider: SupportedProvider,
|
||||
) => {
|
||||
const result = await runOneEval(evaluation, scenario, modelResponse, provider);
|
||||
return await prisma.outputEvaluation.upsert({
|
||||
where: {
|
||||
modelResponseId_evaluationId: {
|
||||
modelResponseId: modelResponse.id,
|
||||
evaluationId: evaluation.id,
|
||||
},
|
||||
},
|
||||
create: {
|
||||
modelResponseId: modelResponse.id,
|
||||
evaluationId: evaluation.id,
|
||||
...result,
|
||||
},
|
||||
update: {
|
||||
...result,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const runEvalsForOutput = async (
|
||||
experimentId: string,
|
||||
scenario: Scenario,
|
||||
modelResponse: ModelResponse,
|
||||
provider: SupportedProvider,
|
||||
) => {
|
||||
const evaluations = await prisma.evaluation.findMany({
|
||||
where: { experimentId },
|
||||
});
|
||||
|
||||
await Promise.all(
|
||||
evaluations.map(
|
||||
async (evaluation) => await runAndSaveEval(evaluation, scenario, modelResponse, provider),
|
||||
),
|
||||
);
|
||||
};
|
||||
|
||||
// Will not run eval-output pairs that already exist in the database
|
||||
export const runAllEvals = async (experimentId: string) => {
|
||||
const outputs = await prisma.modelResponse.findMany({
|
||||
where: {
|
||||
outdated: false,
|
||||
output: {
|
||||
not: Prisma.AnyNull,
|
||||
},
|
||||
scenarioVariantCell: {
|
||||
promptVariant: {
|
||||
experimentId,
|
||||
visible: true,
|
||||
},
|
||||
testScenario: {
|
||||
visible: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
include: {
|
||||
scenarioVariantCell: {
|
||||
include: {
|
||||
testScenario: true,
|
||||
promptVariant: true,
|
||||
},
|
||||
},
|
||||
outputEvaluations: true,
|
||||
},
|
||||
});
|
||||
const evals = await prisma.evaluation.findMany({
|
||||
where: { experimentId },
|
||||
});
|
||||
|
||||
await Promise.all(
|
||||
outputs.map(async (output) => {
|
||||
const evalsToBeRun = evals.filter(
|
||||
(evaluation) => !output.outputEvaluations.find((e) => e.evaluationId === evaluation.id),
|
||||
);
|
||||
|
||||
await Promise.all(
|
||||
evalsToBeRun.map(async (evaluation) => {
|
||||
await runAndSaveEval(
|
||||
evaluation,
|
||||
output.scenarioVariantCell.testScenario,
|
||||
output,
|
||||
output.scenarioVariantCell.promptVariant.modelProvider as SupportedProvider,
|
||||
);
|
||||
}),
|
||||
);
|
||||
}),
|
||||
);
|
||||
};
|
||||
15
app/src/server/utils/fillTemplate.ts
Normal file
15
app/src/server/utils/fillTemplate.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
export type VariableMap = Record<string, string>;
|
||||
|
||||
// Escape quotes to match the way we encode JSON
|
||||
export function escapeQuotes(str: string) {
|
||||
return str.replace(/(\\")|"/g, (match, p1) => (p1 ? match : '\\"'));
|
||||
}
|
||||
|
||||
// Escape regex special characters
|
||||
export function escapeRegExp(str: string) {
|
||||
return str.replace(/[.*+\-?^${}()|[\]\\]/g, "\\$&"); // $& means the whole matched string
|
||||
}
|
||||
|
||||
export function fillTemplate(template: string, variables: VariableMap): string {
|
||||
return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || "");
|
||||
}
|
||||
126
app/src/server/utils/generateNewCell.ts
Normal file
126
app/src/server/utils/generateNewCell.ts
Normal file
@@ -0,0 +1,126 @@
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { prisma } from "../db";
|
||||
import { type JsonObject } from "type-fest";
|
||||
import hashPrompt from "./hashPrompt";
|
||||
import { omit } from "lodash-es";
|
||||
import { queueQueryModel } from "../tasks/queryModel.task";
|
||||
import parsePromptConstructor from "~/promptConstructor/parse";
|
||||
|
||||
export const generateNewCell = async (
|
||||
variantId: string,
|
||||
scenarioId: string,
|
||||
options?: { stream?: boolean },
|
||||
): Promise<void> => {
|
||||
const stream = options?.stream ?? false;
|
||||
|
||||
const variant = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: variantId,
|
||||
},
|
||||
});
|
||||
|
||||
const scenario = await prisma.testScenario.findUnique({
|
||||
where: {
|
||||
id: scenarioId,
|
||||
},
|
||||
});
|
||||
|
||||
if (!variant || !scenario) return;
|
||||
|
||||
let cell = await prisma.scenarioVariantCell.findUnique({
|
||||
where: {
|
||||
promptVariantId_testScenarioId: {
|
||||
promptVariantId: variantId,
|
||||
testScenarioId: scenarioId,
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelResponses: true,
|
||||
},
|
||||
});
|
||||
|
||||
if (cell) return;
|
||||
|
||||
const parsedConstructFn = await parsePromptConstructor(
|
||||
variant.promptConstructor,
|
||||
scenario.variableValues as JsonObject,
|
||||
);
|
||||
|
||||
if ("error" in parsedConstructFn) {
|
||||
await prisma.scenarioVariantCell.create({
|
||||
data: {
|
||||
promptVariantId: variantId,
|
||||
testScenarioId: scenarioId,
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const inputHash = hashPrompt(parsedConstructFn);
|
||||
|
||||
cell = await prisma.scenarioVariantCell.create({
|
||||
data: {
|
||||
promptVariantId: variantId,
|
||||
testScenarioId: scenarioId,
|
||||
prompt: parsedConstructFn.modelInput as unknown as Prisma.InputJsonValue,
|
||||
retrievalStatus: "PENDING",
|
||||
},
|
||||
include: {
|
||||
modelResponses: true,
|
||||
},
|
||||
});
|
||||
|
||||
const matchingModelResponse = await prisma.modelResponse.findFirst({
|
||||
where: {
|
||||
inputHash,
|
||||
output: {
|
||||
not: Prisma.AnyNull,
|
||||
},
|
||||
},
|
||||
orderBy: {
|
||||
receivedAt: "desc",
|
||||
},
|
||||
include: {
|
||||
scenarioVariantCell: true,
|
||||
},
|
||||
take: 1,
|
||||
});
|
||||
|
||||
if (matchingModelResponse) {
|
||||
const newModelResponse = await prisma.modelResponse.create({
|
||||
data: {
|
||||
...omit(matchingModelResponse, ["id", "scenarioVariantCell"]),
|
||||
scenarioVariantCellId: cell.id,
|
||||
output: matchingModelResponse.output as Prisma.InputJsonValue,
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cell.id },
|
||||
data: {
|
||||
retrievalStatus: "COMPLETE",
|
||||
jobStartedAt: matchingModelResponse.scenarioVariantCell.jobStartedAt,
|
||||
jobQueuedAt: matchingModelResponse.scenarioVariantCell.jobQueuedAt,
|
||||
},
|
||||
});
|
||||
|
||||
// Copy over all eval results as well
|
||||
await Promise.all(
|
||||
(
|
||||
await prisma.outputEvaluation.findMany({
|
||||
where: { modelResponseId: matchingModelResponse.id },
|
||||
})
|
||||
).map(async (evaluation) => {
|
||||
await prisma.outputEvaluation.create({
|
||||
data: {
|
||||
...omit(evaluation, ["id"]),
|
||||
modelResponseId: newModelResponse.id,
|
||||
},
|
||||
});
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
await queueQueryModel(cell.id, stream);
|
||||
}
|
||||
};
|
||||
37
app/src/server/utils/hashPrompt.ts
Normal file
37
app/src/server/utils/hashPrompt.ts
Normal file
@@ -0,0 +1,37 @@
|
||||
import crypto from "crypto";
|
||||
import { type JsonValue } from "type-fest";
|
||||
import { ParsedPromptConstructor } from "~/promptConstructor/parse";
|
||||
|
||||
function sortKeys(obj: JsonValue): JsonValue {
|
||||
if (typeof obj !== "object" || obj === null) {
|
||||
// Not an object or array, return as is
|
||||
return obj;
|
||||
}
|
||||
|
||||
if (Array.isArray(obj)) {
|
||||
return obj.map(sortKeys);
|
||||
}
|
||||
|
||||
// Get keys and sort them
|
||||
const keys = Object.keys(obj).sort();
|
||||
const sortedObj = {};
|
||||
|
||||
for (const key of keys) {
|
||||
// @ts-expect-error not worth fixing types
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-argument
|
||||
sortedObj[key] = sortKeys(obj[key]);
|
||||
}
|
||||
|
||||
return sortedObj;
|
||||
}
|
||||
|
||||
export default function hashPrompt(prompt: ParsedPromptConstructor<any>): string {
|
||||
// Sort object keys recursively
|
||||
const sortedObj = sortKeys(prompt as unknown as JsonValue);
|
||||
|
||||
// Convert to JSON and hash it
|
||||
const str = JSON.stringify(sortedObj);
|
||||
const hash = crypto.createHash("sha256");
|
||||
hash.update(str);
|
||||
return hash.digest("hex");
|
||||
}
|
||||
6
app/src/server/utils/openai.ts
Normal file
6
app/src/server/utils/openai.ts
Normal file
@@ -0,0 +1,6 @@
|
||||
import { env } from "~/env.mjs";
|
||||
|
||||
import OpenAI from "openai";
|
||||
|
||||
// Set a dummy key so it doesn't fail at build time
|
||||
export const openai = new OpenAI({ apiKey: env.OPENAI_API_KEY ?? "dummy-key" });
|
||||
12
app/src/server/utils/recordExperimentUpdated.ts
Normal file
12
app/src/server/utils/recordExperimentUpdated.ts
Normal file
@@ -0,0 +1,12 @@
|
||||
import { prisma } from "~/server/db";
|
||||
|
||||
export const recordExperimentUpdated = (experimentId: string) => {
|
||||
return prisma.experiment.update({
|
||||
where: {
|
||||
id: experimentId,
|
||||
},
|
||||
data: {
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
});
|
||||
};
|
||||
65
app/src/server/utils/reorderPromptVariants.ts
Normal file
65
app/src/server/utils/reorderPromptVariants.ts
Normal file
@@ -0,0 +1,65 @@
|
||||
import { prisma } from "~/server/db";
|
||||
|
||||
export const reorderPromptVariants = async (
|
||||
movedId: string,
|
||||
stationaryTargetId: string,
|
||||
alwaysInsertRight?: boolean,
|
||||
) => {
|
||||
const moved = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: movedId,
|
||||
},
|
||||
});
|
||||
|
||||
const target = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: stationaryTargetId,
|
||||
},
|
||||
});
|
||||
|
||||
if (!moved || !target || moved.experimentId !== target.experimentId) {
|
||||
throw new Error(`Prompt Variant with id ${movedId} or ${stationaryTargetId} does not exist`);
|
||||
}
|
||||
|
||||
const visibleItems = await prisma.promptVariant.findMany({
|
||||
where: {
|
||||
experimentId: moved.experimentId,
|
||||
visible: true,
|
||||
},
|
||||
orderBy: {
|
||||
sortIndex: "asc",
|
||||
},
|
||||
});
|
||||
|
||||
// Remove the moved item from its current position
|
||||
const orderedItems = visibleItems.filter((item) => item.id !== moved.id);
|
||||
|
||||
// Find the index of the moved item and the target item
|
||||
const movedIndex = visibleItems.findIndex((item) => item.id === moved.id);
|
||||
const targetIndex = visibleItems.findIndex((item) => item.id === target.id);
|
||||
|
||||
// Determine the new index for the moved item
|
||||
let newIndex;
|
||||
if (movedIndex < targetIndex || alwaysInsertRight) {
|
||||
newIndex = targetIndex + 1; // Insert after the target item
|
||||
} else {
|
||||
newIndex = targetIndex; // Insert before the target item
|
||||
}
|
||||
|
||||
// Insert the moved item at the new position
|
||||
orderedItems.splice(newIndex, 0, moved);
|
||||
|
||||
// Now, we need to update all the items with their new sortIndex
|
||||
await prisma.$transaction(
|
||||
orderedItems.map((item, index) => {
|
||||
return prisma.promptVariant.update({
|
||||
where: {
|
||||
id: item.id,
|
||||
},
|
||||
data: {
|
||||
sortIndex: index,
|
||||
},
|
||||
});
|
||||
}),
|
||||
);
|
||||
};
|
||||
93
app/src/server/utils/runOneEval.ts
Normal file
93
app/src/server/utils/runOneEval.ts
Normal file
@@ -0,0 +1,93 @@
|
||||
import { type Evaluation, type ModelResponse, type TestScenario } from "@prisma/client";
|
||||
import { type VariableMap, fillTemplate, escapeRegExp, escapeQuotes } from "./fillTemplate";
|
||||
import { openai } from "./openai";
|
||||
import dedent from "dedent";
|
||||
import modelProviders from "~/modelProviders/modelProviders";
|
||||
import { type SupportedProvider } from "~/modelProviders/types";
|
||||
|
||||
export const runGpt4Eval = async (
|
||||
evaluation: Evaluation,
|
||||
scenario: TestScenario,
|
||||
stringifiedOutput: string,
|
||||
): Promise<{ result: number; details: string }> => {
|
||||
const output = await openai.chat.completions.create({
|
||||
model: "gpt-4-0613",
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: dedent`
|
||||
You are a highly intelligent AI model and have been tasked with evaluating the quality of a simpler model. Your objective is to determine whether the simpler model has produced a successful and correct output. You should return "true" if the output was successful and "false" if it was not. Pay more attention to the semantics of the output than the formatting. Success is defined in the following terms:
|
||||
---
|
||||
${evaluation.value}
|
||||
`,
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: `Scenario:\n---\n${JSON.stringify(scenario.variableValues, null, 2)}`,
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: `The full output of the simpler message:\n---\n${stringifiedOutput}`,
|
||||
},
|
||||
],
|
||||
function_call: {
|
||||
name: "report_success",
|
||||
},
|
||||
functions: [
|
||||
{
|
||||
name: "report_success",
|
||||
parameters: {
|
||||
type: "object",
|
||||
required: ["thoughts", "success"],
|
||||
properties: {
|
||||
thoughts: {
|
||||
type: "string",
|
||||
description: "Explain your reasoning for considering this a pass or fail",
|
||||
},
|
||||
success: {
|
||||
type: "boolean",
|
||||
description:
|
||||
"Whether the simpler model successfully completed the task for this scenario",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
try {
|
||||
const out = JSON.parse(output.choices[0]?.message?.function_call?.arguments ?? "");
|
||||
return { result: out.success ? 1 : 0, details: out.thoughts ?? JSON.stringify(out) };
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
return { result: 0, details: "Error parsing GPT-4 output" };
|
||||
}
|
||||
};
|
||||
|
||||
export const runOneEval = async (
|
||||
evaluation: Evaluation,
|
||||
scenario: TestScenario,
|
||||
modelResponse: ModelResponse,
|
||||
provider: SupportedProvider,
|
||||
): Promise<{ result: number; details?: string }> => {
|
||||
const modelProvider = modelProviders[provider];
|
||||
const message = modelProvider.normalizeOutput(modelResponse.output);
|
||||
|
||||
if (!message) return { result: 0 };
|
||||
|
||||
const stringifiedOutput =
|
||||
message.type === "json" ? JSON.stringify(message.value, null, 2) : message.value;
|
||||
|
||||
const matchRegex = escapeRegExp(
|
||||
fillTemplate(escapeQuotes(evaluation.value), scenario.variableValues as VariableMap),
|
||||
);
|
||||
|
||||
switch (evaluation.evalType) {
|
||||
case "CONTAINS":
|
||||
return { result: stringifiedOutput.match(matchRegex) !== null ? 1 : 0 };
|
||||
case "DOES_NOT_CONTAIN":
|
||||
return { result: stringifiedOutput.match(matchRegex) === null ? 1 : 0 };
|
||||
case "GPT4_EVAL":
|
||||
return await runGpt4Eval(evaluation, scenario, stringifiedOutput);
|
||||
}
|
||||
};
|
||||
1
app/src/server/utils/sleep.ts
Normal file
1
app/src/server/utils/sleep.ts
Normal file
@@ -0,0 +1 @@
|
||||
export const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
|
||||
19
app/src/server/utils/userOrg.ts
Normal file
19
app/src/server/utils/userOrg.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
import { prisma } from "~/server/db";
|
||||
|
||||
export default async function userOrg(userId: string) {
|
||||
return await prisma.organization.upsert({
|
||||
where: {
|
||||
personalOrgUserId: userId,
|
||||
},
|
||||
update: {},
|
||||
create: {
|
||||
personalOrgUserId: userId,
|
||||
organizationUsers: {
|
||||
create: {
|
||||
userId: userId,
|
||||
role: "ADMIN",
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
Reference in New Issue
Block a user