Assign experiments and datasets to correct org

This commit is contained in:
David Corbitt
2023-08-07 19:10:27 -07:00
parent 9f17d98736
commit 0f9a83cf45
7 changed files with 280 additions and 268 deletions

View File

@@ -15,6 +15,7 @@ import { useRouter } from "next/router";
import { BsPlusSquare } from "react-icons/bs"; import { BsPlusSquare } from "react-icons/bs";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks"; import { useHandledAsyncCallback } from "~/utils/hooks";
import { useAppStore } from "~/state/store";
type DatasetData = { type DatasetData = {
name: string; name: string;
@@ -71,11 +72,12 @@ const CountLabel = ({ label, count }: { label: string; count: number }) => {
export const NewDatasetCard = () => { export const NewDatasetCard = () => {
const router = useRouter(); const router = useRouter();
const selectedOrgId = useAppStore((s) => s.selectedOrgId);
const createMutation = api.datasets.create.useMutation(); const createMutation = api.datasets.create.useMutation();
const [createDataset, isLoading] = useHandledAsyncCallback(async () => { const [createDataset, isLoading] = useHandledAsyncCallback(async () => {
const newDataset = await createMutation.mutateAsync({ label: "New Dataset" }); const newDataset = await createMutation.mutateAsync({ organizationId: selectedOrgId ?? "" });
await router.push({ pathname: "/data/[id]", query: { id: newDataset.id } }); await router.push({ pathname: "/data/[id]", query: { id: newDataset.id } });
}, [createMutation, router]); }, [createMutation, router, selectedOrgId]);
return ( return (
<AspectRatio ratio={1.2} w="full"> <AspectRatio ratio={1.2} w="full">

View File

@@ -15,6 +15,7 @@ import { useRouter } from "next/router";
import { BsPlusSquare } from "react-icons/bs"; import { BsPlusSquare } from "react-icons/bs";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks"; import { useHandledAsyncCallback } from "~/utils/hooks";
import { useAppStore } from "~/state/store";
type ExperimentData = { type ExperimentData = {
testScenarioCount: number; testScenarioCount: number;
@@ -75,11 +76,17 @@ const CountLabel = ({ label, count }: { label: string; count: number }) => {
export const NewExperimentCard = () => { export const NewExperimentCard = () => {
const router = useRouter(); const router = useRouter();
const selectedOrgId = useAppStore((s) => s.selectedOrgId);
const createMutation = api.experiments.create.useMutation(); const createMutation = api.experiments.create.useMutation();
const [createExperiment, isLoading] = useHandledAsyncCallback(async () => { const [createExperiment, isLoading] = useHandledAsyncCallback(async () => {
const newExperiment = await createMutation.mutateAsync({ label: "New Experiment" }); const newExperiment = await createMutation.mutateAsync({
await router.push({ pathname: "/experiments/[id]", query: { id: newExperiment.id } }); organizationId: selectedOrgId ?? "",
}, [createMutation, router]); });
await router.push({
pathname: "/experiments/[id]",
query: { id: newExperiment.id },
});
}, [createMutation, router, selectedOrgId]);
return ( return (
<AspectRatio ratio={1.2} w="full"> <AspectRatio ratio={1.2} w="full">

View File

@@ -1,14 +1,14 @@
import { HStack, Flex, Text } from "@chakra-ui/react"; import { HStack, Flex, Text } from "@chakra-ui/react";
import Link from "next/link"; import Link from "next/link";
import { useSelectedOrg } from "~/utils/hooks"; import { useSelectedOrg } from "~/utils/hooks";
// Have to export only contents here instead of full BreadcrumbItem because Chakra doesn't // Have to export only contents here instead of full BreadcrumbItem because Chakra doesn't
// recognize a BreadcrumbItem exported with this component as a valid child of Breadcrumb. // recognize a BreadcrumbItem exported with this component as a valid child of Breadcrumb.
export default function ProjectBreadcrumbContents({ orgName = "" }: { orgName?: string }) {
export default function ProjectBreadcrumbContents() {
const { data: selectedOrg } = useSelectedOrg(); const { data: selectedOrg } = useSelectedOrg();
orgName = orgName || selectedOrg?.name || "";
return ( return (
<Link href="/home"> <Link href="/home">
<HStack w="full"> <HStack w="full">
@@ -23,10 +23,10 @@ export default function ProjectBreadcrumbContents() {
alignItems="center" alignItems="center"
justifyContent="center" justifyContent="center"
> >
<Text>{selectedOrg?.name[0]?.toUpperCase()}</Text> <Text>{orgName[0]?.toUpperCase()}</Text>
</Flex> </Flex>
<Text display={{ base: "none", md: "block" }} py={1}> <Text display={{ base: "none", md: "block" }} py={1}>
{selectedOrg?.name} {orgName}
</Text> </Text>
</HStack> </HStack>
</Link> </Link>

View File

@@ -60,7 +60,7 @@ export default function Dataset() {
<PageHeaderContainer> <PageHeaderContainer>
<Breadcrumb> <Breadcrumb>
<BreadcrumbItem> <BreadcrumbItem>
<ProjectBreadcrumbContents /> <ProjectBreadcrumbContents orgName={dataset.data?.organization?.name} />
</BreadcrumbItem> </BreadcrumbItem>
<BreadcrumbItem> <BreadcrumbItem>
<Link href="/data"> <Link href="/data">

View File

@@ -110,7 +110,7 @@ export default function Experiment() {
<PageHeaderContainer> <PageHeaderContainer>
<Breadcrumb> <Breadcrumb>
<BreadcrumbItem> <BreadcrumbItem>
<ProjectBreadcrumbContents /> <ProjectBreadcrumbContents orgName={experiment.data?.organization?.name} />
</BreadcrumbItem> </BreadcrumbItem>
<BreadcrumbItem> <BreadcrumbItem>
<Link href="/experiments"> <Link href="/experiments">

View File

@@ -3,11 +3,10 @@ import { createTRPCRouter, protectedProcedure, publicProcedure } from "~/server/
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { import {
requireCanModifyDataset, requireCanModifyDataset,
requireCanModifyOrganization,
requireCanViewDataset, requireCanViewDataset,
requireCanViewOrganization, requireCanViewOrganization,
requireNothing,
} from "~/utils/accessControl"; } from "~/utils/accessControl";
import userOrg from "~/server/utils/userOrg";
export const datasetsRouter = createTRPCRouter({ export const datasetsRouter = createTRPCRouter({
list: protectedProcedure list: protectedProcedure
@@ -36,30 +35,30 @@ export const datasetsRouter = createTRPCRouter({
await requireCanViewDataset(input.id, ctx); await requireCanViewDataset(input.id, ctx);
return await prisma.dataset.findFirstOrThrow({ return await prisma.dataset.findFirstOrThrow({
where: { id: input.id }, where: { id: input.id },
include: {
organization: true,
},
}); });
}), }),
create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => { create: protectedProcedure
// Anyone can create an experiment .input(z.object({ organizationId: z.string() }))
requireNothing(ctx); .mutation(async ({ input, ctx }) => {
await requireCanModifyOrganization(input.organizationId, ctx);
const numDatasets = await prisma.dataset.count({ const numDatasets = await prisma.dataset.count({
where: { where: {
organization: { organizationId: input.organizationId,
organizationUsers: {
some: { userId: ctx.session.user.id },
},
}, },
}, });
});
return await prisma.dataset.create({ return await prisma.dataset.create({
data: { data: {
name: `Dataset ${numDatasets + 1}`, name: `Dataset ${numDatasets + 1}`,
organizationId: (await userOrg(ctx.session.user.id)).id, organizationId: input.organizationId,
}, },
}); });
}), }),
update: protectedProcedure update: protectedProcedure
.input(z.object({ id: z.string(), updates: z.object({ name: z.string() }) })) .input(z.object({ id: z.string(), updates: z.object({ name: z.string() }) }))

View File

@@ -8,11 +8,10 @@ import { generateNewCell } from "~/server/utils/generateNewCell";
import { import {
canModifyExperiment, canModifyExperiment,
requireCanModifyExperiment, requireCanModifyExperiment,
requireCanModifyOrganization,
requireCanViewExperiment, requireCanViewExperiment,
requireCanViewOrganization, requireCanViewOrganization,
requireNothing,
} from "~/utils/accessControl"; } from "~/utils/accessControl";
import userOrg from "~/server/utils/userOrg";
import generateTypes from "~/modelProviders/generateTypes"; import generateTypes from "~/modelProviders/generateTypes";
import { promptConstructorVersion } from "~/promptConstructor/version"; import { promptConstructorVersion } from "~/promptConstructor/version";
@@ -47,7 +46,7 @@ export const experimentsRouter = createTRPCRouter({
list: protectedProcedure list: protectedProcedure
.input(z.object({ organizationId: z.string() })) .input(z.object({ organizationId: z.string() }))
.query(async ({ input, ctx }) => { .query(async ({ input, ctx }) => {
await requireCanViewOrganization(input.organizationId, ctx) await requireCanViewOrganization(input.organizationId, ctx);
const experiments = await prisma.experiment.findMany({ const experiments = await prisma.experiment.findMany({
where: { where: {
@@ -90,6 +89,9 @@ export const experimentsRouter = createTRPCRouter({
await requireCanViewExperiment(input.id, ctx); await requireCanViewExperiment(input.id, ctx);
const experiment = await prisma.experiment.findFirstOrThrow({ const experiment = await prisma.experiment.findFirstOrThrow({
where: { id: input.id }, where: { id: input.id },
include: {
organization: true,
},
}); });
const canModify = ctx.session?.user.id const canModify = ctx.session?.user.id
@@ -105,222 +107,224 @@ export const experimentsRouter = createTRPCRouter({
}; };
}), }),
fork: protectedProcedure.input(z.object({ id: z.string() })).mutation(async ({ input, ctx }) => { fork: protectedProcedure
await requireCanViewExperiment(input.id, ctx); .input(z.object({ id: z.string(), organizationId: z.string() }))
.mutation(async ({ input, ctx }) => {
await requireCanViewExperiment(input.id, ctx);
await requireCanModifyOrganization(input.organizationId, ctx);
const [ const [
existingExp, existingExp,
existingVariants, existingVariants,
existingScenarios, existingScenarios,
existingCells, existingCells,
evaluations, evaluations,
templateVariables, templateVariables,
] = await prisma.$transaction([ ] = await prisma.$transaction([
prisma.experiment.findUniqueOrThrow({ prisma.experiment.findUniqueOrThrow({
where: { where: {
id: input.id, id: input.id,
},
}),
prisma.promptVariant.findMany({
where: {
experimentId: input.id,
visible: true,
},
}),
prisma.testScenario.findMany({
where: {
experimentId: input.id,
visible: true,
},
}),
prisma.scenarioVariantCell.findMany({
where: {
testScenario: {
visible: true,
}, },
promptVariant: { }),
prisma.promptVariant.findMany({
where: {
experimentId: input.id, experimentId: input.id,
visible: true, visible: true,
}, },
}, }),
include: { prisma.testScenario.findMany({
modelResponses: { where: {
include: { experimentId: input.id,
outputEvaluations: true, visible: true,
},
}),
prisma.scenarioVariantCell.findMany({
where: {
testScenario: {
visible: true,
},
promptVariant: {
experimentId: input.id,
visible: true,
}, },
}, },
}, include: {
}), modelResponses: {
prisma.evaluation.findMany({ include: {
where: { outputEvaluations: true,
experimentId: input.id, },
}, },
}), },
prisma.templateVariable.findMany({ }),
where: { prisma.evaluation.findMany({
experimentId: input.id, where: {
}, experimentId: input.id,
}), },
]); }),
prisma.templateVariable.findMany({
where: {
experimentId: input.id,
},
}),
]);
const newExperimentId = uuidv4(); const newExperimentId = uuidv4();
const existingToNewVariantIds = new Map<string, string>(); const existingToNewVariantIds = new Map<string, string>();
const variantsToCreate: Prisma.PromptVariantCreateManyInput[] = []; const variantsToCreate: Prisma.PromptVariantCreateManyInput[] = [];
for (const variant of existingVariants) { for (const variant of existingVariants) {
const newVariantId = uuidv4(); const newVariantId = uuidv4();
existingToNewVariantIds.set(variant.id, newVariantId); existingToNewVariantIds.set(variant.id, newVariantId);
variantsToCreate.push({ variantsToCreate.push({
...variant, ...variant,
id: newVariantId, id: newVariantId,
experimentId: newExperimentId, experimentId: newExperimentId,
});
}
const existingToNewScenarioIds = new Map<string, string>();
const scenariosToCreate: Prisma.TestScenarioCreateManyInput[] = [];
for (const scenario of existingScenarios) {
const newScenarioId = uuidv4();
existingToNewScenarioIds.set(scenario.id, newScenarioId);
scenariosToCreate.push({
...scenario,
id: newScenarioId,
experimentId: newExperimentId,
variableValues: scenario.variableValues as Prisma.InputJsonValue,
});
}
const existingToNewEvaluationIds = new Map<string, string>();
const evaluationsToCreate: Prisma.EvaluationCreateManyInput[] = [];
for (const evaluation of evaluations) {
const newEvaluationId = uuidv4();
existingToNewEvaluationIds.set(evaluation.id, newEvaluationId);
evaluationsToCreate.push({
...evaluation,
id: newEvaluationId,
experimentId: newExperimentId,
});
}
const cellsToCreate: Prisma.ScenarioVariantCellCreateManyInput[] = [];
const modelResponsesToCreate: Prisma.ModelResponseCreateManyInput[] = [];
const outputEvaluationsToCreate: Prisma.OutputEvaluationCreateManyInput[] = [];
for (const cell of existingCells) {
const newCellId = uuidv4();
const { modelResponses, ...cellData } = cell;
cellsToCreate.push({
...cellData,
id: newCellId,
promptVariantId: existingToNewVariantIds.get(cell.promptVariantId) ?? "",
testScenarioId: existingToNewScenarioIds.get(cell.testScenarioId) ?? "",
prompt: (cell.prompt as Prisma.InputJsonValue) ?? undefined,
});
for (const modelResponse of modelResponses) {
const newModelResponseId = uuidv4();
const { outputEvaluations, ...modelResponseData } = modelResponse;
modelResponsesToCreate.push({
...modelResponseData,
id: newModelResponseId,
scenarioVariantCellId: newCellId,
output: (modelResponse.output as Prisma.InputJsonValue) ?? undefined,
}); });
for (const evaluation of outputEvaluations) { }
outputEvaluationsToCreate.push({
...evaluation, const existingToNewScenarioIds = new Map<string, string>();
id: uuidv4(), const scenariosToCreate: Prisma.TestScenarioCreateManyInput[] = [];
modelResponseId: newModelResponseId, for (const scenario of existingScenarios) {
evaluationId: existingToNewEvaluationIds.get(evaluation.evaluationId) ?? "", const newScenarioId = uuidv4();
existingToNewScenarioIds.set(scenario.id, newScenarioId);
scenariosToCreate.push({
...scenario,
id: newScenarioId,
experimentId: newExperimentId,
variableValues: scenario.variableValues as Prisma.InputJsonValue,
});
}
const existingToNewEvaluationIds = new Map<string, string>();
const evaluationsToCreate: Prisma.EvaluationCreateManyInput[] = [];
for (const evaluation of evaluations) {
const newEvaluationId = uuidv4();
existingToNewEvaluationIds.set(evaluation.id, newEvaluationId);
evaluationsToCreate.push({
...evaluation,
id: newEvaluationId,
experimentId: newExperimentId,
});
}
const cellsToCreate: Prisma.ScenarioVariantCellCreateManyInput[] = [];
const modelResponsesToCreate: Prisma.ModelResponseCreateManyInput[] = [];
const outputEvaluationsToCreate: Prisma.OutputEvaluationCreateManyInput[] = [];
for (const cell of existingCells) {
const newCellId = uuidv4();
const { modelResponses, ...cellData } = cell;
cellsToCreate.push({
...cellData,
id: newCellId,
promptVariantId: existingToNewVariantIds.get(cell.promptVariantId) ?? "",
testScenarioId: existingToNewScenarioIds.get(cell.testScenarioId) ?? "",
prompt: (cell.prompt as Prisma.InputJsonValue) ?? undefined,
});
for (const modelResponse of modelResponses) {
const newModelResponseId = uuidv4();
const { outputEvaluations, ...modelResponseData } = modelResponse;
modelResponsesToCreate.push({
...modelResponseData,
id: newModelResponseId,
scenarioVariantCellId: newCellId,
output: (modelResponse.output as Prisma.InputJsonValue) ?? undefined,
}); });
for (const evaluation of outputEvaluations) {
outputEvaluationsToCreate.push({
...evaluation,
id: uuidv4(),
modelResponseId: newModelResponseId,
evaluationId: existingToNewEvaluationIds.get(evaluation.evaluationId) ?? "",
});
}
} }
} }
}
const templateVariablesToCreate: Prisma.TemplateVariableCreateManyInput[] = []; const templateVariablesToCreate: Prisma.TemplateVariableCreateManyInput[] = [];
for (const templateVariable of templateVariables) { for (const templateVariable of templateVariables) {
templateVariablesToCreate.push({ templateVariablesToCreate.push({
...templateVariable, ...templateVariable,
id: uuidv4(), id: uuidv4(),
experimentId: newExperimentId, experimentId: newExperimentId,
}); });
} }
const maxSortIndex = const maxSortIndex =
( (
await prisma.experiment.aggregate({ await prisma.experiment.aggregate({
_max: { _max: {
sortIndex: true, sortIndex: true,
},
})
)._max?.sortIndex ?? 0;
await prisma.$transaction([
prisma.experiment.create({
data: {
id: newExperimentId,
sortIndex: maxSortIndex + 1,
label: `${existingExp.label} (forked)`,
organizationId: input.organizationId,
}, },
}) }),
)._max?.sortIndex ?? 0; prisma.promptVariant.createMany({
data: variantsToCreate,
}),
prisma.testScenario.createMany({
data: scenariosToCreate,
}),
prisma.scenarioVariantCell.createMany({
data: cellsToCreate,
}),
prisma.modelResponse.createMany({
data: modelResponsesToCreate,
}),
prisma.evaluation.createMany({
data: evaluationsToCreate,
}),
prisma.outputEvaluation.createMany({
data: outputEvaluationsToCreate,
}),
prisma.templateVariable.createMany({
data: templateVariablesToCreate,
}),
]);
await prisma.$transaction([ return newExperimentId;
prisma.experiment.create({ }),
create: protectedProcedure
.input(z.object({ organizationId: z.string() }))
.mutation(async ({ input, ctx }) => {
await requireCanModifyOrganization(input.organizationId, ctx);
const maxSortIndex =
(
await prisma.experiment.aggregate({
_max: {
sortIndex: true,
},
where: { organizationId: input.organizationId },
})
)._max?.sortIndex ?? 0;
const exp = await prisma.experiment.create({
data: { data: {
id: newExperimentId,
sortIndex: maxSortIndex + 1, sortIndex: maxSortIndex + 1,
label: `${existingExp.label} (forked)`, label: `Experiment ${maxSortIndex + 1}`,
organizationId: (await userOrg(ctx.session.user.id)).id, organizationId: input.organizationId,
}, },
}), });
prisma.promptVariant.createMany({
data: variantsToCreate,
}),
prisma.testScenario.createMany({
data: scenariosToCreate,
}),
prisma.scenarioVariantCell.createMany({
data: cellsToCreate,
}),
prisma.modelResponse.createMany({
data: modelResponsesToCreate,
}),
prisma.evaluation.createMany({
data: evaluationsToCreate,
}),
prisma.outputEvaluation.createMany({
data: outputEvaluationsToCreate,
}),
prisma.templateVariable.createMany({
data: templateVariablesToCreate,
}),
]);
return newExperimentId; const [variant, _, scenario1, scenario2, scenario3] = await prisma.$transaction([
}), prisma.promptVariant.create({
data: {
create: protectedProcedure.input(z.object({})).mutation(async ({ ctx }) => { experimentId: exp.id,
// Anyone can create an experiment label: "Prompt Variant 1",
requireNothing(ctx); sortIndex: 0,
// The interpolated $ is necessary until dedent incorporates
const organizationId = (await userOrg(ctx.session.user.id)).id; // https://github.com/dmnd/dedent/pull/46
promptConstructor: dedent`
const maxSortIndex =
(
await prisma.experiment.aggregate({
_max: {
sortIndex: true,
},
where: { organizationId },
})
)._max?.sortIndex ?? 0;
const exp = await prisma.experiment.create({
data: {
sortIndex: maxSortIndex + 1,
label: `Experiment ${maxSortIndex + 1}`,
organizationId,
},
});
const [variant, _, scenario1, scenario2, scenario3] = await prisma.$transaction([
prisma.promptVariant.create({
data: {
experimentId: exp.id,
label: "Prompt Variant 1",
sortIndex: 0,
// The interpolated $ is necessary until dedent incorporates
// https://github.com/dmnd/dedent/pull/46
promptConstructor: dedent`
/** /**
* Use Javascript to define an OpenAI chat completion * Use Javascript to define an OpenAI chat completion
* (https://platform.openai.com/docs/api-reference/chat/create). * (https://platform.openai.com/docs/api-reference/chat/create).
@@ -339,49 +343,49 @@ export const experimentsRouter = createTRPCRouter({
}, },
], ],
});`, });`,
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
modelProvider: "openai/ChatCompletion", modelProvider: "openai/ChatCompletion",
promptConstructorVersion, promptConstructorVersion,
},
}),
prisma.templateVariable.create({
data: {
experimentId: exp.id,
label: "language",
},
}),
prisma.testScenario.create({
data: {
experimentId: exp.id,
variableValues: {
language: "English",
}, },
}, }),
}), prisma.templateVariable.create({
prisma.testScenario.create({ data: {
data: { experimentId: exp.id,
experimentId: exp.id, label: "language",
variableValues: {
language: "Spanish",
}, },
}, }),
}), prisma.testScenario.create({
prisma.testScenario.create({ data: {
data: { experimentId: exp.id,
experimentId: exp.id, variableValues: {
variableValues: { language: "English",
language: "German", },
}, },
}, }),
}), prisma.testScenario.create({
]); data: {
experimentId: exp.id,
variableValues: {
language: "Spanish",
},
},
}),
prisma.testScenario.create({
data: {
experimentId: exp.id,
variableValues: {
language: "German",
},
},
}),
]);
await generateNewCell(variant.id, scenario1.id); await generateNewCell(variant.id, scenario1.id);
await generateNewCell(variant.id, scenario2.id); await generateNewCell(variant.id, scenario2.id);
await generateNewCell(variant.id, scenario3.id); await generateNewCell(variant.id, scenario3.id);
return exp; return exp;
}), }),
update: protectedProcedure update: protectedProcedure
.input(z.object({ id: z.string(), updates: z.object({ label: z.string() }) })) .input(z.object({ id: z.string(), updates: z.object({ label: z.string() }) }))