add evaluations

This commit is contained in:
Kyle Corbitt
2023-07-06 13:39:13 -07:00
parent 1ae5612d55
commit f728027ef6
18 changed files with 614 additions and 68 deletions

View File

@@ -0,0 +1,91 @@
import { type Evaluation } from "@prisma/client";
import { prisma } from "../db";
import { evaluateOutput } from "./evaluateOutput";
export const reevaluateVariant = async (variantId: string) => {
const variant = await prisma.promptVariant.findUnique({
where: { id: variantId },
});
if (!variant) return;
const evaluations = await prisma.evaluation.findMany({
where: { experimentId: variant.experimentId },
});
const modelOutputs = await prisma.modelOutput.findMany({
where: { promptVariantId: variantId },
include: { testScenario: true },
});
const scenarios = await prisma.testScenario.findMany({
where: { experimentId: variant.experimentId, visible: true },
});
await Promise.all(
evaluations.map(async (evaluation) => {
const passCount = modelOutputs.filter((output) =>
evaluateOutput(output, output.testScenario, evaluation)
).length;
const failCount = scenarios.length - passCount;
await prisma.evaluationResult.upsert({
where: {
evaluationId_promptVariantId: {
evaluationId: evaluation.id,
promptVariantId: variantId,
},
},
create: {
evaluationId: evaluation.id,
promptVariantId: variantId,
passCount,
failCount,
},
update: {
passCount,
failCount,
},
});
})
);
};
export const reevaluateEvaluation = async (evaluation: Evaluation) => {
const variants = await prisma.promptVariant.findMany({
where: { experimentId: evaluation.experimentId, visible: true },
});
const modelOutputs = await prisma.modelOutput.findMany({
where: { promptVariantId: { in: variants.map((v) => v.id) }, testScenario: { visible: true } },
include: { testScenario: true },
});
await Promise.all(
variants.map(async (variant) => {
const outputs = modelOutputs.filter((output) => output.promptVariantId === variant.id);
const passCount = outputs.filter((output) =>
evaluateOutput(output, output.testScenario, evaluation)
).length;
const failCount = outputs.length - passCount;
await prisma.evaluationResult.upsert({
where: {
evaluationId_promptVariantId: {
evaluationId: evaluation.id,
promptVariantId: variant.id,
},
},
create: {
evaluationId: evaluation.id,
promptVariantId: variant.id,
passCount,
failCount,
},
update: {
passCount,
failCount,
},
});
})
);
};