Ensure evals run properly (#96)

* Run evals against llama output

* Continue polling in OutputCell until evals complete

* Remove unnecessary check
This commit is contained in:
arcticfly
2023-07-25 20:01:58 -07:00
committed by GitHub
parent 98b231c8bd
commit d4fb8b689a
5 changed files with 66 additions and 41 deletions

View File

@@ -2,13 +2,15 @@ import { type ModelResponse, type Evaluation, Prisma } from "@prisma/client";
import { prisma } from "../db";
import { runOneEval } from "./runOneEval";
import { type Scenario } from "~/components/OutputsTable/types";
import { type SupportedProvider } from "~/modelProviders/types";
const saveResult = async (
const runAndSaveEval = async (
evaluation: Evaluation,
scenario: Scenario,
modelResponse: ModelResponse,
provider: SupportedProvider,
) => {
const result = await runOneEval(evaluation, scenario, modelResponse);
const result = await runOneEval(evaluation, scenario, modelResponse, provider);
return await prisma.outputEvaluation.upsert({
where: {
modelResponseId_evaluationId: {
@@ -31,13 +33,16 @@ export const runEvalsForOutput = async (
experimentId: string,
scenario: Scenario,
modelResponse: ModelResponse,
provider: SupportedProvider,
) => {
const evaluations = await prisma.evaluation.findMany({
where: { experimentId },
});
await Promise.all(
evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelResponse)),
evaluations.map(
async (evaluation) => await runAndSaveEval(evaluation, scenario, modelResponse, provider),
),
);
};
@@ -62,6 +67,7 @@ export const runAllEvals = async (experimentId: string) => {
scenarioVariantCell: {
include: {
testScenario: true,
promptVariant: true,
},
},
outputEvaluations: true,
@@ -73,13 +79,18 @@ export const runAllEvals = async (experimentId: string) => {
await Promise.all(
outputs.map(async (output) => {
const unrunEvals = evals.filter(
const evalsToBeRun = evals.filter(
(evaluation) => !output.outputEvaluations.find((e) => e.evaluationId === evaluation.id),
);
await Promise.all(
unrunEvals.map(async (evaluation) => {
await saveResult(evaluation, output.scenarioVariantCell.testScenario, output);
evalsToBeRun.map(async (evaluation) => {
await runAndSaveEval(
evaluation,
output.scenarioVariantCell.testScenario,
output,
output.scenarioVariantCell.promptVariant.modelProvider as SupportedProvider,
);
}),
);
}),

View File

@@ -1,13 +1,14 @@
import { type Evaluation, type ModelResponse, type TestScenario } from "@prisma/client";
import { type ChatCompletion } from "openai/resources/chat";
import { type VariableMap, fillTemplate, escapeRegExp, escapeQuotes } from "./fillTemplate";
import { openai } from "./openai";
import dedent from "dedent";
import modelProviders from "~/modelProviders/modelProviders";
import { type SupportedProvider } from "~/modelProviders/types";
export const runGpt4Eval = async (
evaluation: Evaluation,
scenario: TestScenario,
message: ChatCompletion.Choice.Message,
stringifiedOutput: string,
): Promise<{ result: number; details: string }> => {
const output = await openai.chat.completions.create({
model: "gpt-4-0613",
@@ -26,11 +27,7 @@ export const runGpt4Eval = async (
},
{
role: "user",
content: `The full output of the simpler message:\n---\n${JSON.stringify(
message.content ?? message.function_call,
null,
2,
)}`,
content: `The full output of the simpler message:\n---\n${stringifiedOutput}`,
},
],
function_call: {
@@ -71,14 +68,15 @@ export const runOneEval = async (
evaluation: Evaluation,
scenario: TestScenario,
modelResponse: ModelResponse,
provider: SupportedProvider,
): Promise<{ result: number; details?: string }> => {
const output = modelResponse.output as unknown as ChatCompletion;
const message = output?.choices?.[0]?.message;
const modelProvider = modelProviders[provider];
const message = modelProvider.normalizeOutput(modelResponse.output);
if (!message) return { result: 0 };
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
const stringifiedOutput =
message.type === "json" ? JSON.stringify(message.value, null, 2) : message.value;
const matchRegex = escapeRegExp(
fillTemplate(escapeQuotes(evaluation.value), scenario.variableValues as VariableMap),
@@ -86,10 +84,10 @@ export const runOneEval = async (
switch (evaluation.evalType) {
case "CONTAINS":
return { result: stringifiedMessage.match(matchRegex) !== null ? 1 : 0 };
return { result: stringifiedOutput.match(matchRegex) !== null ? 1 : 0 };
case "DOES_NOT_CONTAIN":
return { result: stringifiedMessage.match(matchRegex) === null ? 1 : 0 };
return { result: stringifiedOutput.match(matchRegex) === null ? 1 : 0 };
case "GPT4_EVAL":
return await runGpt4Eval(evaluation, scenario, message);
return await runGpt4Eval(evaluation, scenario, stringifiedOutput);
}
};