Ensure evals run properly (#96)
* Run evals against llama output * Continue polling in OutputCell until evals complete * Remove unnecessary check
This commit is contained in:
@@ -63,6 +63,7 @@ export default function OutputCell({
|
|||||||
|
|
||||||
const awaitingOutput =
|
const awaitingOutput =
|
||||||
!cell ||
|
!cell ||
|
||||||
|
!cell.evalsComplete ||
|
||||||
cell.retrievalStatus === "PENDING" ||
|
cell.retrievalStatus === "PENDING" ||
|
||||||
cell.retrievalStatus === "IN_PROGRESS" ||
|
cell.retrievalStatus === "IN_PROGRESS" ||
|
||||||
hardRefetching;
|
hardRefetching;
|
||||||
|
|||||||
@@ -19,30 +19,45 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
await requireCanViewExperiment(experimentId, ctx);
|
await requireCanViewExperiment(experimentId, ctx);
|
||||||
|
|
||||||
return await prisma.scenarioVariantCell.findUnique({
|
const [cell, numTotalEvals] = await prisma.$transaction([
|
||||||
where: {
|
prisma.scenarioVariantCell.findUnique({
|
||||||
promptVariantId_testScenarioId: {
|
where: {
|
||||||
promptVariantId: input.variantId,
|
promptVariantId_testScenarioId: {
|
||||||
testScenarioId: input.scenarioId,
|
promptVariantId: input.variantId,
|
||||||
},
|
testScenarioId: input.scenarioId,
|
||||||
},
|
|
||||||
include: {
|
|
||||||
modelResponses: {
|
|
||||||
where: {
|
|
||||||
outdated: false,
|
|
||||||
},
|
},
|
||||||
include: {
|
},
|
||||||
outputEvaluations: {
|
include: {
|
||||||
include: {
|
modelResponses: {
|
||||||
evaluation: {
|
where: {
|
||||||
select: { label: true },
|
outdated: false,
|
||||||
|
},
|
||||||
|
include: {
|
||||||
|
outputEvaluations: {
|
||||||
|
include: {
|
||||||
|
evaluation: {
|
||||||
|
select: { label: true },
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
});
|
prisma.evaluation.count({
|
||||||
|
where: { experimentId },
|
||||||
|
}),
|
||||||
|
]);
|
||||||
|
|
||||||
|
if (!cell) return null;
|
||||||
|
|
||||||
|
const lastResponse = cell.modelResponses?.[cell.modelResponses?.length - 1];
|
||||||
|
const evalsComplete = lastResponse?.outputEvaluations?.length === numTotalEvals;
|
||||||
|
|
||||||
|
return {
|
||||||
|
...cell,
|
||||||
|
evalsComplete,
|
||||||
|
};
|
||||||
}),
|
}),
|
||||||
forceRefetch: protectedProcedure
|
forceRefetch: protectedProcedure
|
||||||
.input(
|
.input(
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
|
|||||||
const inputHash = hashPrompt(prompt);
|
const inputHash = hashPrompt(prompt);
|
||||||
|
|
||||||
for (let i = 0; true; i++) {
|
for (let i = 0; true; i++) {
|
||||||
const modelResponse = await prisma.modelResponse.create({
|
let modelResponse = await prisma.modelResponse.create({
|
||||||
data: {
|
data: {
|
||||||
inputHash,
|
inputHash,
|
||||||
scenarioVariantCellId: cellId,
|
scenarioVariantCellId: cellId,
|
||||||
@@ -108,7 +108,7 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
|
|||||||
});
|
});
|
||||||
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
||||||
if (response.type === "success") {
|
if (response.type === "success") {
|
||||||
await prisma.modelResponse.update({
|
modelResponse = await prisma.modelResponse.update({
|
||||||
where: { id: modelResponse.id },
|
where: { id: modelResponse.id },
|
||||||
data: {
|
data: {
|
||||||
output: response.value as Prisma.InputJsonObject,
|
output: response.value as Prisma.InputJsonObject,
|
||||||
@@ -127,7 +127,7 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
await runEvalsForOutput(variant.experimentId, scenario, modelResponse);
|
await runEvalsForOutput(variant.experimentId, scenario, modelResponse, prompt.modelProvider);
|
||||||
break;
|
break;
|
||||||
} else {
|
} else {
|
||||||
const shouldRetry = response.autoRetry && i < MAX_AUTO_RETRIES;
|
const shouldRetry = response.autoRetry && i < MAX_AUTO_RETRIES;
|
||||||
|
|||||||
@@ -2,13 +2,15 @@ import { type ModelResponse, type Evaluation, Prisma } from "@prisma/client";
|
|||||||
import { prisma } from "../db";
|
import { prisma } from "../db";
|
||||||
import { runOneEval } from "./runOneEval";
|
import { runOneEval } from "./runOneEval";
|
||||||
import { type Scenario } from "~/components/OutputsTable/types";
|
import { type Scenario } from "~/components/OutputsTable/types";
|
||||||
|
import { type SupportedProvider } from "~/modelProviders/types";
|
||||||
|
|
||||||
const saveResult = async (
|
const runAndSaveEval = async (
|
||||||
evaluation: Evaluation,
|
evaluation: Evaluation,
|
||||||
scenario: Scenario,
|
scenario: Scenario,
|
||||||
modelResponse: ModelResponse,
|
modelResponse: ModelResponse,
|
||||||
|
provider: SupportedProvider,
|
||||||
) => {
|
) => {
|
||||||
const result = await runOneEval(evaluation, scenario, modelResponse);
|
const result = await runOneEval(evaluation, scenario, modelResponse, provider);
|
||||||
return await prisma.outputEvaluation.upsert({
|
return await prisma.outputEvaluation.upsert({
|
||||||
where: {
|
where: {
|
||||||
modelResponseId_evaluationId: {
|
modelResponseId_evaluationId: {
|
||||||
@@ -31,13 +33,16 @@ export const runEvalsForOutput = async (
|
|||||||
experimentId: string,
|
experimentId: string,
|
||||||
scenario: Scenario,
|
scenario: Scenario,
|
||||||
modelResponse: ModelResponse,
|
modelResponse: ModelResponse,
|
||||||
|
provider: SupportedProvider,
|
||||||
) => {
|
) => {
|
||||||
const evaluations = await prisma.evaluation.findMany({
|
const evaluations = await prisma.evaluation.findMany({
|
||||||
where: { experimentId },
|
where: { experimentId },
|
||||||
});
|
});
|
||||||
|
|
||||||
await Promise.all(
|
await Promise.all(
|
||||||
evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelResponse)),
|
evaluations.map(
|
||||||
|
async (evaluation) => await runAndSaveEval(evaluation, scenario, modelResponse, provider),
|
||||||
|
),
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -62,6 +67,7 @@ export const runAllEvals = async (experimentId: string) => {
|
|||||||
scenarioVariantCell: {
|
scenarioVariantCell: {
|
||||||
include: {
|
include: {
|
||||||
testScenario: true,
|
testScenario: true,
|
||||||
|
promptVariant: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
outputEvaluations: true,
|
outputEvaluations: true,
|
||||||
@@ -73,13 +79,18 @@ export const runAllEvals = async (experimentId: string) => {
|
|||||||
|
|
||||||
await Promise.all(
|
await Promise.all(
|
||||||
outputs.map(async (output) => {
|
outputs.map(async (output) => {
|
||||||
const unrunEvals = evals.filter(
|
const evalsToBeRun = evals.filter(
|
||||||
(evaluation) => !output.outputEvaluations.find((e) => e.evaluationId === evaluation.id),
|
(evaluation) => !output.outputEvaluations.find((e) => e.evaluationId === evaluation.id),
|
||||||
);
|
);
|
||||||
|
|
||||||
await Promise.all(
|
await Promise.all(
|
||||||
unrunEvals.map(async (evaluation) => {
|
evalsToBeRun.map(async (evaluation) => {
|
||||||
await saveResult(evaluation, output.scenarioVariantCell.testScenario, output);
|
await runAndSaveEval(
|
||||||
|
evaluation,
|
||||||
|
output.scenarioVariantCell.testScenario,
|
||||||
|
output,
|
||||||
|
output.scenarioVariantCell.promptVariant.modelProvider as SupportedProvider,
|
||||||
|
);
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
}),
|
}),
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
import { type Evaluation, type ModelResponse, type TestScenario } from "@prisma/client";
|
import { type Evaluation, type ModelResponse, type TestScenario } from "@prisma/client";
|
||||||
import { type ChatCompletion } from "openai/resources/chat";
|
|
||||||
import { type VariableMap, fillTemplate, escapeRegExp, escapeQuotes } from "./fillTemplate";
|
import { type VariableMap, fillTemplate, escapeRegExp, escapeQuotes } from "./fillTemplate";
|
||||||
import { openai } from "./openai";
|
import { openai } from "./openai";
|
||||||
import dedent from "dedent";
|
import dedent from "dedent";
|
||||||
|
import modelProviders from "~/modelProviders/modelProviders";
|
||||||
|
import { type SupportedProvider } from "~/modelProviders/types";
|
||||||
|
|
||||||
export const runGpt4Eval = async (
|
export const runGpt4Eval = async (
|
||||||
evaluation: Evaluation,
|
evaluation: Evaluation,
|
||||||
scenario: TestScenario,
|
scenario: TestScenario,
|
||||||
message: ChatCompletion.Choice.Message,
|
stringifiedOutput: string,
|
||||||
): Promise<{ result: number; details: string }> => {
|
): Promise<{ result: number; details: string }> => {
|
||||||
const output = await openai.chat.completions.create({
|
const output = await openai.chat.completions.create({
|
||||||
model: "gpt-4-0613",
|
model: "gpt-4-0613",
|
||||||
@@ -26,11 +27,7 @@ export const runGpt4Eval = async (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
role: "user",
|
role: "user",
|
||||||
content: `The full output of the simpler message:\n---\n${JSON.stringify(
|
content: `The full output of the simpler message:\n---\n${stringifiedOutput}`,
|
||||||
message.content ?? message.function_call,
|
|
||||||
null,
|
|
||||||
2,
|
|
||||||
)}`,
|
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
function_call: {
|
function_call: {
|
||||||
@@ -71,14 +68,15 @@ export const runOneEval = async (
|
|||||||
evaluation: Evaluation,
|
evaluation: Evaluation,
|
||||||
scenario: TestScenario,
|
scenario: TestScenario,
|
||||||
modelResponse: ModelResponse,
|
modelResponse: ModelResponse,
|
||||||
|
provider: SupportedProvider,
|
||||||
): Promise<{ result: number; details?: string }> => {
|
): Promise<{ result: number; details?: string }> => {
|
||||||
const output = modelResponse.output as unknown as ChatCompletion;
|
const modelProvider = modelProviders[provider];
|
||||||
|
const message = modelProvider.normalizeOutput(modelResponse.output);
|
||||||
const message = output?.choices?.[0]?.message;
|
|
||||||
|
|
||||||
if (!message) return { result: 0 };
|
if (!message) return { result: 0 };
|
||||||
|
|
||||||
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
|
const stringifiedOutput =
|
||||||
|
message.type === "json" ? JSON.stringify(message.value, null, 2) : message.value;
|
||||||
|
|
||||||
const matchRegex = escapeRegExp(
|
const matchRegex = escapeRegExp(
|
||||||
fillTemplate(escapeQuotes(evaluation.value), scenario.variableValues as VariableMap),
|
fillTemplate(escapeQuotes(evaluation.value), scenario.variableValues as VariableMap),
|
||||||
@@ -86,10 +84,10 @@ export const runOneEval = async (
|
|||||||
|
|
||||||
switch (evaluation.evalType) {
|
switch (evaluation.evalType) {
|
||||||
case "CONTAINS":
|
case "CONTAINS":
|
||||||
return { result: stringifiedMessage.match(matchRegex) !== null ? 1 : 0 };
|
return { result: stringifiedOutput.match(matchRegex) !== null ? 1 : 0 };
|
||||||
case "DOES_NOT_CONTAIN":
|
case "DOES_NOT_CONTAIN":
|
||||||
return { result: stringifiedMessage.match(matchRegex) === null ? 1 : 0 };
|
return { result: stringifiedOutput.match(matchRegex) === null ? 1 : 0 };
|
||||||
case "GPT4_EVAL":
|
case "GPT4_EVAL":
|
||||||
return await runGpt4Eval(evaluation, scenario, message);
|
return await runGpt4Eval(evaluation, scenario, stringifiedOutput);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user