Store multiple ModelResponses (#95)

* Store multiple ModelResponses

* Fix prettier

* Add CellContent container
This commit is contained in:
arcticfly
2023-07-25 18:54:38 -07:00
committed by GitHub
parent 45afb1f1f4
commit 98b231c8bd
15 changed files with 341 additions and 159 deletions

View File

@@ -29,17 +29,9 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
const { cellId, stream } = task;
const cell = await prisma.scenarioVariantCell.findUnique({
where: { id: cellId },
include: { modelOutput: true },
include: { modelResponses: true },
});
if (!cell) {
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
statusCode: 404,
errorMessage: "Cell not found",
retrievalStatus: "ERROR",
},
});
return;
}
@@ -51,6 +43,7 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
where: { id: cellId },
data: {
retrievalStatus: "IN_PROGRESS",
jobStartedAt: new Date(),
},
});
@@ -61,7 +54,6 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
statusCode: 404,
errorMessage: "Prompt Variant not found",
retrievalStatus: "ERROR",
},
@@ -76,7 +68,6 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
statusCode: 404,
errorMessage: "Scenario not found",
retrievalStatus: "ERROR",
},
@@ -90,7 +81,6 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
statusCode: 400,
errorMessage: prompt.error,
retrievalStatus: "ERROR",
},
@@ -106,17 +96,24 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
}
: null;
const inputHash = hashPrompt(prompt);
for (let i = 0; true; i++) {
const modelResponse = await prisma.modelResponse.create({
data: {
inputHash,
scenarioVariantCellId: cellId,
requestedAt: new Date(),
},
});
const response = await provider.getCompletion(prompt.modelInput, onStream);
if (response.type === "success") {
const inputHash = hashPrompt(prompt);
const modelOutput = await prisma.modelOutput.create({
await prisma.modelResponse.update({
where: { id: modelResponse.id },
data: {
scenarioVariantCellId: cellId,
inputHash,
output: response.value as Prisma.InputJsonObject,
timeToComplete: response.timeToComplete,
statusCode: response.statusCode,
receivedAt: new Date(),
promptTokens: response.promptTokens,
completionTokens: response.completionTokens,
cost: response.cost,
@@ -126,30 +123,35 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
statusCode: response.statusCode,
retrievalStatus: "COMPLETE",
},
});
await runEvalsForOutput(variant.experimentId, scenario, modelOutput);
await runEvalsForOutput(variant.experimentId, scenario, modelResponse);
break;
} else {
const shouldRetry = response.autoRetry && i < MAX_AUTO_RETRIES;
const delay = calculateDelay(i);
await prisma.scenarioVariantCell.update({
where: { id: cellId },
await prisma.modelResponse.update({
where: { id: modelResponse.id },
data: {
errorMessage: response.message,
statusCode: response.statusCode,
errorMessage: response.message,
receivedAt: new Date(),
retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
retrievalStatus: "ERROR",
},
});
if (shouldRetry) {
await sleep(delay);
} else {
await prisma.scenarioVariantCell.update({
where: { id: cellId },
data: {
retrievalStatus: "ERROR",
},
});
break;
}
}
@@ -165,6 +167,7 @@ export const queueQueryModel = async (cellId: string, stream: boolean) => {
data: {
retrievalStatus: "PENDING",
errorMessage: null,
jobQueuedAt: new Date(),
},
}),
queryModel.enqueue({ cellId, stream }),