Ensure evals run properly (#96)

* Run evals against llama output

* Continue polling in OutputCell until evals complete

* Remove unnecessary check
This commit is contained in:
arcticfly
2023-07-25 20:01:58 -07:00
committed by GitHub
parent 98b231c8bd
commit d4fb8b689a
5 changed files with 66 additions and 41 deletions

View File

@@ -63,6 +63,7 @@ export default function OutputCell({
const awaitingOutput = const awaitingOutput =
!cell || !cell ||
!cell.evalsComplete ||
cell.retrievalStatus === "PENDING" || cell.retrievalStatus === "PENDING" ||
cell.retrievalStatus === "IN_PROGRESS" || cell.retrievalStatus === "IN_PROGRESS" ||
hardRefetching; hardRefetching;

View File

@@ -19,30 +19,45 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
}); });
await requireCanViewExperiment(experimentId, ctx); await requireCanViewExperiment(experimentId, ctx);
return await prisma.scenarioVariantCell.findUnique({ const [cell, numTotalEvals] = await prisma.$transaction([
where: { prisma.scenarioVariantCell.findUnique({
promptVariantId_testScenarioId: { where: {
promptVariantId: input.variantId, promptVariantId_testScenarioId: {
testScenarioId: input.scenarioId, promptVariantId: input.variantId,
}, testScenarioId: input.scenarioId,
},
include: {
modelResponses: {
where: {
outdated: false,
}, },
include: { },
outputEvaluations: { include: {
include: { modelResponses: {
evaluation: { where: {
select: { label: true }, outdated: false,
},
include: {
outputEvaluations: {
include: {
evaluation: {
select: { label: true },
},
}, },
}, },
}, },
}, },
}, },
}, }),
}); prisma.evaluation.count({
where: { experimentId },
}),
]);
if (!cell) return null;
const lastResponse = cell.modelResponses?.[cell.modelResponses?.length - 1];
const evalsComplete = lastResponse?.outputEvaluations?.length === numTotalEvals;
return {
...cell,
evalsComplete,
};
}), }),
forceRefetch: protectedProcedure forceRefetch: protectedProcedure
.input( .input(

View File

@@ -99,7 +99,7 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
const inputHash = hashPrompt(prompt); const inputHash = hashPrompt(prompt);
for (let i = 0; true; i++) { for (let i = 0; true; i++) {
const modelResponse = await prisma.modelResponse.create({ let modelResponse = await prisma.modelResponse.create({
data: { data: {
inputHash, inputHash,
scenarioVariantCellId: cellId, scenarioVariantCellId: cellId,
@@ -108,7 +108,7 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
}); });
const response = await provider.getCompletion(prompt.modelInput, onStream); const response = await provider.getCompletion(prompt.modelInput, onStream);
if (response.type === "success") { if (response.type === "success") {
await prisma.modelResponse.update({ modelResponse = await prisma.modelResponse.update({
where: { id: modelResponse.id }, where: { id: modelResponse.id },
data: { data: {
output: response.value as Prisma.InputJsonObject, output: response.value as Prisma.InputJsonObject,
@@ -127,7 +127,7 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
}, },
}); });
await runEvalsForOutput(variant.experimentId, scenario, modelResponse); await runEvalsForOutput(variant.experimentId, scenario, modelResponse, prompt.modelProvider);
break; break;
} else { } else {
const shouldRetry = response.autoRetry && i < MAX_AUTO_RETRIES; const shouldRetry = response.autoRetry && i < MAX_AUTO_RETRIES;

View File

@@ -2,13 +2,15 @@ import { type ModelResponse, type Evaluation, Prisma } from "@prisma/client";
import { prisma } from "../db"; import { prisma } from "../db";
import { runOneEval } from "./runOneEval"; import { runOneEval } from "./runOneEval";
import { type Scenario } from "~/components/OutputsTable/types"; import { type Scenario } from "~/components/OutputsTable/types";
import { type SupportedProvider } from "~/modelProviders/types";
const saveResult = async ( const runAndSaveEval = async (
evaluation: Evaluation, evaluation: Evaluation,
scenario: Scenario, scenario: Scenario,
modelResponse: ModelResponse, modelResponse: ModelResponse,
provider: SupportedProvider,
) => { ) => {
const result = await runOneEval(evaluation, scenario, modelResponse); const result = await runOneEval(evaluation, scenario, modelResponse, provider);
return await prisma.outputEvaluation.upsert({ return await prisma.outputEvaluation.upsert({
where: { where: {
modelResponseId_evaluationId: { modelResponseId_evaluationId: {
@@ -31,13 +33,16 @@ export const runEvalsForOutput = async (
experimentId: string, experimentId: string,
scenario: Scenario, scenario: Scenario,
modelResponse: ModelResponse, modelResponse: ModelResponse,
provider: SupportedProvider,
) => { ) => {
const evaluations = await prisma.evaluation.findMany({ const evaluations = await prisma.evaluation.findMany({
where: { experimentId }, where: { experimentId },
}); });
await Promise.all( await Promise.all(
evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelResponse)), evaluations.map(
async (evaluation) => await runAndSaveEval(evaluation, scenario, modelResponse, provider),
),
); );
}; };
@@ -62,6 +67,7 @@ export const runAllEvals = async (experimentId: string) => {
scenarioVariantCell: { scenarioVariantCell: {
include: { include: {
testScenario: true, testScenario: true,
promptVariant: true,
}, },
}, },
outputEvaluations: true, outputEvaluations: true,
@@ -73,13 +79,18 @@ export const runAllEvals = async (experimentId: string) => {
await Promise.all( await Promise.all(
outputs.map(async (output) => { outputs.map(async (output) => {
const unrunEvals = evals.filter( const evalsToBeRun = evals.filter(
(evaluation) => !output.outputEvaluations.find((e) => e.evaluationId === evaluation.id), (evaluation) => !output.outputEvaluations.find((e) => e.evaluationId === evaluation.id),
); );
await Promise.all( await Promise.all(
unrunEvals.map(async (evaluation) => { evalsToBeRun.map(async (evaluation) => {
await saveResult(evaluation, output.scenarioVariantCell.testScenario, output); await runAndSaveEval(
evaluation,
output.scenarioVariantCell.testScenario,
output,
output.scenarioVariantCell.promptVariant.modelProvider as SupportedProvider,
);
}), }),
); );
}), }),

View File

@@ -1,13 +1,14 @@
import { type Evaluation, type ModelResponse, type TestScenario } from "@prisma/client"; import { type Evaluation, type ModelResponse, type TestScenario } from "@prisma/client";
import { type ChatCompletion } from "openai/resources/chat";
import { type VariableMap, fillTemplate, escapeRegExp, escapeQuotes } from "./fillTemplate"; import { type VariableMap, fillTemplate, escapeRegExp, escapeQuotes } from "./fillTemplate";
import { openai } from "./openai"; import { openai } from "./openai";
import dedent from "dedent"; import dedent from "dedent";
import modelProviders from "~/modelProviders/modelProviders";
import { type SupportedProvider } from "~/modelProviders/types";
export const runGpt4Eval = async ( export const runGpt4Eval = async (
evaluation: Evaluation, evaluation: Evaluation,
scenario: TestScenario, scenario: TestScenario,
message: ChatCompletion.Choice.Message, stringifiedOutput: string,
): Promise<{ result: number; details: string }> => { ): Promise<{ result: number; details: string }> => {
const output = await openai.chat.completions.create({ const output = await openai.chat.completions.create({
model: "gpt-4-0613", model: "gpt-4-0613",
@@ -26,11 +27,7 @@ export const runGpt4Eval = async (
}, },
{ {
role: "user", role: "user",
content: `The full output of the simpler message:\n---\n${JSON.stringify( content: `The full output of the simpler message:\n---\n${stringifiedOutput}`,
message.content ?? message.function_call,
null,
2,
)}`,
}, },
], ],
function_call: { function_call: {
@@ -71,14 +68,15 @@ export const runOneEval = async (
evaluation: Evaluation, evaluation: Evaluation,
scenario: TestScenario, scenario: TestScenario,
modelResponse: ModelResponse, modelResponse: ModelResponse,
provider: SupportedProvider,
): Promise<{ result: number; details?: string }> => { ): Promise<{ result: number; details?: string }> => {
const output = modelResponse.output as unknown as ChatCompletion; const modelProvider = modelProviders[provider];
const message = modelProvider.normalizeOutput(modelResponse.output);
const message = output?.choices?.[0]?.message;
if (!message) return { result: 0 }; if (!message) return { result: 0 };
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call); const stringifiedOutput =
message.type === "json" ? JSON.stringify(message.value, null, 2) : message.value;
const matchRegex = escapeRegExp( const matchRegex = escapeRegExp(
fillTemplate(escapeQuotes(evaluation.value), scenario.variableValues as VariableMap), fillTemplate(escapeQuotes(evaluation.value), scenario.variableValues as VariableMap),
@@ -86,10 +84,10 @@ export const runOneEval = async (
switch (evaluation.evalType) { switch (evaluation.evalType) {
case "CONTAINS": case "CONTAINS":
return { result: stringifiedMessage.match(matchRegex) !== null ? 1 : 0 }; return { result: stringifiedOutput.match(matchRegex) !== null ? 1 : 0 };
case "DOES_NOT_CONTAIN": case "DOES_NOT_CONTAIN":
return { result: stringifiedMessage.match(matchRegex) === null ? 1 : 0 }; return { result: stringifiedOutput.match(matchRegex) === null ? 1 : 0 };
case "GPT4_EVAL": case "GPT4_EVAL":
return await runGpt4Eval(evaluation, scenario, message); return await runGpt4Eval(evaluation, scenario, stringifiedOutput);
} }
}; };