Prep for more model providers

Adds a `modelProvider` field to `promptVariants`, currently just set to "openai/ChatCompletion" for all variants for now.

Adds a `modelProviders/` directory where we can define and store pluggable model providers. Currently just OpenAI. Not everything is pluggable yet -- notably the code to actually generate completions hasn't been migrated to this setup yet.

Does a lot of work to get the types working. Prompts are now defined with a function `definePrompt(modelProvider, config)` instead of `prompt = config`. Added a script to migrate old prompt definitions.

This is still partial work, but the diff is large enough that I want to get it in. I don't think anything is broken but I haven't tested thoroughly.
This commit is contained in:
Kyle Corbitt
2023-07-20 14:47:39 -07:00
parent 2c8c8d07cf
commit ded6678e97
43 changed files with 1195 additions and 3023 deletions

View File

@@ -0,0 +1,17 @@
-- Add new columns allowing NULL values
ALTER TABLE "PromptVariant"
ADD COLUMN "constructFnVersion" INTEGER,
ADD COLUMN "modelProvider" TEXT;
-- Update existing records to have the default values
UPDATE "PromptVariant"
SET "constructFnVersion" = 1,
"modelProvider" = 'openai/ChatCompletion'
WHERE "constructFnVersion" IS NULL OR "modelProvider" IS NULL;
-- Alter table to set NOT NULL constraint
ALTER TABLE "PromptVariant"
ALTER COLUMN "constructFnVersion" SET NOT NULL,
ALTER COLUMN "modelProvider" SET NOT NULL;
ALTER TABLE "ScenarioVariantCell" ADD COLUMN "prompt" JSONB;

View File

@@ -31,9 +31,11 @@ model Experiment {
model PromptVariant {
id String @id @default(uuid()) @db.Uuid
label String
constructFn String
model String
label String
constructFn String
constructFnVersion Int
model String
modelProvider String
uiId String @default(uuid()) @db.Uuid
visible Boolean @default(true)
@@ -98,6 +100,7 @@ model ScenarioVariantCell {
promptVariantId String @db.Uuid
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
prompt Json?
testScenarioId String @db.Uuid
testScenario TestScenario @relation(fields: [testScenarioId], references: [id], onDelete: Cascade)

View File

@@ -46,8 +46,10 @@ await prisma.promptVariant.createMany({
label: "Prompt Variant 1",
sortIndex: 0,
model: "gpt-3.5-turbo-0613",
modelProvider: "openai/ChatCompletion",
constructFnVersion: 1,
constructFn: dedent`
prompt = {
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo-0613",
messages: [
{
@@ -56,15 +58,17 @@ await prisma.promptVariant.createMany({
}
],
temperature: 0,
}`,
})`,
},
{
experimentId: defaultId,
label: "Prompt Variant 2",
sortIndex: 1,
model: "gpt-3.5-turbo-0613",
modelProvider: "openai/ChatCompletion",
constructFnVersion: 1,
constructFn: dedent`
prompt = {
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo-0613",
messages: [
{
@@ -73,7 +77,7 @@ await prisma.promptVariant.createMany({
}
],
temperature: 0,
}`,
})`,
},
],
});