Ensure evals run properly (#96)
* Run evals against llama output * Continue polling in OutputCell until evals complete * Remove unnecessary check
This commit is contained in:
@@ -63,6 +63,7 @@ export default function OutputCell({
|
||||
|
||||
const awaitingOutput =
|
||||
!cell ||
|
||||
!cell.evalsComplete ||
|
||||
cell.retrievalStatus === "PENDING" ||
|
||||
cell.retrievalStatus === "IN_PROGRESS" ||
|
||||
hardRefetching;
|
||||
|
||||
@@ -19,30 +19,45 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
|
||||
});
|
||||
await requireCanViewExperiment(experimentId, ctx);
|
||||
|
||||
return await prisma.scenarioVariantCell.findUnique({
|
||||
where: {
|
||||
promptVariantId_testScenarioId: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenarioId: input.scenarioId,
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelResponses: {
|
||||
where: {
|
||||
outdated: false,
|
||||
const [cell, numTotalEvals] = await prisma.$transaction([
|
||||
prisma.scenarioVariantCell.findUnique({
|
||||
where: {
|
||||
promptVariantId_testScenarioId: {
|
||||
promptVariantId: input.variantId,
|
||||
testScenarioId: input.scenarioId,
|
||||
},
|
||||
include: {
|
||||
outputEvaluations: {
|
||||
include: {
|
||||
evaluation: {
|
||||
select: { label: true },
|
||||
},
|
||||
include: {
|
||||
modelResponses: {
|
||||
where: {
|
||||
outdated: false,
|
||||
},
|
||||
include: {
|
||||
outputEvaluations: {
|
||||
include: {
|
||||
evaluation: {
|
||||
select: { label: true },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
}),
|
||||
prisma.evaluation.count({
|
||||
where: { experimentId },
|
||||
}),
|
||||
]);
|
||||
|
||||
if (!cell) return null;
|
||||
|
||||
const lastResponse = cell.modelResponses?.[cell.modelResponses?.length - 1];
|
||||
const evalsComplete = lastResponse?.outputEvaluations?.length === numTotalEvals;
|
||||
|
||||
return {
|
||||
...cell,
|
||||
evalsComplete,
|
||||
};
|
||||
}),
|
||||
forceRefetch: protectedProcedure
|
||||
.input(
|
||||
|
||||
@@ -99,7 +99,7 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
|
||||
const inputHash = hashPrompt(prompt);
|
||||
|
||||
for (let i = 0; true; i++) {
|
||||
const modelResponse = await prisma.modelResponse.create({
|
||||
let modelResponse = await prisma.modelResponse.create({
|
||||
data: {
|
||||
inputHash,
|
||||
scenarioVariantCellId: cellId,
|
||||
@@ -108,7 +108,7 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
|
||||
});
|
||||
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
||||
if (response.type === "success") {
|
||||
await prisma.modelResponse.update({
|
||||
modelResponse = await prisma.modelResponse.update({
|
||||
where: { id: modelResponse.id },
|
||||
data: {
|
||||
output: response.value as Prisma.InputJsonObject,
|
||||
@@ -127,7 +127,7 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
|
||||
},
|
||||
});
|
||||
|
||||
await runEvalsForOutput(variant.experimentId, scenario, modelResponse);
|
||||
await runEvalsForOutput(variant.experimentId, scenario, modelResponse, prompt.modelProvider);
|
||||
break;
|
||||
} else {
|
||||
const shouldRetry = response.autoRetry && i < MAX_AUTO_RETRIES;
|
||||
|
||||
@@ -2,13 +2,15 @@ import { type ModelResponse, type Evaluation, Prisma } from "@prisma/client";
|
||||
import { prisma } from "../db";
|
||||
import { runOneEval } from "./runOneEval";
|
||||
import { type Scenario } from "~/components/OutputsTable/types";
|
||||
import { type SupportedProvider } from "~/modelProviders/types";
|
||||
|
||||
const saveResult = async (
|
||||
const runAndSaveEval = async (
|
||||
evaluation: Evaluation,
|
||||
scenario: Scenario,
|
||||
modelResponse: ModelResponse,
|
||||
provider: SupportedProvider,
|
||||
) => {
|
||||
const result = await runOneEval(evaluation, scenario, modelResponse);
|
||||
const result = await runOneEval(evaluation, scenario, modelResponse, provider);
|
||||
return await prisma.outputEvaluation.upsert({
|
||||
where: {
|
||||
modelResponseId_evaluationId: {
|
||||
@@ -31,13 +33,16 @@ export const runEvalsForOutput = async (
|
||||
experimentId: string,
|
||||
scenario: Scenario,
|
||||
modelResponse: ModelResponse,
|
||||
provider: SupportedProvider,
|
||||
) => {
|
||||
const evaluations = await prisma.evaluation.findMany({
|
||||
where: { experimentId },
|
||||
});
|
||||
|
||||
await Promise.all(
|
||||
evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelResponse)),
|
||||
evaluations.map(
|
||||
async (evaluation) => await runAndSaveEval(evaluation, scenario, modelResponse, provider),
|
||||
),
|
||||
);
|
||||
};
|
||||
|
||||
@@ -62,6 +67,7 @@ export const runAllEvals = async (experimentId: string) => {
|
||||
scenarioVariantCell: {
|
||||
include: {
|
||||
testScenario: true,
|
||||
promptVariant: true,
|
||||
},
|
||||
},
|
||||
outputEvaluations: true,
|
||||
@@ -73,13 +79,18 @@ export const runAllEvals = async (experimentId: string) => {
|
||||
|
||||
await Promise.all(
|
||||
outputs.map(async (output) => {
|
||||
const unrunEvals = evals.filter(
|
||||
const evalsToBeRun = evals.filter(
|
||||
(evaluation) => !output.outputEvaluations.find((e) => e.evaluationId === evaluation.id),
|
||||
);
|
||||
|
||||
await Promise.all(
|
||||
unrunEvals.map(async (evaluation) => {
|
||||
await saveResult(evaluation, output.scenarioVariantCell.testScenario, output);
|
||||
evalsToBeRun.map(async (evaluation) => {
|
||||
await runAndSaveEval(
|
||||
evaluation,
|
||||
output.scenarioVariantCell.testScenario,
|
||||
output,
|
||||
output.scenarioVariantCell.promptVariant.modelProvider as SupportedProvider,
|
||||
);
|
||||
}),
|
||||
);
|
||||
}),
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import { type Evaluation, type ModelResponse, type TestScenario } from "@prisma/client";
|
||||
import { type ChatCompletion } from "openai/resources/chat";
|
||||
import { type VariableMap, fillTemplate, escapeRegExp, escapeQuotes } from "./fillTemplate";
|
||||
import { openai } from "./openai";
|
||||
import dedent from "dedent";
|
||||
import modelProviders from "~/modelProviders/modelProviders";
|
||||
import { type SupportedProvider } from "~/modelProviders/types";
|
||||
|
||||
export const runGpt4Eval = async (
|
||||
evaluation: Evaluation,
|
||||
scenario: TestScenario,
|
||||
message: ChatCompletion.Choice.Message,
|
||||
stringifiedOutput: string,
|
||||
): Promise<{ result: number; details: string }> => {
|
||||
const output = await openai.chat.completions.create({
|
||||
model: "gpt-4-0613",
|
||||
@@ -26,11 +27,7 @@ export const runGpt4Eval = async (
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: `The full output of the simpler message:\n---\n${JSON.stringify(
|
||||
message.content ?? message.function_call,
|
||||
null,
|
||||
2,
|
||||
)}`,
|
||||
content: `The full output of the simpler message:\n---\n${stringifiedOutput}`,
|
||||
},
|
||||
],
|
||||
function_call: {
|
||||
@@ -71,14 +68,15 @@ export const runOneEval = async (
|
||||
evaluation: Evaluation,
|
||||
scenario: TestScenario,
|
||||
modelResponse: ModelResponse,
|
||||
provider: SupportedProvider,
|
||||
): Promise<{ result: number; details?: string }> => {
|
||||
const output = modelResponse.output as unknown as ChatCompletion;
|
||||
|
||||
const message = output?.choices?.[0]?.message;
|
||||
const modelProvider = modelProviders[provider];
|
||||
const message = modelProvider.normalizeOutput(modelResponse.output);
|
||||
|
||||
if (!message) return { result: 0 };
|
||||
|
||||
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
|
||||
const stringifiedOutput =
|
||||
message.type === "json" ? JSON.stringify(message.value, null, 2) : message.value;
|
||||
|
||||
const matchRegex = escapeRegExp(
|
||||
fillTemplate(escapeQuotes(evaluation.value), scenario.variableValues as VariableMap),
|
||||
@@ -86,10 +84,10 @@ export const runOneEval = async (
|
||||
|
||||
switch (evaluation.evalType) {
|
||||
case "CONTAINS":
|
||||
return { result: stringifiedMessage.match(matchRegex) !== null ? 1 : 0 };
|
||||
return { result: stringifiedOutput.match(matchRegex) !== null ? 1 : 0 };
|
||||
case "DOES_NOT_CONTAIN":
|
||||
return { result: stringifiedMessage.match(matchRegex) === null ? 1 : 0 };
|
||||
return { result: stringifiedOutput.match(matchRegex) === null ? 1 : 0 };
|
||||
case "GPT4_EVAL":
|
||||
return await runGpt4Eval(evaluation, scenario, message);
|
||||
return await runGpt4Eval(evaluation, scenario, stringifiedOutput);
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user