Store multiple ModelResponses (#95)
* Store multiple ModelResponses * Fix prettier * Add CellContent container
This commit is contained in:
@@ -29,17 +29,9 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
|
||||
const { cellId, stream } = task;
|
||||
const cell = await prisma.scenarioVariantCell.findUnique({
|
||||
where: { id: cellId },
|
||||
include: { modelOutput: true },
|
||||
include: { modelResponses: true },
|
||||
});
|
||||
if (!cell) {
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
statusCode: 404,
|
||||
errorMessage: "Cell not found",
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -51,6 +43,7 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
retrievalStatus: "IN_PROGRESS",
|
||||
jobStartedAt: new Date(),
|
||||
},
|
||||
});
|
||||
|
||||
@@ -61,7 +54,6 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
statusCode: 404,
|
||||
errorMessage: "Prompt Variant not found",
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
@@ -76,7 +68,6 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
statusCode: 404,
|
||||
errorMessage: "Scenario not found",
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
@@ -90,7 +81,6 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
statusCode: 400,
|
||||
errorMessage: prompt.error,
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
@@ -106,17 +96,24 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
|
||||
}
|
||||
: null;
|
||||
|
||||
const inputHash = hashPrompt(prompt);
|
||||
|
||||
for (let i = 0; true; i++) {
|
||||
const modelResponse = await prisma.modelResponse.create({
|
||||
data: {
|
||||
inputHash,
|
||||
scenarioVariantCellId: cellId,
|
||||
requestedAt: new Date(),
|
||||
},
|
||||
});
|
||||
const response = await provider.getCompletion(prompt.modelInput, onStream);
|
||||
if (response.type === "success") {
|
||||
const inputHash = hashPrompt(prompt);
|
||||
|
||||
const modelOutput = await prisma.modelOutput.create({
|
||||
await prisma.modelResponse.update({
|
||||
where: { id: modelResponse.id },
|
||||
data: {
|
||||
scenarioVariantCellId: cellId,
|
||||
inputHash,
|
||||
output: response.value as Prisma.InputJsonObject,
|
||||
timeToComplete: response.timeToComplete,
|
||||
statusCode: response.statusCode,
|
||||
receivedAt: new Date(),
|
||||
promptTokens: response.promptTokens,
|
||||
completionTokens: response.completionTokens,
|
||||
cost: response.cost,
|
||||
@@ -126,30 +123,35 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
statusCode: response.statusCode,
|
||||
retrievalStatus: "COMPLETE",
|
||||
},
|
||||
});
|
||||
|
||||
await runEvalsForOutput(variant.experimentId, scenario, modelOutput);
|
||||
await runEvalsForOutput(variant.experimentId, scenario, modelResponse);
|
||||
break;
|
||||
} else {
|
||||
const shouldRetry = response.autoRetry && i < MAX_AUTO_RETRIES;
|
||||
const delay = calculateDelay(i);
|
||||
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
await prisma.modelResponse.update({
|
||||
where: { id: modelResponse.id },
|
||||
data: {
|
||||
errorMessage: response.message,
|
||||
statusCode: response.statusCode,
|
||||
errorMessage: response.message,
|
||||
receivedAt: new Date(),
|
||||
retryTime: shouldRetry ? new Date(Date.now() + delay) : null,
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
|
||||
if (shouldRetry) {
|
||||
await sleep(delay);
|
||||
} else {
|
||||
await prisma.scenarioVariantCell.update({
|
||||
where: { id: cellId },
|
||||
data: {
|
||||
retrievalStatus: "ERROR",
|
||||
},
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -165,6 +167,7 @@ export const queueQueryModel = async (cellId: string, stream: boolean) => {
|
||||
data: {
|
||||
retrievalStatus: "PENDING",
|
||||
errorMessage: null,
|
||||
jobQueuedAt: new Date(),
|
||||
},
|
||||
}),
|
||||
queryModel.enqueue({ cellId, stream }),
|
||||
|
||||
Reference in New Issue
Block a user