remove the default value for PromptVariant.model

We should be explicit about setting the appropriate model so it always matches the constructFn.
This commit is contained in:
Kyle Corbitt
2023-07-14 17:43:52 -07:00
parent 0c3bdbe4f2
commit 3b99b7bd2b
4 changed files with 19 additions and 10 deletions

View File

@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "PromptVariant" ALTER COLUMN "model" DROP DEFAULT;

View File

@@ -30,7 +30,7 @@ model PromptVariant {
label String label String
constructFn String constructFn String
model String @default("gpt-3.5-turbo") model String
uiId String @default(uuid()) @db.Uuid uiId String @default(uuid()) @db.Uuid
visible Boolean @default(true) visible Boolean @default(true)
@@ -39,10 +39,10 @@ model PromptVariant {
experimentId String @db.Uuid experimentId String @db.Uuid
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade) experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
scenarioVariantCells ScenarioVariantCell[] scenarioVariantCells ScenarioVariantCell[]
EvaluationResult EvaluationResult[] EvaluationResult EvaluationResult[]
@@index([uiId]) @@index([uiId])
} }
@@ -59,8 +59,8 @@ model TestScenario {
experimentId String @db.Uuid experimentId String @db.Uuid
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade) experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
scenarioVariantCells ScenarioVariantCell[] scenarioVariantCells ScenarioVariantCell[]
} }
@@ -90,7 +90,7 @@ model ScenarioVariantCell {
output Json? // TODO: Remove once migration is complete output Json? // TODO: Remove once migration is complete
statusCode Int? statusCode Int?
errorMessage String? errorMessage String?
timeToComplete Int? @default(0) // TODO: Remove once migration is complete timeToComplete Int? @default(0) // TODO: Remove once migration is complete
retryTime DateTime? retryTime DateTime?
streamingChannel String? streamingChannel String?
retrievalStatus CellRetrievalStatus @default(COMPLETE) retrievalStatus CellRetrievalStatus @default(COMPLETE)
@@ -116,14 +116,14 @@ model ModelOutput {
inputHash String inputHash String
output Json output Json
timeToComplete Int @default(0) timeToComplete Int @default(0)
promptTokens Int? promptTokens Int?
completionTokens Int? completionTokens Int?
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
scenarioVariantCellId String @db.Uuid scenarioVariantCellId String @db.Uuid
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade) scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
@@unique([scenarioVariantCellId]) @@unique([scenarioVariantCellId])

View File

@@ -36,6 +36,7 @@ await prisma.promptVariant.createMany({
experimentId, experimentId,
label: "Prompt Variant 1", label: "Prompt Variant 1",
sortIndex: 0, sortIndex: 0,
model: "gpt-3.5-turbo-0613",
constructFn: `prompt = { constructFn: `prompt = {
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
messages: [{ role: "user", content: "What is the capital of {{country}}?" }], messages: [{ role: "user", content: "What is the capital of {{country}}?" }],
@@ -46,6 +47,7 @@ await prisma.promptVariant.createMany({
experimentId, experimentId,
label: "Prompt Variant 2", label: "Prompt Variant 2",
sortIndex: 1, sortIndex: 1,
model: "gpt-3.5-turbo-0613",
constructFn: `prompt = { constructFn: `prompt = {
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
messages: [ messages: [

View File

@@ -12,6 +12,7 @@ await prisma.promptVariant.createMany({
{ {
experimentId: functionCallsExperiment.id, experimentId: functionCallsExperiment.id,
label: "No Fn Calls", label: "No Fn Calls",
model: "gpt-3.5-turbo-0613",
constructFn: `prompt = { constructFn: `prompt = {
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
messages: [ messages: [
@@ -30,6 +31,7 @@ await prisma.promptVariant.createMany({
{ {
experimentId: functionCallsExperiment.id, experimentId: functionCallsExperiment.id,
label: "Fn Calls", label: "Fn Calls",
model: "gpt-3.5-turbo-0613",
constructFn: `prompt = { constructFn: `prompt = {
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
messages: [ messages: [
@@ -92,6 +94,7 @@ await prisma.promptVariant.createMany({
experimentId: redditExperiment.id, experimentId: redditExperiment.id,
label: "3.5 Base", label: "3.5 Base",
sortIndex: 0, sortIndex: 0,
model: "gpt-3.5-turbo-0613",
constructFn: `prompt = { constructFn: `prompt = {
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
messages: [ messages: [
@@ -107,6 +110,7 @@ await prisma.promptVariant.createMany({
experimentId: redditExperiment.id, experimentId: redditExperiment.id,
label: "4 Base", label: "4 Base",
sortIndex: 1, sortIndex: 1,
model: "gpt-3.5-turbo-0613",
constructFn: `prompt = { constructFn: `prompt = {
model: "gpt-4-0613", model: "gpt-4-0613",
messages: [ messages: [
@@ -122,6 +126,7 @@ await prisma.promptVariant.createMany({
experimentId: redditExperiment.id, experimentId: redditExperiment.id,
label: "3.5 CoT + Functions", label: "3.5 CoT + Functions",
sortIndex: 2, sortIndex: 2,
model: "gpt-3.5-turbo-0613",
constructFn: `prompt = { constructFn: `prompt = {
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
messages: [ messages: [