Store multiple ModelResponses (#95)
* Store multiple ModelResponses * Fix prettier * Add CellContent container
This commit is contained in:
@@ -1,19 +1,23 @@
|
||||
import { type ModelOutput, type Evaluation } from "@prisma/client";
|
||||
import { type ModelResponse, type Evaluation, Prisma } from "@prisma/client";
|
||||
import { prisma } from "../db";
|
||||
import { runOneEval } from "./runOneEval";
|
||||
import { type Scenario } from "~/components/OutputsTable/types";
|
||||
|
||||
const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutput: ModelOutput) => {
|
||||
const result = await runOneEval(evaluation, scenario, modelOutput);
|
||||
const saveResult = async (
|
||||
evaluation: Evaluation,
|
||||
scenario: Scenario,
|
||||
modelResponse: ModelResponse,
|
||||
) => {
|
||||
const result = await runOneEval(evaluation, scenario, modelResponse);
|
||||
return await prisma.outputEvaluation.upsert({
|
||||
where: {
|
||||
modelOutputId_evaluationId: {
|
||||
modelOutputId: modelOutput.id,
|
||||
modelResponseId_evaluationId: {
|
||||
modelResponseId: modelResponse.id,
|
||||
evaluationId: evaluation.id,
|
||||
},
|
||||
},
|
||||
create: {
|
||||
modelOutputId: modelOutput.id,
|
||||
modelResponseId: modelResponse.id,
|
||||
evaluationId: evaluation.id,
|
||||
...result,
|
||||
},
|
||||
@@ -26,20 +30,24 @@ const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutpu
|
||||
export const runEvalsForOutput = async (
|
||||
experimentId: string,
|
||||
scenario: Scenario,
|
||||
modelOutput: ModelOutput,
|
||||
modelResponse: ModelResponse,
|
||||
) => {
|
||||
const evaluations = await prisma.evaluation.findMany({
|
||||
where: { experimentId },
|
||||
});
|
||||
|
||||
await Promise.all(
|
||||
evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelOutput)),
|
||||
evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelResponse)),
|
||||
);
|
||||
};
|
||||
|
||||
export const runAllEvals = async (experimentId: string) => {
|
||||
const outputs = await prisma.modelOutput.findMany({
|
||||
const outputs = await prisma.modelResponse.findMany({
|
||||
where: {
|
||||
outdated: false,
|
||||
output: {
|
||||
not: Prisma.AnyNull,
|
||||
},
|
||||
scenarioVariantCell: {
|
||||
promptVariant: {
|
||||
experimentId,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { type Prisma } from "@prisma/client";
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { prisma } from "../db";
|
||||
import parseConstructFn from "./parseConstructFn";
|
||||
import { type JsonObject } from "type-fest";
|
||||
@@ -35,7 +35,7 @@ export const generateNewCell = async (
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelOutput: true,
|
||||
modelResponses: true,
|
||||
},
|
||||
});
|
||||
|
||||
@@ -51,8 +51,6 @@ export const generateNewCell = async (
|
||||
data: {
|
||||
promptVariantId: variantId,
|
||||
testScenarioId: scenarioId,
|
||||
statusCode: 400,
|
||||
errorMessage: parsedConstructFn.error,
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
@@ -69,36 +67,55 @@ export const generateNewCell = async (
|
||||
retrievalStatus: "PENDING",
|
||||
},
|
||||
include: {
|
||||
modelOutput: true,
|
||||
modelResponses: true,
|
||||
},
|
||||
});
|
||||
|
||||
const matchingModelOutput = await prisma.modelOutput.findFirst({
|
||||
where: { inputHash },
|
||||
const matchingModelResponse = await prisma.modelResponse.findFirst({
|
||||
where: {
|
||||
inputHash,
|
||||
output: {
|
||||
not: Prisma.AnyNull,
|
||||
},
|
||||
},
|
||||
orderBy: {
|
||||
receivedAt: "desc",
|
||||
},
|
||||
include: {
|
||||
scenarioVariantCell: true,
|
||||
},
|
||||
take: 1,
|
||||
});
|
||||
|
||||
if (matchingModelOutput) {
|
||||
const newModelOutput = await prisma.modelOutput.create({
|
||||
if (matchingModelResponse) {
|
||||
const newModelResponse = await prisma.modelResponse.create({
|
||||
data: {
|
||||
...omit(matchingModelOutput, ["id"]),
|
||||
...omit(matchingModelResponse, ["id", "scenarioVariantCell"]),
|
||||
scenarioVariantCellId: cell.id,
|
||||
output: matchingModelOutput.output as Prisma.InputJsonValue,
|
||||
output: matchingModelResponse.output as Prisma.InputJsonValue,
|
||||
},
|
||||
});
|
||||
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cell.id },
|
||||
data: { retrievalStatus: "COMPLETE" },
|
||||
data: {
|
||||
retrievalStatus: "COMPLETE",
|
||||
jobStartedAt: matchingModelResponse.scenarioVariantCell.jobStartedAt,
|
||||
jobQueuedAt: matchingModelResponse.scenarioVariantCell.jobQueuedAt,
|
||||
},
|
||||
});
|
||||
|
||||
// Copy over all eval results as well
|
||||
await Promise.all(
|
||||
(
|
||||
await prisma.outputEvaluation.findMany({ where: { modelOutputId: matchingModelOutput.id } })
|
||||
await prisma.outputEvaluation.findMany({
|
||||
where: { modelResponseId: matchingModelResponse.id },
|
||||
})
|
||||
).map(async (evaluation) => {
|
||||
await prisma.outputEvaluation.create({
|
||||
data: {
|
||||
...omit(evaluation, ["id"]),
|
||||
modelOutputId: newModelOutput.id,
|
||||
modelResponseId: newModelResponse.id,
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/client";
|
||||
import { type Evaluation, type ModelResponse, type TestScenario } from "@prisma/client";
|
||||
import { type ChatCompletion } from "openai/resources/chat";
|
||||
import { type VariableMap, fillTemplate, escapeRegExp, escapeQuotes } from "./fillTemplate";
|
||||
import { openai } from "./openai";
|
||||
@@ -70,9 +70,9 @@ export const runGpt4Eval = async (
|
||||
export const runOneEval = async (
|
||||
evaluation: Evaluation,
|
||||
scenario: TestScenario,
|
||||
modelOutput: ModelOutput,
|
||||
modelResponse: ModelResponse,
|
||||
): Promise<{ result: number; details?: string }> => {
|
||||
const output = modelOutput.output as unknown as ChatCompletion;
|
||||
const output = modelResponse.output as unknown as ChatCompletion;
|
||||
|
||||
const message = output?.choices?.[0]?.message;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user