Fix seeds and update eval field names

This commit is contained in:
Kyle Corbitt
2023-07-17 14:14:20 -07:00
parent 6b84a59372
commit 54369dba54
18 changed files with 136 additions and 80 deletions

View File

@@ -16,7 +16,8 @@
"postinstall": "prisma generate", "postinstall": "prisma generate",
"lint": "next lint", "lint": "next lint",
"start": "next start", "start": "next start",
"codegen": "tsx src/codegen/export-openai-types.ts" "codegen": "tsx src/codegen/export-openai-types.ts",
"seed": "tsx prisma/seed.ts"
}, },
"dependencies": { "dependencies": {
"@babel/preset-typescript": "^7.22.5", "@babel/preset-typescript": "^7.22.5",
@@ -49,7 +50,7 @@
"immer": "^10.0.2", "immer": "^10.0.2",
"isolated-vm": "^4.5.0", "isolated-vm": "^4.5.0",
"json-stringify-pretty-compact": "^4.0.0", "json-stringify-pretty-compact": "^4.0.0",
"lodash": "^4.17.21", "lodash-es": "^4.17.21",
"next": "^13.4.2", "next": "^13.4.2",
"next-auth": "^4.22.1", "next-auth": "^4.22.1",
"nextjs-routes": "^2.0.1", "nextjs-routes": "^2.0.1",
@@ -77,7 +78,7 @@
"@types/cors": "^2.8.13", "@types/cors": "^2.8.13",
"@types/eslint": "^8.37.0", "@types/eslint": "^8.37.0",
"@types/express": "^4.17.17", "@types/express": "^4.17.17",
"@types/lodash": "^4.14.195", "@types/lodash-es": "^4.17.8",
"@types/node": "^18.16.0", "@types/node": "^18.16.0",
"@types/pluralize": "^0.0.30", "@types/pluralize": "^0.0.30",
"@types/react": "^18.2.6", "@types/react": "^18.2.6",
@@ -100,6 +101,6 @@
"initVersion": "7.14.0" "initVersion": "7.14.0"
}, },
"prisma": { "prisma": {
"seed": "tsx prisma/seed.ts" "seed": "pnpm seed"
} }
} }

18
pnpm-lock.yaml generated
View File

@@ -95,7 +95,7 @@ dependencies:
json-stringify-pretty-compact: json-stringify-pretty-compact:
specifier: ^4.0.0 specifier: ^4.0.0
version: 4.0.0 version: 4.0.0
lodash: lodash-es:
specifier: ^4.17.21 specifier: ^4.17.21
version: 4.17.21 version: 4.17.21
next: next:
@@ -175,9 +175,9 @@ devDependencies:
'@types/express': '@types/express':
specifier: ^4.17.17 specifier: ^4.17.17
version: 4.17.17 version: 4.17.17
'@types/lodash': '@types/lodash-es':
specifier: ^4.14.195 specifier: ^4.17.8
version: 4.14.195 version: 4.17.8
'@types/node': '@types/node':
specifier: ^18.16.0 specifier: ^18.16.0
version: 18.16.0 version: 18.16.0
@@ -2753,6 +2753,12 @@ packages:
resolution: {integrity: sha512-dRLjCWHYg4oaA77cxO64oO+7JwCwnIzkZPdrrC71jQmQtlhM556pwKo5bUzqvZndkVbeFLIIi+9TC40JNF5hNQ==} resolution: {integrity: sha512-dRLjCWHYg4oaA77cxO64oO+7JwCwnIzkZPdrrC71jQmQtlhM556pwKo5bUzqvZndkVbeFLIIi+9TC40JNF5hNQ==}
dev: true dev: true
/@types/lodash-es@4.17.8:
resolution: {integrity: sha512-euY3XQcZmIzSy7YH5+Unb3b2X12Wtk54YWINBvvGQ5SmMvwb11JQskGsfkH/5HXK77Kr8GF0wkVDIxzAisWtog==}
dependencies:
'@types/lodash': 4.14.195
dev: true
/@types/lodash.mergewith@4.6.7: /@types/lodash.mergewith@4.6.7:
resolution: {integrity: sha512-3m+lkO5CLRRYU0fhGRp7zbsGi6+BZj0uTVSwvcKU+nSlhjA9/QRNfuSGnD2mX6hQA7ZbmcCkzk5h4ZYGOtk14A==} resolution: {integrity: sha512-3m+lkO5CLRRYU0fhGRp7zbsGi6+BZj0uTVSwvcKU+nSlhjA9/QRNfuSGnD2mX6hQA7ZbmcCkzk5h4ZYGOtk14A==}
dependencies: dependencies:
@@ -5379,6 +5385,10 @@ packages:
p-locate: 5.0.0 p-locate: 5.0.0
dev: true dev: true
/lodash-es@4.17.21:
resolution: {integrity: sha512-mKnC+QJ9pWVzv+C4/U3rRsHapFfHvQFoFB92e52xeyGMcX6/OlIl78je1u8vePzYZSkkogMPJ2yjxxsb89cxyw==}
dev: false
/lodash.merge@4.6.2: /lodash.merge@4.6.2:
resolution: {integrity: sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==} resolution: {integrity: sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==}
dev: true dev: true

View File

@@ -0,0 +1,24 @@
/*
Warnings:
- You are about to rename the column `matchString` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
- You are about to rename the column `matchType` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
- You are about to rename the column `name` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
- You are about to rename the enum `EvaluationMatchType` to `EvalType`. If there is any code or views referring to the old name, they will break.
*/
-- RenameEnum
ALTER TYPE "EvaluationMatchType" RENAME TO "EvalType";
-- AlterTable
ALTER TABLE "Evaluation" RENAME COLUMN "matchString" TO "value";
ALTER TABLE "Evaluation" RENAME COLUMN "matchType" TO "evalType";
ALTER TABLE "Evaluation" RENAME COLUMN "name" TO "label";
-- AlterColumnType
ALTER TABLE "Evaluation" ALTER COLUMN "evalType" TYPE "EvalType" USING "evalType"::text::"EvalType";
-- SetNotNullConstraint
ALTER TABLE "Evaluation" ALTER COLUMN "evalType" SET NOT NULL;
ALTER TABLE "Evaluation" ALTER COLUMN "label" SET NOT NULL;
ALTER TABLE "Evaluation" ALTER COLUMN "value" SET NOT NULL;

View File

@@ -3,7 +3,6 @@
generator client { generator client {
provider = "prisma-client-js" provider = "prisma-client-js"
previewFeatures = ["jsonProtocol"]
} }
datasource db { datasource db {
@@ -130,7 +129,7 @@ model ModelOutput {
@@index([inputHash]) @@index([inputHash])
} }
enum EvaluationMatchType { enum EvalType {
CONTAINS CONTAINS
DOES_NOT_CONTAIN DOES_NOT_CONTAIN
} }
@@ -138,9 +137,9 @@ enum EvaluationMatchType {
model Evaluation { model Evaluation {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
name String label String
matchString String evalType EvalType
matchType EvaluationMatchType value String
experimentId String @db.Uuid experimentId String @db.Uuid
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade) experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)

View File

@@ -1,4 +1,6 @@
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import dedent from "dedent";
import { generateNewCell } from "~/server/utils/generateNewCell";
const experimentId = "11111111-1111-1111-1111-111111111111"; const experimentId = "11111111-1111-1111-1111-111111111111";
@@ -9,7 +11,7 @@ await prisma.experiment.deleteMany({
}, },
}); });
const experiment = await prisma.experiment.create({ await prisma.experiment.create({
data: { data: {
id: experimentId, id: experimentId,
label: "Country Capitals Example", label: "Country Capitals Example",
@@ -37,9 +39,15 @@ await prisma.promptVariant.createMany({
label: "Prompt Variant 1", label: "Prompt Variant 1",
sortIndex: 0, sortIndex: 0,
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
constructFn: `prompt = { constructFn: dedent`
prompt = {
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
messages: [{ role: "user", content: "What is the capital of {{country}}?" }], messages: [
{
role: "user",
content: \`What is the capital of ${"$"}{scenario.country}?\`
}
],
temperature: 0, temperature: 0,
}`, }`,
}, },
@@ -48,14 +56,14 @@ await prisma.promptVariant.createMany({
label: "Prompt Variant 2", label: "Prompt Variant 2",
sortIndex: 1, sortIndex: 1,
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
constructFn: `prompt = { constructFn: dedent`
prompt = {
model: "gpt-3.5-turbo-0613", model: "gpt-3.5-turbo-0613",
messages: [ messages: [
{ {
role: "user", role: "user",
content: content: \`What is the capital of ${"$"}{scenario.country}? Return just the city name and nothing else.\`
"What is the capital of {{country}}? Return just the city name and nothing else.", }
},
], ],
temperature: 0, temperature: 0,
}`, }`,
@@ -109,3 +117,26 @@ await prisma.testScenario.createMany({
}, },
], ],
}); });
const variants = await prisma.promptVariant.findMany({
where: {
experimentId,
},
});
const scenarios = await prisma.testScenario.findMany({
where: {
experimentId,
},
});
await Promise.all(
variants
.flatMap((variant) =>
scenarios.map((scenario) => ({
promptVariantId: variant.id,
testScenarioId: scenario.id,
})),
)
.map((cell) => generateNewCell(cell.promptVariantId, cell.testScenarioId)),
);

View File

@@ -183,9 +183,9 @@ await prisma.templateVariable.createMany({
await prisma.evaluation.create({ await prisma.evaluation.create({
data: { data: {
experimentId: redditExperiment.id, experimentId: redditExperiment.id,
name: "Relevance Accuracy", label: "Relevance Accuracy",
matchType: "CONTAINS", evalType: "CONTAINS",
matchString: '"{{relevance}}"', value: '"{{relevance}}"',
}, },
}); });
@@ -1124,12 +1124,3 @@ await prisma.testScenario.createMany({
variableValues: vars, variableValues: vars,
})), })),
}); });
// await prisma.evaluation.create({
// data: {
// experimentId: redditExperiment.id,
// name: "Scores Match",
// matchType: "CONTAINS",
// matchString: "{{score}}",
// },
// });

View File

@@ -2,7 +2,7 @@ import fs from "fs";
import path from "path"; import path from "path";
import openapiTS, { type OpenAPI3 } from "openapi-typescript"; import openapiTS, { type OpenAPI3 } from "openapi-typescript";
import YAML from "yaml"; import YAML from "yaml";
import _ from "lodash"; import { pick } from "lodash-es";
import assert from "assert"; import assert from "assert";
const OPENAPI_URL = const OPENAPI_URL =
@@ -31,7 +31,7 @@ modelProperty.oneOf = undefined;
delete schema["paths"]; delete schema["paths"];
assert(schema.components?.schemas); assert(schema.components?.schemas);
schema.components.schemas = _.pick(schema.components?.schemas, [ schema.components.schemas = pick(schema.components?.schemas, [
"CreateChatCompletionRequest", "CreateChatCompletionRequest",
"ChatCompletionRequestMessage", "ChatCompletionRequestMessage",
"ChatCompletionFunctions", "ChatCompletionFunctions",

View File

@@ -12,13 +12,13 @@ import {
Select, Select,
FormHelperText, FormHelperText,
} from "@chakra-ui/react"; } from "@chakra-ui/react";
import { type Evaluation, EvaluationMatchType } from "@prisma/client"; import { type Evaluation, EvalType } from "@prisma/client";
import { useCallback, useState } from "react"; import { useCallback, useState } from "react";
import { BsPencil, BsX } from "react-icons/bs"; import { BsPencil, BsX } from "react-icons/bs";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
type EvalValues = Pick<Evaluation, "name" | "matchString" | "matchType">; type EvalValues = Pick<Evaluation, "label" | "value" | "evalType">;
export function EvaluationEditor(props: { export function EvaluationEditor(props: {
evaluation: Evaluation | null; evaluation: Evaluation | null;
@@ -27,9 +27,9 @@ export function EvaluationEditor(props: {
onCancel: () => void; onCancel: () => void;
}) { }) {
const [values, setValues] = useState<EvalValues>({ const [values, setValues] = useState<EvalValues>({
name: props.evaluation?.name ?? props.defaultName ?? "", label: props.evaluation?.label ?? props.defaultName ?? "",
matchString: props.evaluation?.matchString ?? "", value: props.evaluation?.value ?? "",
matchType: props.evaluation?.matchType ?? "CONTAINS", evalType: props.evaluation?.evalType ?? "CONTAINS",
}); });
return ( return (
@@ -39,7 +39,7 @@ export function EvaluationEditor(props: {
<FormLabel fontSize="sm">Evaluation Name</FormLabel> <FormLabel fontSize="sm">Evaluation Name</FormLabel>
<Input <Input
size="sm" size="sm"
value={values.name} value={values.label}
onChange={(e) => setValues((values) => ({ ...values, name: e.target.value }))} onChange={(e) => setValues((values) => ({ ...values, name: e.target.value }))}
/> />
</FormControl> </FormControl>
@@ -47,15 +47,15 @@ export function EvaluationEditor(props: {
<FormLabel fontSize="sm">Match Type</FormLabel> <FormLabel fontSize="sm">Match Type</FormLabel>
<Select <Select
size="sm" size="sm"
value={values.matchType} value={values.evalType}
onChange={(e) => onChange={(e) =>
setValues((values) => ({ setValues((values) => ({
...values, ...values,
matchType: e.target.value as EvaluationMatchType, evalType: e.target.value as EvalType,
})) }))
} }
> >
{Object.values(EvaluationMatchType).map((type) => ( {Object.values(EvalType).map((type) => (
<option key={type} value={type}> <option key={type} value={type}>
{type} {type}
</option> </option>
@@ -67,8 +67,8 @@ export function EvaluationEditor(props: {
<FormLabel fontSize="sm">Match String</FormLabel> <FormLabel fontSize="sm">Match String</FormLabel>
<Input <Input
size="sm" size="sm"
value={values.matchString} value={values.value}
onChange={(e) => setValues((values) => ({ ...values, matchString: e.target.value }))} onChange={(e) => setValues((values) => ({ ...values, value: e.target.value }))}
/> />
<FormHelperText> <FormHelperText>
This string will be interpreted as a regex and checked against each model output. This string will be interpreted as a regex and checked against each model output.
@@ -156,9 +156,9 @@ export default function EditEvaluations() {
align="center" align="center"
key={evaluation.id} key={evaluation.id}
> >
<Text fontWeight="bold">{evaluation.name}</Text> <Text fontWeight="bold">{evaluation.label}</Text>
<Text flex={1}> <Text flex={1}>
{evaluation.matchType}: &quot;{evaluation.matchString}&quot; {evaluation.evalType}: &quot;{evaluation.value}&quot;
</Text> </Text>
<Button <Button
variant="unstyled" variant="unstyled"

View File

@@ -42,7 +42,7 @@ export const OutputStats = ({
const passed = evaluateOutput(modelOutput, scenario, evaluation); const passed = evaluateOutput(modelOutput, scenario, evaluation);
return ( return (
<HStack spacing={0} key={evaluation.id}> <HStack spacing={0} key={evaluation.id}>
<Text>{evaluation.name}</Text> <Text>{evaluation.label}</Text>
<Icon <Icon
as={passed ? BsCheck : BsX} as={passed ? BsCheck : BsX}
color={passed ? "green.500" : "red.500"} color={passed ? "green.500" : "red.500"}

View File

@@ -1,6 +1,6 @@
import { type DragEvent } from "react"; import { type DragEvent } from "react";
import { api } from "~/utils/api"; import { api } from "~/utils/api";
import { isEqual } from "lodash"; import { isEqual } from "lodash-es";
import { type Scenario } from "./types"; import { type Scenario } from "./types";
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
import { useState } from "react"; import { useState } from "react";

View File

@@ -47,7 +47,7 @@ export default function VariantStats(props: { variant: PromptVariant }) {
const passedFrac = result.passCount / (result.passCount + result.failCount); const passedFrac = result.passCount / (result.passCount + result.failCount);
return ( return (
<HStack key={result.id}> <HStack key={result.id}>
<Text>{result.evaluation.name}</Text> <Text>{result.evaluation.label}</Text>
<Text color={scale(passedFrac).hex()} fontWeight="bold"> <Text color={scale(passedFrac).hex()} fontWeight="bold">
{(passedFrac * 100).toFixed(1)}% {(passedFrac * 100).toFixed(1)}%
</Text> </Text>

View File

@@ -1,7 +1,7 @@
import { type CompletionCreateParams } from "openai/resources/chat"; import { type CompletionCreateParams } from "openai/resources/chat";
import { prisma } from "../db"; import { prisma } from "../db";
import { openai } from "../utils/openai"; import { openai } from "../utils/openai";
import { pick } from "lodash"; import { pick } from "lodash-es";
type AxiosError = { type AxiosError = {
response?: { response?: {

View File

@@ -1,4 +1,4 @@
import { EvaluationMatchType } from "@prisma/client"; import { EvalType } from "@prisma/client";
import { z } from "zod"; import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
@@ -18,18 +18,18 @@ export const evaluationsRouter = createTRPCRouter({
.input( .input(
z.object({ z.object({
experimentId: z.string(), experimentId: z.string(),
name: z.string(), label: z.string(),
matchString: z.string(), value: z.string(),
matchType: z.nativeEnum(EvaluationMatchType), evalType: z.nativeEnum(EvalType),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input }) => {
const evaluation = await prisma.evaluation.create({ const evaluation = await prisma.evaluation.create({
data: { data: {
experimentId: input.experimentId, experimentId: input.experimentId,
name: input.name, label: input.label,
matchString: input.matchString, value: input.value,
matchType: input.matchType, evalType: input.evalType,
}, },
}); });
await reevaluateEvaluation(evaluation); await reevaluateEvaluation(evaluation);
@@ -41,8 +41,8 @@ export const evaluationsRouter = createTRPCRouter({
id: z.string(), id: z.string(),
updates: z.object({ updates: z.object({
name: z.string().optional(), name: z.string().optional(),
matchString: z.string().optional(), value: z.string().optional(),
matchType: z.nativeEnum(EvaluationMatchType).optional(), evalType: z.nativeEnum(EvalType).optional(),
}), }),
}), }),
) )
@@ -50,9 +50,9 @@ export const evaluationsRouter = createTRPCRouter({
await prisma.evaluation.update({ await prisma.evaluation.update({
where: { id: input.id }, where: { id: input.id },
data: { data: {
name: input.updates.name, label: input.updates.name,
matchString: input.updates.matchString, value: input.updates.value,
matchType: input.updates.matchType, evalType: input.updates.evalType,
}, },
}); });
await reevaluateEvaluation( await reevaluateEvaluation(

View File

@@ -1,5 +1,5 @@
import dedent from "dedent"; import dedent from "dedent";
import { isObject } from "lodash"; import { isObject } from "lodash-es";
import { z } from "zod"; import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";

View File

@@ -14,11 +14,11 @@ export const evaluateOutput = (
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call); const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
const matchRegex = fillTemplate(evaluation.matchString, scenario.variableValues as VariableMap); const matchRegex = fillTemplate(evaluation.value, scenario.variableValues as VariableMap);
let match; let match;
switch (evaluation.matchType) { switch (evaluation.evalType) {
case "CONTAINS": case "CONTAINS":
match = stringifiedMessage.match(matchRegex) !== null; match = stringifiedMessage.match(matchRegex) !== null;
break; break;

View File

@@ -1,5 +1,5 @@
/* eslint-disable @typescript-eslint/no-unsafe-call */ /* eslint-disable @typescript-eslint/no-unsafe-call */
import { isObject } from "lodash"; import { isObject } from "lodash-es";
import { Prisma } from "@prisma/client"; import { Prisma } from "@prisma/client";
import { streamChatCompletion } from "./openai"; import { streamChatCompletion } from "./openai";
import { wsConnection } from "~/utils/wsConnection"; import { wsConnection } from "~/utils/wsConnection";

View File

@@ -1,4 +1,4 @@
import { omit } from "lodash"; import { omit } from "lodash-es";
import { env } from "~/env.mjs"; import { env } from "~/env.mjs";
import OpenAI from "openai"; import OpenAI from "openai";

View File

@@ -1,4 +1,4 @@
import { isObject } from "lodash"; import { isObject } from "lodash-es";
import { type JSONSerializable } from "../types"; import { type JSONSerializable } from "../types";
export const shouldStream = (config: JSONSerializable): boolean => { export const shouldStream = (config: JSONSerializable): boolean => {