Copy over evals when new cell created

Fixes a bug where new cells generated as clones of existing cells didn't get the eval results cloned as well.
This commit is contained in:
Kyle Corbitt
2023-07-21 18:38:34 -07:00
parent 46036a44d2
commit 52d1d5c7ee
2 changed files with 24 additions and 18 deletions

View File

@@ -4,8 +4,9 @@ import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
import parseConstructFn from "./parseConstructFn";
import { type JsonObject } from "type-fest";
import hashPrompt from "./hashPrompt";
import { omit } from "lodash-es";
export const generateNewCell = async (variantId: string, scenarioId: string) => {
export const generateNewCell = async (variantId: string, scenarioId: string): Promise<void> => {
const variant = await prisma.promptVariant.findUnique({
where: {
id: variantId,
@@ -18,7 +19,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
},
});
if (!variant || !scenario) return null;
if (!variant || !scenario) return;
let cell = await prisma.scenarioVariantCell.findUnique({
where: {
@@ -32,7 +33,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
},
});
if (cell) return cell;
if (cell) return;
const parsedConstructFn = await parseConstructFn(
variant.constructFn,
@@ -40,7 +41,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
);
if ("error" in parsedConstructFn) {
return await prisma.scenarioVariantCell.create({
await prisma.scenarioVariantCell.create({
data: {
promptVariantId: variantId,
testScenarioId: scenarioId,
@@ -49,6 +50,7 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
retrievalStatus: "ERROR",
},
});
return;
}
const inputHash = hashPrompt(parsedConstructFn);
@@ -69,29 +71,33 @@ export const generateNewCell = async (variantId: string, scenarioId: string) =>
where: { inputHash },
});
let newModelOutput;
if (matchingModelOutput) {
newModelOutput = await prisma.modelOutput.create({
const newModelOutput = await prisma.modelOutput.create({
data: {
...omit(matchingModelOutput, ["id"]),
scenarioVariantCellId: cell.id,
inputHash,
output: matchingModelOutput.output as Prisma.InputJsonValue,
timeToComplete: matchingModelOutput.timeToComplete,
cost: matchingModelOutput.cost,
promptTokens: matchingModelOutput.promptTokens,
completionTokens: matchingModelOutput.completionTokens,
createdAt: matchingModelOutput.createdAt,
updatedAt: matchingModelOutput.updatedAt,
},
});
await prisma.scenarioVariantCell.update({
where: { id: cell.id },
data: { retrievalStatus: "COMPLETE" },
});
// Copy over all eval results as well
await Promise.all(
(
await prisma.outputEvaluation.findMany({ where: { modelOutputId: matchingModelOutput.id } })
).map(async (evaluation) => {
await prisma.outputEvaluation.create({
data: {
...omit(evaluation, ["id"]),
modelOutputId: newModelOutput.id,
},
});
}),
);
} else {
cell = await queueLLMRetrievalTask(cell.id);
}
return { ...cell, modelOutput: newModelOutput };
};