cache output evals
This commit is contained in:
@@ -41,7 +41,6 @@ model PromptVariant {
|
|||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
scenarioVariantCells ScenarioVariantCell[]
|
scenarioVariantCells ScenarioVariantCell[]
|
||||||
EvaluationResult EvaluationResult[]
|
|
||||||
|
|
||||||
@@index([uiId])
|
@@index([uiId])
|
||||||
}
|
}
|
||||||
@@ -124,6 +123,7 @@ model ModelOutput {
|
|||||||
|
|
||||||
scenarioVariantCellId String @db.Uuid
|
scenarioVariantCellId String @db.Uuid
|
||||||
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
|
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
|
||||||
|
outputEvaluation OutputEvaluation[]
|
||||||
|
|
||||||
@@unique([scenarioVariantCellId])
|
@@unique([scenarioVariantCellId])
|
||||||
@@index([inputHash])
|
@@index([inputHash])
|
||||||
@@ -146,25 +146,26 @@ model Evaluation {
|
|||||||
|
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
EvaluationResult EvaluationResult[]
|
OutputEvaluation OutputEvaluation[]
|
||||||
}
|
}
|
||||||
|
|
||||||
model EvaluationResult {
|
model OutputEvaluation {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
passCount Int
|
// Number between 0 (fail) and 1 (pass)
|
||||||
failCount Int
|
result Float
|
||||||
|
details String?
|
||||||
|
|
||||||
|
modelOutputId String @db.Uuid
|
||||||
|
modelOutput ModelOutput @relation(fields: [modelOutputId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
evaluationId String @db.Uuid
|
evaluationId String @db.Uuid
|
||||||
evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade)
|
evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
promptVariantId String @db.Uuid
|
|
||||||
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
|
|
||||||
|
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
@@unique([evaluationId, promptVariantId])
|
@@unique([modelOutputId, evaluationId])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Necessary for Next auth
|
// Necessary for Next auth
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ export function EvaluationEditor(props: {
|
|||||||
<Input
|
<Input
|
||||||
size="sm"
|
size="sm"
|
||||||
value={values.label}
|
value={values.label}
|
||||||
onChange={(e) => setValues((values) => ({ ...values, name: e.target.value }))}
|
onChange={(e) => setValues((values) => ({ ...values, label: e.target.value }))}
|
||||||
/>
|
/>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
<FormControl flex={1}>
|
<FormControl flex={1}>
|
||||||
@@ -125,6 +125,7 @@ export default function EditEvaluations() {
|
|||||||
}
|
}
|
||||||
await utils.evaluations.list.invalidate();
|
await utils.evaluations.list.invalidate();
|
||||||
await utils.promptVariants.stats.invalidate();
|
await utils.promptVariants.stats.invalidate();
|
||||||
|
await utils.scenarioVariantCells.get.invalidate();
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const onCancel = useCallback(() => {
|
const onCancel = useCallback(() => {
|
||||||
|
|||||||
@@ -1,10 +1,7 @@
|
|||||||
import { type ModelOutput } from "@prisma/client";
|
|
||||||
import { type SupportedModel } from "~/server/types";
|
import { type SupportedModel } from "~/server/types";
|
||||||
import { type Scenario } from "../types";
|
import { type Scenario } from "../types";
|
||||||
import { useExperiment } from "~/utils/hooks";
|
import { type RouterOutputs } from "~/utils/api";
|
||||||
import { api } from "~/utils/api";
|
|
||||||
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
||||||
import { evaluateOutput } from "~/server/utils/evaluateOutput";
|
|
||||||
import { HStack, Icon, Text } from "@chakra-ui/react";
|
import { HStack, Icon, Text } from "@chakra-ui/react";
|
||||||
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
|
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
|
||||||
import { CostTooltip } from "~/components/tooltip/CostTooltip";
|
import { CostTooltip } from "~/components/tooltip/CostTooltip";
|
||||||
@@ -15,16 +12,14 @@ const SHOW_TIME = true;
|
|||||||
export const OutputStats = ({
|
export const OutputStats = ({
|
||||||
model,
|
model,
|
||||||
modelOutput,
|
modelOutput,
|
||||||
scenario,
|
|
||||||
}: {
|
}: {
|
||||||
model: SupportedModel | string | null;
|
model: SupportedModel | string | null;
|
||||||
modelOutput: ModelOutput;
|
modelOutput: NonNullable<
|
||||||
|
NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelOutput"]
|
||||||
|
>;
|
||||||
scenario: Scenario;
|
scenario: Scenario;
|
||||||
}) => {
|
}) => {
|
||||||
const timeToComplete = modelOutput.timeToComplete;
|
const timeToComplete = modelOutput.timeToComplete;
|
||||||
const experiment = useExperiment();
|
|
||||||
const evals =
|
|
||||||
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
|
|
||||||
|
|
||||||
const promptTokens = modelOutput.promptTokens;
|
const promptTokens = modelOutput.promptTokens;
|
||||||
const completionTokens = modelOutput.completionTokens;
|
const completionTokens = modelOutput.completionTokens;
|
||||||
@@ -38,11 +33,11 @@ export const OutputStats = ({
|
|||||||
return (
|
return (
|
||||||
<HStack align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
|
<HStack align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
|
||||||
<HStack flex={1}>
|
<HStack flex={1}>
|
||||||
{evals.map((evaluation) => {
|
{modelOutput.outputEvaluation.map((evaluation) => {
|
||||||
const passed = evaluateOutput(modelOutput, scenario, evaluation);
|
const passed = evaluation.result > 0.5;
|
||||||
return (
|
return (
|
||||||
<HStack spacing={0} key={evaluation.id}>
|
<HStack spacing={0} key={evaluation.id}>
|
||||||
<Text>{evaluation.label}</Text>
|
<Text>{evaluation.evaluation.label}</Text>
|
||||||
<Icon
|
<Icon
|
||||||
as={passed ? BsCheck : BsX}
|
as={passed ? BsCheck : BsX}
|
||||||
color={passed ? "green.500" : "red.500"}
|
color={passed ? "green.500" : "red.500"}
|
||||||
|
|||||||
@@ -44,10 +44,10 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
|||||||
)}
|
)}
|
||||||
<HStack px={cellPadding.x} py={cellPadding.y}>
|
<HStack px={cellPadding.x} py={cellPadding.y}>
|
||||||
{data.evalResults.map((result) => {
|
{data.evalResults.map((result) => {
|
||||||
const passedFrac = result.passCount / (result.passCount + result.failCount);
|
const passedFrac = result.passCount / result.totalCount;
|
||||||
return (
|
return (
|
||||||
<HStack key={result.id}>
|
<HStack key={result.id}>
|
||||||
<Text>{result.evaluation.label}</Text>
|
<Text>{result.label}</Text>
|
||||||
<Text color={scale(passedFrac).hex()} fontWeight="bold">
|
<Text color={scale(passedFrac).hex()} fontWeight="bold">
|
||||||
{(passedFrac * 100).toFixed(1)}%
|
{(passedFrac * 100).toFixed(1)}%
|
||||||
</Text>
|
</Text>
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import { EvalType } from "@prisma/client";
|
|||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import { reevaluateEvaluation } from "~/server/utils/evaluations";
|
import { runAllEvals } from "~/server/utils/evaluations";
|
||||||
|
|
||||||
export const evaluationsRouter = createTRPCRouter({
|
export const evaluationsRouter = createTRPCRouter({
|
||||||
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
||||||
@@ -24,7 +24,7 @@ export const evaluationsRouter = createTRPCRouter({
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input }) => {
|
||||||
const evaluation = await prisma.evaluation.create({
|
await prisma.evaluation.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
label: input.label,
|
label: input.label,
|
||||||
@@ -32,7 +32,10 @@ export const evaluationsRouter = createTRPCRouter({
|
|||||||
evalType: input.evalType,
|
evalType: input.evalType,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
await reevaluateEvaluation(evaluation);
|
|
||||||
|
// TODO: this may be a bad UX for slow evals (eg. GPT-4 evals) Maybe need
|
||||||
|
// to kick off a background job or something instead
|
||||||
|
await runAllEvals(input.experimentId);
|
||||||
}),
|
}),
|
||||||
|
|
||||||
update: publicProcedure
|
update: publicProcedure
|
||||||
@@ -40,24 +43,30 @@ export const evaluationsRouter = createTRPCRouter({
|
|||||||
z.object({
|
z.object({
|
||||||
id: z.string(),
|
id: z.string(),
|
||||||
updates: z.object({
|
updates: z.object({
|
||||||
name: z.string().optional(),
|
label: z.string().optional(),
|
||||||
value: z.string().optional(),
|
value: z.string().optional(),
|
||||||
evalType: z.nativeEnum(EvalType).optional(),
|
evalType: z.nativeEnum(EvalType).optional(),
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input }) => {
|
||||||
await prisma.evaluation.update({
|
const evaluation = await prisma.evaluation.update({
|
||||||
where: { id: input.id },
|
where: { id: input.id },
|
||||||
data: {
|
data: {
|
||||||
label: input.updates.name,
|
label: input.updates.label,
|
||||||
value: input.updates.value,
|
value: input.updates.value,
|
||||||
evalType: input.updates.evalType,
|
evalType: input.updates.evalType,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
await reevaluateEvaluation(
|
|
||||||
await prisma.evaluation.findUniqueOrThrow({ where: { id: input.id } }),
|
await prisma.outputEvaluation.deleteMany({
|
||||||
);
|
where: {
|
||||||
|
evaluationId: evaluation.id,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
// Re-run all evals. Other eval results will already be cached, so this
|
||||||
|
// should only re-run the updated one.
|
||||||
|
await runAllEvals(evaluation.experimentId);
|
||||||
}),
|
}),
|
||||||
|
|
||||||
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
|
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
|
||||||
|
|||||||
@@ -32,11 +32,43 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
|
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
|
||||||
}
|
}
|
||||||
|
|
||||||
const evalResults = await prisma.evaluationResult.findMany({
|
const outputEvals = await prisma.outputEvaluation.groupBy({
|
||||||
where: {
|
by: ["evaluationId"],
|
||||||
promptVariantId: input.variantId,
|
_sum: {
|
||||||
|
result: true,
|
||||||
},
|
},
|
||||||
include: { evaluation: true },
|
_count: {
|
||||||
|
id: true,
|
||||||
|
},
|
||||||
|
where: {
|
||||||
|
modelOutput: {
|
||||||
|
scenarioVariantCell: {
|
||||||
|
promptVariant: {
|
||||||
|
id: input.variantId,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
testScenario: {
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const evals = await prisma.evaluation.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: variant.experimentId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const evalResults = evals.map((evalItem) => {
|
||||||
|
const evalResult = outputEvals.find((outputEval) => outputEval.evaluationId === evalItem.id);
|
||||||
|
return {
|
||||||
|
id: evalItem.id,
|
||||||
|
label: evalItem.label,
|
||||||
|
passCount: evalResult?._sum?.result ?? 0,
|
||||||
|
totalCount: evalResult?._count?.id ?? 1,
|
||||||
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
const scenarioCount = await prisma.testScenario.count({
|
const scenarioCount = await prisma.testScenario.count({
|
||||||
@@ -50,7 +82,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
promptVariantId: input.variantId,
|
promptVariantId: input.variantId,
|
||||||
testScenario: { visible: true },
|
testScenario: { visible: true },
|
||||||
modelOutput: {
|
modelOutput: {
|
||||||
isNot: null,
|
is: {},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -21,7 +21,17 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
include: {
|
include: {
|
||||||
modelOutput: true,
|
modelOutput: {
|
||||||
|
include: {
|
||||||
|
outputEvaluation: {
|
||||||
|
include: {
|
||||||
|
evaluation: {
|
||||||
|
select: { label: true },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
}),
|
}),
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
|||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import { autogenerateScenarioValues } from "../autogen";
|
import { autogenerateScenarioValues } from "../autogen";
|
||||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||||
import { reevaluateAll } from "~/server/utils/evaluations";
|
import { runAllEvals } from "~/server/utils/evaluations";
|
||||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
|
|
||||||
export const scenariosRouter = createTRPCRouter({
|
export const scenariosRouter = createTRPCRouter({
|
||||||
@@ -73,7 +73,7 @@ export const scenariosRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Reevaluate all evaluations now that this scenario is hidden
|
// Reevaluate all evaluations now that this scenario is hidden
|
||||||
await reevaluateAll(hiddenScenario.experimentId);
|
await runAllEvals(hiddenScenario.experimentId);
|
||||||
|
|
||||||
return hiddenScenario;
|
return hiddenScenario;
|
||||||
}),
|
}),
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import { type JSONSerializable } from "../types";
|
|||||||
import { sleep } from "../utils/sleep";
|
import { sleep } from "../utils/sleep";
|
||||||
import { shouldStream } from "../utils/shouldStream";
|
import { shouldStream } from "../utils/shouldStream";
|
||||||
import { generateChannel } from "~/utils/generateChannel";
|
import { generateChannel } from "~/utils/generateChannel";
|
||||||
import { reevaluateVariant } from "../utils/evaluations";
|
import { runEvalsForOutput } from "../utils/evaluations";
|
||||||
import { constructPrompt } from "../utils/constructPrompt";
|
import { constructPrompt } from "../utils/constructPrompt";
|
||||||
import { type CompletionCreateParams } from "openai/resources/chat";
|
import { type CompletionCreateParams } from "openai/resources/chat";
|
||||||
import { type Prisma } from "@prisma/client";
|
import { type Prisma } from "@prisma/client";
|
||||||
@@ -148,5 +148,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await reevaluateVariant(cell.promptVariantId);
|
if (modelOutput) {
|
||||||
|
await runEvalsForOutput(variant.experimentId, scenario, modelOutput);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,105 +1,154 @@
|
|||||||
import { type ModelOutput, type Evaluation } from "@prisma/client";
|
import { type ModelOutput, type Evaluation } from "@prisma/client";
|
||||||
import { prisma } from "../db";
|
import { prisma } from "../db";
|
||||||
import { evaluateOutput } from "./evaluateOutput";
|
import { runOneEval } from "./runOneEval";
|
||||||
|
import { type Scenario } from "~/components/OutputsTable/types";
|
||||||
|
|
||||||
export const reevaluateVariant = async (variantId: string) => {
|
const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutput: ModelOutput) => {
|
||||||
const variant = await prisma.promptVariant.findUnique({
|
const result = runOneEval(evaluation, scenario, modelOutput);
|
||||||
where: { id: variantId },
|
return await prisma.outputEvaluation.upsert({
|
||||||
});
|
|
||||||
if (!variant) return;
|
|
||||||
|
|
||||||
const evaluations = await prisma.evaluation.findMany({
|
|
||||||
where: { experimentId: variant.experimentId },
|
|
||||||
});
|
|
||||||
|
|
||||||
const cells = await prisma.scenarioVariantCell.findMany({
|
|
||||||
where: {
|
where: {
|
||||||
promptVariantId: variantId,
|
modelOutputId_evaluationId: {
|
||||||
retrievalStatus: "COMPLETE",
|
modelOutputId: modelOutput.id,
|
||||||
testScenario: { visible: true },
|
|
||||||
modelOutput: { isNot: null },
|
|
||||||
},
|
|
||||||
include: { testScenario: true, modelOutput: true },
|
|
||||||
});
|
|
||||||
|
|
||||||
await Promise.all(
|
|
||||||
evaluations.map(async (evaluation) => {
|
|
||||||
const passCount = cells.filter((cell) =>
|
|
||||||
evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
|
|
||||||
).length;
|
|
||||||
const failCount = cells.length - passCount;
|
|
||||||
|
|
||||||
await prisma.evaluationResult.upsert({
|
|
||||||
where: {
|
|
||||||
evaluationId_promptVariantId: {
|
|
||||||
evaluationId: evaluation.id,
|
evaluationId: evaluation.id,
|
||||||
promptVariantId: variantId,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
create: {
|
create: {
|
||||||
|
modelOutputId: modelOutput.id,
|
||||||
evaluationId: evaluation.id,
|
evaluationId: evaluation.id,
|
||||||
promptVariantId: variantId,
|
result,
|
||||||
passCount,
|
|
||||||
failCount,
|
|
||||||
},
|
},
|
||||||
update: {
|
update: {
|
||||||
passCount,
|
result,
|
||||||
failCount,
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
}),
|
|
||||||
);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export const reevaluateEvaluation = async (evaluation: Evaluation) => {
|
export const runEvalsForOutput = async (
|
||||||
const variants = await prisma.promptVariant.findMany({
|
experimentId: string,
|
||||||
where: { experimentId: evaluation.experimentId, visible: true },
|
scenario: Scenario,
|
||||||
});
|
modelOutput: ModelOutput,
|
||||||
|
) => {
|
||||||
const cells = await prisma.scenarioVariantCell.findMany({
|
|
||||||
where: {
|
|
||||||
promptVariantId: { in: variants.map((v) => v.id) },
|
|
||||||
testScenario: { visible: true },
|
|
||||||
statusCode: { notIn: [429] },
|
|
||||||
modelOutput: { isNot: null },
|
|
||||||
},
|
|
||||||
include: { testScenario: true, modelOutput: true },
|
|
||||||
});
|
|
||||||
|
|
||||||
await Promise.all(
|
|
||||||
variants.map(async (variant) => {
|
|
||||||
const variantCells = cells.filter((cell) => cell.promptVariantId === variant.id);
|
|
||||||
const passCount = variantCells.filter((cell) =>
|
|
||||||
evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
|
|
||||||
).length;
|
|
||||||
const failCount = variantCells.length - passCount;
|
|
||||||
|
|
||||||
await prisma.evaluationResult.upsert({
|
|
||||||
where: {
|
|
||||||
evaluationId_promptVariantId: {
|
|
||||||
evaluationId: evaluation.id,
|
|
||||||
promptVariantId: variant.id,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
create: {
|
|
||||||
evaluationId: evaluation.id,
|
|
||||||
promptVariantId: variant.id,
|
|
||||||
passCount,
|
|
||||||
failCount,
|
|
||||||
},
|
|
||||||
update: {
|
|
||||||
passCount,
|
|
||||||
failCount,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export const reevaluateAll = async (experimentId: string) => {
|
|
||||||
const evaluations = await prisma.evaluation.findMany({
|
const evaluations = await prisma.evaluation.findMany({
|
||||||
where: { experimentId },
|
where: { experimentId },
|
||||||
});
|
});
|
||||||
|
|
||||||
await Promise.all(evaluations.map(reevaluateEvaluation));
|
await Promise.all(
|
||||||
|
evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelOutput)),
|
||||||
|
);
|
||||||
|
|
||||||
|
// const cells = await prisma.scenarioVariantCell.findMany({
|
||||||
|
// where: {
|
||||||
|
// promptVariantId: variantId,
|
||||||
|
// retrievalStatus: "COMPLETE",
|
||||||
|
// testScenario: { visible: true },
|
||||||
|
// },
|
||||||
|
// include: { testScenario: true, modelOutput: { include: { OutputEvaluation: true } } },
|
||||||
|
// });
|
||||||
|
|
||||||
|
// await Promise.all(
|
||||||
|
// evaluations.map(async (evaluation) => {
|
||||||
|
// const passCount = cells.filter((cell) =>
|
||||||
|
// runOneEval(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
|
||||||
|
// ).length;
|
||||||
|
// const failCount = cells.length - passCount;
|
||||||
|
|
||||||
|
// await prisma.evaluationResult.upsert({
|
||||||
|
// where: {
|
||||||
|
// evaluationId_promptVariantId: {
|
||||||
|
// evaluationId: evaluation.id,
|
||||||
|
// promptVariantId: variantId,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// create: {
|
||||||
|
// evaluationId: evaluation.id,
|
||||||
|
// promptVariantId: variantId,
|
||||||
|
// passCount,
|
||||||
|
// failCount,
|
||||||
|
// },
|
||||||
|
// update: {
|
||||||
|
// passCount,
|
||||||
|
// failCount,
|
||||||
|
// },
|
||||||
|
// });
|
||||||
|
// }),
|
||||||
|
// );
|
||||||
|
};
|
||||||
|
|
||||||
|
export const runAllEvals = async (experimentId: string) => {
|
||||||
|
const outputs = await prisma.modelOutput.findMany({
|
||||||
|
where: {
|
||||||
|
scenarioVariantCell: {
|
||||||
|
promptVariant: {
|
||||||
|
experimentId,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
testScenario: {
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
include: {
|
||||||
|
scenarioVariantCell: {
|
||||||
|
include: {
|
||||||
|
testScenario: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
outputEvaluation: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
const evals = await prisma.evaluation.findMany({
|
||||||
|
where: { experimentId },
|
||||||
|
});
|
||||||
|
|
||||||
|
await Promise.all(
|
||||||
|
outputs.map(async (output) => {
|
||||||
|
const unrunEvals = evals.filter(
|
||||||
|
(evaluation) => !output.outputEvaluation.find((e) => e.evaluationId === evaluation.id),
|
||||||
|
);
|
||||||
|
|
||||||
|
await Promise.all(
|
||||||
|
unrunEvals.map(async (evaluation) => {
|
||||||
|
await saveResult(evaluation, output.scenarioVariantCell.testScenario, output);
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// const cells = await prisma.scenarioVariantCell.findMany({
|
||||||
|
// where: {
|
||||||
|
// promptVariantId: { in: variants.map((v) => v.id) },
|
||||||
|
// testScenario: { visible: true },
|
||||||
|
// statusCode: { notIn: [429] },
|
||||||
|
// },
|
||||||
|
// include: { testScenario: true, modelOutput: true },
|
||||||
|
// });
|
||||||
|
|
||||||
|
// await Promise.all(
|
||||||
|
// variants.map(async (variant) => {
|
||||||
|
// const variantCells = cells.filter((cell) => cell.promptVariantId === variant.id);
|
||||||
|
// const passCount = variantCells.filter((cell) =>
|
||||||
|
// runOneEval(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
|
||||||
|
// ).length;
|
||||||
|
// const failCount = variantCells.length - passCount;
|
||||||
|
|
||||||
|
// await prisma.evaluationResult.upsert({
|
||||||
|
// where: {
|
||||||
|
// evaluationId_promptVariantId: {
|
||||||
|
// evaluationId: evaluation.id,
|
||||||
|
// promptVariantId: variant.id,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// create: {
|
||||||
|
// evaluationId: evaluation.id,
|
||||||
|
// promptVariantId: variant.id,
|
||||||
|
// passCount,
|
||||||
|
// failCount,
|
||||||
|
// },
|
||||||
|
// update: {
|
||||||
|
// passCount,
|
||||||
|
// failCount,
|
||||||
|
// },
|
||||||
|
// });
|
||||||
|
// }),
|
||||||
|
// );
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -2,30 +2,31 @@ import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/cl
|
|||||||
import { type ChatCompletion } from "openai/resources/chat";
|
import { type ChatCompletion } from "openai/resources/chat";
|
||||||
import { type VariableMap, fillTemplate } from "./fillTemplate";
|
import { type VariableMap, fillTemplate } from "./fillTemplate";
|
||||||
|
|
||||||
export const evaluateOutput = (
|
export const runOneEval = (
|
||||||
modelOutput: ModelOutput,
|
|
||||||
scenario: TestScenario,
|
|
||||||
evaluation: Evaluation,
|
evaluation: Evaluation,
|
||||||
): boolean => {
|
scenario: TestScenario,
|
||||||
|
modelOutput: ModelOutput,
|
||||||
|
): number => {
|
||||||
const output = modelOutput.output as unknown as ChatCompletion;
|
const output = modelOutput.output as unknown as ChatCompletion;
|
||||||
|
|
||||||
const message = output?.choices?.[0]?.message;
|
const message = output?.choices?.[0]?.message;
|
||||||
|
|
||||||
if (!message) return false;
|
if (!message) return 0;
|
||||||
|
|
||||||
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
|
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
|
||||||
|
|
||||||
const matchRegex = fillTemplate(evaluation.value, scenario.variableValues as VariableMap);
|
const matchRegex = fillTemplate(evaluation.value, scenario.variableValues as VariableMap);
|
||||||
|
|
||||||
let match;
|
let result;
|
||||||
|
|
||||||
switch (evaluation.evalType) {
|
switch (evaluation.evalType) {
|
||||||
case "CONTAINS":
|
case "CONTAINS":
|
||||||
match = stringifiedMessage.match(matchRegex) !== null;
|
result = stringifiedMessage.match(matchRegex) !== null ? 1 : 0;
|
||||||
break;
|
break;
|
||||||
case "DOES_NOT_CONTAIN":
|
case "DOES_NOT_CONTAIN":
|
||||||
match = stringifiedMessage.match(matchRegex) === null;
|
result = stringifiedMessage.match(matchRegex) === null ? 1 : 0;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
return match;
|
return result;
|
||||||
};
|
};
|
||||||
Reference in New Issue
Block a user