From 54369dba541b6e4562181f353e642f14b63b94cc Mon Sep 17 00:00:00 2001 From: Kyle Corbitt Date: Mon, 17 Jul 2023 14:14:20 -0700 Subject: [PATCH] Fix seeds and update eval field names --- package.json | 9 +-- pnpm-lock.yaml | 18 +++-- .../migration.sql | 24 +++++++ prisma/schema.prisma | 11 ++-- prisma/seed.ts | 65 ++++++++++++++----- prisma/seedDemo.ts | 15 +---- src/codegen/export-openai-types.ts | 4 +- .../OutputsTable/EditEvaluations.tsx | 26 ++++---- .../OutputsTable/OutputCell/OutputStats.tsx | 2 +- .../OutputsTable/ScenarioEditor.tsx | 2 +- src/components/OutputsTable/VariantStats.tsx | 2 +- src/server/api/autogen.ts | 2 +- src/server/api/routers/evaluations.router.ts | 24 +++---- .../api/routers/promptVariants.router.ts | 2 +- src/server/utils/evaluateOutput.ts | 4 +- src/server/utils/getCompletion.ts | 2 +- src/server/utils/openai.ts | 2 +- src/server/utils/shouldStream.ts | 2 +- 18 files changed, 136 insertions(+), 80 deletions(-) create mode 100644 prisma/migrations/20230717203031_add_gpt4_eval/migration.sql diff --git a/package.json b/package.json index a8c1553..075dcd9 100644 --- a/package.json +++ b/package.json @@ -16,7 +16,8 @@ "postinstall": "prisma generate", "lint": "next lint", "start": "next start", - "codegen": "tsx src/codegen/export-openai-types.ts" + "codegen": "tsx src/codegen/export-openai-types.ts", + "seed": "tsx prisma/seed.ts" }, "dependencies": { "@babel/preset-typescript": "^7.22.5", @@ -49,7 +50,7 @@ "immer": "^10.0.2", "isolated-vm": "^4.5.0", "json-stringify-pretty-compact": "^4.0.0", - "lodash": "^4.17.21", + "lodash-es": "^4.17.21", "next": "^13.4.2", "next-auth": "^4.22.1", "nextjs-routes": "^2.0.1", @@ -77,7 +78,7 @@ "@types/cors": "^2.8.13", "@types/eslint": "^8.37.0", "@types/express": "^4.17.17", - "@types/lodash": "^4.14.195", + "@types/lodash-es": "^4.17.8", "@types/node": "^18.16.0", "@types/pluralize": "^0.0.30", "@types/react": "^18.2.6", @@ -100,6 +101,6 @@ "initVersion": "7.14.0" }, "prisma": { - "seed": "tsx prisma/seed.ts" + "seed": "pnpm seed" } } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 92d1320..22c5766 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -95,7 +95,7 @@ dependencies: json-stringify-pretty-compact: specifier: ^4.0.0 version: 4.0.0 - lodash: + lodash-es: specifier: ^4.17.21 version: 4.17.21 next: @@ -175,9 +175,9 @@ devDependencies: '@types/express': specifier: ^4.17.17 version: 4.17.17 - '@types/lodash': - specifier: ^4.14.195 - version: 4.14.195 + '@types/lodash-es': + specifier: ^4.17.8 + version: 4.17.8 '@types/node': specifier: ^18.16.0 version: 18.16.0 @@ -2753,6 +2753,12 @@ packages: resolution: {integrity: sha512-dRLjCWHYg4oaA77cxO64oO+7JwCwnIzkZPdrrC71jQmQtlhM556pwKo5bUzqvZndkVbeFLIIi+9TC40JNF5hNQ==} dev: true + /@types/lodash-es@4.17.8: + resolution: {integrity: sha512-euY3XQcZmIzSy7YH5+Unb3b2X12Wtk54YWINBvvGQ5SmMvwb11JQskGsfkH/5HXK77Kr8GF0wkVDIxzAisWtog==} + dependencies: + '@types/lodash': 4.14.195 + dev: true + /@types/lodash.mergewith@4.6.7: resolution: {integrity: sha512-3m+lkO5CLRRYU0fhGRp7zbsGi6+BZj0uTVSwvcKU+nSlhjA9/QRNfuSGnD2mX6hQA7ZbmcCkzk5h4ZYGOtk14A==} dependencies: @@ -5379,6 +5385,10 @@ packages: p-locate: 5.0.0 dev: true + /lodash-es@4.17.21: + resolution: {integrity: sha512-mKnC+QJ9pWVzv+C4/U3rRsHapFfHvQFoFB92e52xeyGMcX6/OlIl78je1u8vePzYZSkkogMPJ2yjxxsb89cxyw==} + dev: false + /lodash.merge@4.6.2: resolution: {integrity: sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==} dev: true diff --git a/prisma/migrations/20230717203031_add_gpt4_eval/migration.sql b/prisma/migrations/20230717203031_add_gpt4_eval/migration.sql new file mode 100644 index 0000000..0ae7b1b --- /dev/null +++ b/prisma/migrations/20230717203031_add_gpt4_eval/migration.sql @@ -0,0 +1,24 @@ +/* + Warnings: + + - You are about to rename the column `matchString` on the `Evaluation` table. If there is any code or views referring to the old name, they will break. + - You are about to rename the column `matchType` on the `Evaluation` table. If there is any code or views referring to the old name, they will break. + - You are about to rename the column `name` on the `Evaluation` table. If there is any code or views referring to the old name, they will break. + - You are about to rename the enum `EvaluationMatchType` to `EvalType`. If there is any code or views referring to the old name, they will break. +*/ + +-- RenameEnum +ALTER TYPE "EvaluationMatchType" RENAME TO "EvalType"; + +-- AlterTable +ALTER TABLE "Evaluation" RENAME COLUMN "matchString" TO "value"; +ALTER TABLE "Evaluation" RENAME COLUMN "matchType" TO "evalType"; +ALTER TABLE "Evaluation" RENAME COLUMN "name" TO "label"; + +-- AlterColumnType +ALTER TABLE "Evaluation" ALTER COLUMN "evalType" TYPE "EvalType" USING "evalType"::text::"EvalType"; + +-- SetNotNullConstraint +ALTER TABLE "Evaluation" ALTER COLUMN "evalType" SET NOT NULL; +ALTER TABLE "Evaluation" ALTER COLUMN "label" SET NOT NULL; +ALTER TABLE "Evaluation" ALTER COLUMN "value" SET NOT NULL; diff --git a/prisma/schema.prisma b/prisma/schema.prisma index 9de878d..bf05a85 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -2,8 +2,7 @@ // learn more about it in the docs: https://pris.ly/d/prisma-schema generator client { - provider = "prisma-client-js" - previewFeatures = ["jsonProtocol"] + provider = "prisma-client-js" } datasource db { @@ -130,7 +129,7 @@ model ModelOutput { @@index([inputHash]) } -enum EvaluationMatchType { +enum EvalType { CONTAINS DOES_NOT_CONTAIN } @@ -138,9 +137,9 @@ enum EvaluationMatchType { model Evaluation { id String @id @default(uuid()) @db.Uuid - name String - matchString String - matchType EvaluationMatchType + label String + evalType EvalType + value String experimentId String @db.Uuid experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade) diff --git a/prisma/seed.ts b/prisma/seed.ts index af8f17b..a4fe415 100644 --- a/prisma/seed.ts +++ b/prisma/seed.ts @@ -1,4 +1,6 @@ import { prisma } from "~/server/db"; +import dedent from "dedent"; +import { generateNewCell } from "~/server/utils/generateNewCell"; const experimentId = "11111111-1111-1111-1111-111111111111"; @@ -9,7 +11,7 @@ await prisma.experiment.deleteMany({ }, }); -const experiment = await prisma.experiment.create({ +await prisma.experiment.create({ data: { id: experimentId, label: "Country Capitals Example", @@ -37,28 +39,34 @@ await prisma.promptVariant.createMany({ label: "Prompt Variant 1", sortIndex: 0, model: "gpt-3.5-turbo-0613", - constructFn: `prompt = { - model: "gpt-3.5-turbo-0613", - messages: [{ role: "user", content: "What is the capital of {{country}}?" }], - temperature: 0, - }`, + constructFn: dedent` + prompt = { + model: "gpt-3.5-turbo-0613", + messages: [ + { + role: "user", + content: \`What is the capital of ${"$"}{scenario.country}?\` + } + ], + temperature: 0, + }`, }, { experimentId, label: "Prompt Variant 2", sortIndex: 1, model: "gpt-3.5-turbo-0613", - constructFn: `prompt = { - model: "gpt-3.5-turbo-0613", - messages: [ - { - role: "user", - content: - "What is the capital of {{country}}? Return just the city name and nothing else.", - }, - ], - temperature: 0, - }`, + constructFn: dedent` + prompt = { + model: "gpt-3.5-turbo-0613", + messages: [ + { + role: "user", + content: \`What is the capital of ${"$"}{scenario.country}? Return just the city name and nothing else.\` + } + ], + temperature: 0, + }`, }, ], }); @@ -109,3 +117,26 @@ await prisma.testScenario.createMany({ }, ], }); + +const variants = await prisma.promptVariant.findMany({ + where: { + experimentId, + }, +}); + +const scenarios = await prisma.testScenario.findMany({ + where: { + experimentId, + }, +}); + +await Promise.all( + variants + .flatMap((variant) => + scenarios.map((scenario) => ({ + promptVariantId: variant.id, + testScenarioId: scenario.id, + })), + ) + .map((cell) => generateNewCell(cell.promptVariantId, cell.testScenarioId)), +); diff --git a/prisma/seedDemo.ts b/prisma/seedDemo.ts index d86bb78..0fff442 100644 --- a/prisma/seedDemo.ts +++ b/prisma/seedDemo.ts @@ -183,9 +183,9 @@ await prisma.templateVariable.createMany({ await prisma.evaluation.create({ data: { experimentId: redditExperiment.id, - name: "Relevance Accuracy", - matchType: "CONTAINS", - matchString: '"{{relevance}}"', + label: "Relevance Accuracy", + evalType: "CONTAINS", + value: '"{{relevance}}"', }, }); @@ -1124,12 +1124,3 @@ await prisma.testScenario.createMany({ variableValues: vars, })), }); - -// await prisma.evaluation.create({ -// data: { -// experimentId: redditExperiment.id, -// name: "Scores Match", -// matchType: "CONTAINS", -// matchString: "{{score}}", -// }, -// }); diff --git a/src/codegen/export-openai-types.ts b/src/codegen/export-openai-types.ts index f17e2ed..0b3325a 100644 --- a/src/codegen/export-openai-types.ts +++ b/src/codegen/export-openai-types.ts @@ -2,7 +2,7 @@ import fs from "fs"; import path from "path"; import openapiTS, { type OpenAPI3 } from "openapi-typescript"; import YAML from "yaml"; -import _ from "lodash"; +import { pick } from "lodash-es"; import assert from "assert"; const OPENAPI_URL = @@ -31,7 +31,7 @@ modelProperty.oneOf = undefined; delete schema["paths"]; assert(schema.components?.schemas); -schema.components.schemas = _.pick(schema.components?.schemas, [ +schema.components.schemas = pick(schema.components?.schemas, [ "CreateChatCompletionRequest", "ChatCompletionRequestMessage", "ChatCompletionFunctions", diff --git a/src/components/OutputsTable/EditEvaluations.tsx b/src/components/OutputsTable/EditEvaluations.tsx index 6abd79e..7d1a44c 100644 --- a/src/components/OutputsTable/EditEvaluations.tsx +++ b/src/components/OutputsTable/EditEvaluations.tsx @@ -12,13 +12,13 @@ import { Select, FormHelperText, } from "@chakra-ui/react"; -import { type Evaluation, EvaluationMatchType } from "@prisma/client"; +import { type Evaluation, EvalType } from "@prisma/client"; import { useCallback, useState } from "react"; import { BsPencil, BsX } from "react-icons/bs"; import { api } from "~/utils/api"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; -type EvalValues = Pick; +type EvalValues = Pick; export function EvaluationEditor(props: { evaluation: Evaluation | null; @@ -27,9 +27,9 @@ export function EvaluationEditor(props: { onCancel: () => void; }) { const [values, setValues] = useState({ - name: props.evaluation?.name ?? props.defaultName ?? "", - matchString: props.evaluation?.matchString ?? "", - matchType: props.evaluation?.matchType ?? "CONTAINS", + label: props.evaluation?.label ?? props.defaultName ?? "", + value: props.evaluation?.value ?? "", + evalType: props.evaluation?.evalType ?? "CONTAINS", }); return ( @@ -39,7 +39,7 @@ export function EvaluationEditor(props: { Evaluation Name setValues((values) => ({ ...values, name: e.target.value }))} /> @@ -47,15 +47,15 @@ export function EvaluationEditor(props: { Match Type setValues((values) => ({ ...values, matchString: e.target.value }))} + value={values.value} + onChange={(e) => setValues((values) => ({ ...values, value: e.target.value }))} /> This string will be interpreted as a regex and checked against each model output. @@ -156,9 +156,9 @@ export default function EditEvaluations() { align="center" key={evaluation.id} > - {evaluation.name} + {evaluation.label} - {evaluation.matchType}: "{evaluation.matchString}" + {evaluation.evalType}: "{evaluation.value}"