Trigger llm output retrieval on server (#39)
* Rename tables, add graphile workers, update types * Add dev:worker command * Update pnpm-lock.yaml * Remove sentry config import from worker.ts * Stop generating new cells in cell router get query * Generate new cells for new scenarios, variants, and experiments * Remove most error throwing from queryLLM.task.ts * Remove promptVariantId and testScenarioId from ModelOutput * Remove duplicate index from ModelOutput * Move inputHash from cell to output * Add TODO * Add todo * Show cost and time for each cell * Always show output stats if there is output * Trigger LLM outputs when scenario variables are updated * Add newlines to ends of files * Add another newline * Cascade ModelOutput deletion * Fix linting and prettier * Return instead of throwing for non-pending cell * Remove pnpm dev:worker from pnpm:dev * Update pnpm-lock.yaml
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import { type Evaluation } from "@prisma/client";
|
||||
import { type ModelOutput, type Evaluation } from "@prisma/client";
|
||||
import { prisma } from "../db";
|
||||
import { evaluateOutput } from "./evaluateOutput";
|
||||
|
||||
@@ -12,21 +12,22 @@ export const reevaluateVariant = async (variantId: string) => {
|
||||
where: { experimentId: variant.experimentId },
|
||||
});
|
||||
|
||||
const modelOutputs = await prisma.modelOutput.findMany({
|
||||
const cells = await prisma.scenarioVariantCell.findMany({
|
||||
where: {
|
||||
promptVariantId: variantId,
|
||||
statusCode: { notIn: [429] },
|
||||
retrievalStatus: "COMPLETE",
|
||||
testScenario: { visible: true },
|
||||
modelOutput: { isNot: null },
|
||||
},
|
||||
include: { testScenario: true },
|
||||
include: { testScenario: true, modelOutput: true },
|
||||
});
|
||||
|
||||
await Promise.all(
|
||||
evaluations.map(async (evaluation) => {
|
||||
const passCount = modelOutputs.filter((output) =>
|
||||
evaluateOutput(output, output.testScenario, evaluation),
|
||||
const passCount = cells.filter((cell) =>
|
||||
evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
|
||||
).length;
|
||||
const failCount = modelOutputs.length - passCount;
|
||||
const failCount = cells.length - passCount;
|
||||
|
||||
await prisma.evaluationResult.upsert({
|
||||
where: {
|
||||
@@ -55,22 +56,23 @@ export const reevaluateEvaluation = async (evaluation: Evaluation) => {
|
||||
where: { experimentId: evaluation.experimentId, visible: true },
|
||||
});
|
||||
|
||||
const modelOutputs = await prisma.modelOutput.findMany({
|
||||
const cells = await prisma.scenarioVariantCell.findMany({
|
||||
where: {
|
||||
promptVariantId: { in: variants.map((v) => v.id) },
|
||||
testScenario: { visible: true },
|
||||
statusCode: { notIn: [429] },
|
||||
modelOutput: { isNot: null },
|
||||
},
|
||||
include: { testScenario: true },
|
||||
include: { testScenario: true, modelOutput: true },
|
||||
});
|
||||
|
||||
await Promise.all(
|
||||
variants.map(async (variant) => {
|
||||
const outputs = modelOutputs.filter((output) => output.promptVariantId === variant.id);
|
||||
const passCount = outputs.filter((output) =>
|
||||
evaluateOutput(output, output.testScenario, evaluation),
|
||||
const variantCells = cells.filter((cell) => cell.promptVariantId === variant.id);
|
||||
const passCount = variantCells.filter((cell) =>
|
||||
evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
|
||||
).length;
|
||||
const failCount = outputs.length - passCount;
|
||||
const failCount = variantCells.length - passCount;
|
||||
|
||||
await prisma.evaluationResult.upsert({
|
||||
where: {
|
||||
|
||||
76
src/server/utils/generateNewCell.ts
Normal file
76
src/server/utils/generateNewCell.ts
Normal file
@@ -0,0 +1,76 @@
|
||||
import crypto from "crypto";
|
||||
import { type Prisma } from "@prisma/client";
|
||||
import { prisma } from "../db";
|
||||
import { queueLLMRetrievalTask } from "./queueLLMRetrievalTask";
|
||||
import { constructPrompt } from "./constructPrompt";
|
||||
|
||||
export const generateNewCell = async (variantId: string, scenarioId: string) => {
|
||||
const variant = await prisma.promptVariant.findUnique({
|
||||
where: {
|
||||
id: variantId,
|
||||
},
|
||||
});
|
||||
|
||||
const scenario = await prisma.testScenario.findUnique({
|
||||
where: {
|
||||
id: scenarioId,
|
||||
},
|
||||
});
|
||||
|
||||
if (!variant || !scenario) return null;
|
||||
|
||||
const prompt = await constructPrompt(variant, scenario.variableValues);
|
||||
|
||||
const inputHash = crypto.createHash("sha256").update(JSON.stringify(prompt)).digest("hex");
|
||||
|
||||
let cell = await prisma.scenarioVariantCell.findUnique({
|
||||
where: {
|
||||
promptVariantId_testScenarioId: {
|
||||
promptVariantId: variantId,
|
||||
testScenarioId: scenarioId,
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelOutput: true,
|
||||
},
|
||||
});
|
||||
|
||||
if (cell) return cell;
|
||||
|
||||
cell = await prisma.scenarioVariantCell.create({
|
||||
data: {
|
||||
promptVariantId: variantId,
|
||||
testScenarioId: scenarioId,
|
||||
},
|
||||
include: {
|
||||
modelOutput: true,
|
||||
},
|
||||
});
|
||||
|
||||
const matchingModelOutput = await prisma.modelOutput.findFirst({
|
||||
where: {
|
||||
inputHash,
|
||||
},
|
||||
});
|
||||
|
||||
let newModelOutput;
|
||||
|
||||
if (matchingModelOutput) {
|
||||
newModelOutput = await prisma.modelOutput.create({
|
||||
data: {
|
||||
scenarioVariantCellId: cell.id,
|
||||
inputHash,
|
||||
output: matchingModelOutput.output as Prisma.InputJsonValue,
|
||||
timeToComplete: matchingModelOutput.timeToComplete,
|
||||
promptTokens: matchingModelOutput.promptTokens,
|
||||
completionTokens: matchingModelOutput.completionTokens,
|
||||
createdAt: matchingModelOutput.createdAt,
|
||||
updatedAt: matchingModelOutput.updatedAt,
|
||||
},
|
||||
});
|
||||
} else {
|
||||
cell = await queueLLMRetrievalTask(cell.id);
|
||||
}
|
||||
|
||||
return { ...cell, modelOutput: newModelOutput };
|
||||
};
|
||||
@@ -9,7 +9,7 @@ import { env } from "~/env.mjs";
|
||||
import { countOpenAIChatTokens } from "~/utils/countTokens";
|
||||
import { rateLimitErrorMessage } from "~/sharedStrings";
|
||||
|
||||
type CompletionResponse = {
|
||||
export type CompletionResponse = {
|
||||
output: Prisma.InputJsonValue | typeof Prisma.JsonNull;
|
||||
statusCode: number;
|
||||
errorMessage: string | null;
|
||||
|
||||
22
src/server/utils/queueLLMRetrievalTask.ts
Normal file
22
src/server/utils/queueLLMRetrievalTask.ts
Normal file
@@ -0,0 +1,22 @@
|
||||
import { prisma } from "../db";
|
||||
import { queryLLM } from "../tasks/queryLLM.task";
|
||||
|
||||
export const queueLLMRetrievalTask = async (cellId: string) => {
|
||||
const updatedCell = await prisma.scenarioVariantCell.update({
|
||||
where: {
|
||||
id: cellId,
|
||||
},
|
||||
data: {
|
||||
retrievalStatus: "PENDING",
|
||||
errorMessage: null,
|
||||
},
|
||||
include: {
|
||||
modelOutput: true,
|
||||
},
|
||||
});
|
||||
|
||||
// @ts-expect-error we aren't passing the helpers but that's ok
|
||||
void queryLLM.task.handler({ scenarioVariantCellId: cellId }, { logger: console });
|
||||
|
||||
return updatedCell;
|
||||
};
|
||||
7
src/server/utils/shouldStream.ts
Normal file
7
src/server/utils/shouldStream.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
import { isObject } from "lodash";
|
||||
import { type JSONSerializable } from "../types";
|
||||
|
||||
export const shouldStream = (config: JSONSerializable): boolean => {
|
||||
const shouldStream = isObject(config) && "stream" in config && config.stream === true;
|
||||
return shouldStream;
|
||||
};
|
||||
1
src/server/utils/sleep.ts
Normal file
1
src/server/utils/sleep.ts
Normal file
@@ -0,0 +1 @@
|
||||
export const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
|
||||
Reference in New Issue
Block a user