Fix seeds and update eval field names
This commit is contained in:
@@ -16,7 +16,8 @@
|
|||||||
"postinstall": "prisma generate",
|
"postinstall": "prisma generate",
|
||||||
"lint": "next lint",
|
"lint": "next lint",
|
||||||
"start": "next start",
|
"start": "next start",
|
||||||
"codegen": "tsx src/codegen/export-openai-types.ts"
|
"codegen": "tsx src/codegen/export-openai-types.ts",
|
||||||
|
"seed": "tsx prisma/seed.ts"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@babel/preset-typescript": "^7.22.5",
|
"@babel/preset-typescript": "^7.22.5",
|
||||||
@@ -49,7 +50,7 @@
|
|||||||
"immer": "^10.0.2",
|
"immer": "^10.0.2",
|
||||||
"isolated-vm": "^4.5.0",
|
"isolated-vm": "^4.5.0",
|
||||||
"json-stringify-pretty-compact": "^4.0.0",
|
"json-stringify-pretty-compact": "^4.0.0",
|
||||||
"lodash": "^4.17.21",
|
"lodash-es": "^4.17.21",
|
||||||
"next": "^13.4.2",
|
"next": "^13.4.2",
|
||||||
"next-auth": "^4.22.1",
|
"next-auth": "^4.22.1",
|
||||||
"nextjs-routes": "^2.0.1",
|
"nextjs-routes": "^2.0.1",
|
||||||
@@ -77,7 +78,7 @@
|
|||||||
"@types/cors": "^2.8.13",
|
"@types/cors": "^2.8.13",
|
||||||
"@types/eslint": "^8.37.0",
|
"@types/eslint": "^8.37.0",
|
||||||
"@types/express": "^4.17.17",
|
"@types/express": "^4.17.17",
|
||||||
"@types/lodash": "^4.14.195",
|
"@types/lodash-es": "^4.17.8",
|
||||||
"@types/node": "^18.16.0",
|
"@types/node": "^18.16.0",
|
||||||
"@types/pluralize": "^0.0.30",
|
"@types/pluralize": "^0.0.30",
|
||||||
"@types/react": "^18.2.6",
|
"@types/react": "^18.2.6",
|
||||||
@@ -100,6 +101,6 @@
|
|||||||
"initVersion": "7.14.0"
|
"initVersion": "7.14.0"
|
||||||
},
|
},
|
||||||
"prisma": {
|
"prisma": {
|
||||||
"seed": "tsx prisma/seed.ts"
|
"seed": "pnpm seed"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
18
pnpm-lock.yaml
generated
18
pnpm-lock.yaml
generated
@@ -95,7 +95,7 @@ dependencies:
|
|||||||
json-stringify-pretty-compact:
|
json-stringify-pretty-compact:
|
||||||
specifier: ^4.0.0
|
specifier: ^4.0.0
|
||||||
version: 4.0.0
|
version: 4.0.0
|
||||||
lodash:
|
lodash-es:
|
||||||
specifier: ^4.17.21
|
specifier: ^4.17.21
|
||||||
version: 4.17.21
|
version: 4.17.21
|
||||||
next:
|
next:
|
||||||
@@ -175,9 +175,9 @@ devDependencies:
|
|||||||
'@types/express':
|
'@types/express':
|
||||||
specifier: ^4.17.17
|
specifier: ^4.17.17
|
||||||
version: 4.17.17
|
version: 4.17.17
|
||||||
'@types/lodash':
|
'@types/lodash-es':
|
||||||
specifier: ^4.14.195
|
specifier: ^4.17.8
|
||||||
version: 4.14.195
|
version: 4.17.8
|
||||||
'@types/node':
|
'@types/node':
|
||||||
specifier: ^18.16.0
|
specifier: ^18.16.0
|
||||||
version: 18.16.0
|
version: 18.16.0
|
||||||
@@ -2753,6 +2753,12 @@ packages:
|
|||||||
resolution: {integrity: sha512-dRLjCWHYg4oaA77cxO64oO+7JwCwnIzkZPdrrC71jQmQtlhM556pwKo5bUzqvZndkVbeFLIIi+9TC40JNF5hNQ==}
|
resolution: {integrity: sha512-dRLjCWHYg4oaA77cxO64oO+7JwCwnIzkZPdrrC71jQmQtlhM556pwKo5bUzqvZndkVbeFLIIi+9TC40JNF5hNQ==}
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/@types/lodash-es@4.17.8:
|
||||||
|
resolution: {integrity: sha512-euY3XQcZmIzSy7YH5+Unb3b2X12Wtk54YWINBvvGQ5SmMvwb11JQskGsfkH/5HXK77Kr8GF0wkVDIxzAisWtog==}
|
||||||
|
dependencies:
|
||||||
|
'@types/lodash': 4.14.195
|
||||||
|
dev: true
|
||||||
|
|
||||||
/@types/lodash.mergewith@4.6.7:
|
/@types/lodash.mergewith@4.6.7:
|
||||||
resolution: {integrity: sha512-3m+lkO5CLRRYU0fhGRp7zbsGi6+BZj0uTVSwvcKU+nSlhjA9/QRNfuSGnD2mX6hQA7ZbmcCkzk5h4ZYGOtk14A==}
|
resolution: {integrity: sha512-3m+lkO5CLRRYU0fhGRp7zbsGi6+BZj0uTVSwvcKU+nSlhjA9/QRNfuSGnD2mX6hQA7ZbmcCkzk5h4ZYGOtk14A==}
|
||||||
dependencies:
|
dependencies:
|
||||||
@@ -5379,6 +5385,10 @@ packages:
|
|||||||
p-locate: 5.0.0
|
p-locate: 5.0.0
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/lodash-es@4.17.21:
|
||||||
|
resolution: {integrity: sha512-mKnC+QJ9pWVzv+C4/U3rRsHapFfHvQFoFB92e52xeyGMcX6/OlIl78je1u8vePzYZSkkogMPJ2yjxxsb89cxyw==}
|
||||||
|
dev: false
|
||||||
|
|
||||||
/lodash.merge@4.6.2:
|
/lodash.merge@4.6.2:
|
||||||
resolution: {integrity: sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==}
|
resolution: {integrity: sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==}
|
||||||
dev: true
|
dev: true
|
||||||
|
|||||||
24
prisma/migrations/20230717203031_add_gpt4_eval/migration.sql
Normal file
24
prisma/migrations/20230717203031_add_gpt4_eval/migration.sql
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
/*
|
||||||
|
Warnings:
|
||||||
|
|
||||||
|
- You are about to rename the column `matchString` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
|
||||||
|
- You are about to rename the column `matchType` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
|
||||||
|
- You are about to rename the column `name` on the `Evaluation` table. If there is any code or views referring to the old name, they will break.
|
||||||
|
- You are about to rename the enum `EvaluationMatchType` to `EvalType`. If there is any code or views referring to the old name, they will break.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- RenameEnum
|
||||||
|
ALTER TYPE "EvaluationMatchType" RENAME TO "EvalType";
|
||||||
|
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "Evaluation" RENAME COLUMN "matchString" TO "value";
|
||||||
|
ALTER TABLE "Evaluation" RENAME COLUMN "matchType" TO "evalType";
|
||||||
|
ALTER TABLE "Evaluation" RENAME COLUMN "name" TO "label";
|
||||||
|
|
||||||
|
-- AlterColumnType
|
||||||
|
ALTER TABLE "Evaluation" ALTER COLUMN "evalType" TYPE "EvalType" USING "evalType"::text::"EvalType";
|
||||||
|
|
||||||
|
-- SetNotNullConstraint
|
||||||
|
ALTER TABLE "Evaluation" ALTER COLUMN "evalType" SET NOT NULL;
|
||||||
|
ALTER TABLE "Evaluation" ALTER COLUMN "label" SET NOT NULL;
|
||||||
|
ALTER TABLE "Evaluation" ALTER COLUMN "value" SET NOT NULL;
|
||||||
@@ -2,8 +2,7 @@
|
|||||||
// learn more about it in the docs: https://pris.ly/d/prisma-schema
|
// learn more about it in the docs: https://pris.ly/d/prisma-schema
|
||||||
|
|
||||||
generator client {
|
generator client {
|
||||||
provider = "prisma-client-js"
|
provider = "prisma-client-js"
|
||||||
previewFeatures = ["jsonProtocol"]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
datasource db {
|
datasource db {
|
||||||
@@ -130,7 +129,7 @@ model ModelOutput {
|
|||||||
@@index([inputHash])
|
@@index([inputHash])
|
||||||
}
|
}
|
||||||
|
|
||||||
enum EvaluationMatchType {
|
enum EvalType {
|
||||||
CONTAINS
|
CONTAINS
|
||||||
DOES_NOT_CONTAIN
|
DOES_NOT_CONTAIN
|
||||||
}
|
}
|
||||||
@@ -138,9 +137,9 @@ enum EvaluationMatchType {
|
|||||||
model Evaluation {
|
model Evaluation {
|
||||||
id String @id @default(uuid()) @db.Uuid
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
name String
|
label String
|
||||||
matchString String
|
evalType EvalType
|
||||||
matchType EvaluationMatchType
|
value String
|
||||||
|
|
||||||
experimentId String @db.Uuid
|
experimentId String @db.Uuid
|
||||||
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
|
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
|
import dedent from "dedent";
|
||||||
|
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||||
|
|
||||||
const experimentId = "11111111-1111-1111-1111-111111111111";
|
const experimentId = "11111111-1111-1111-1111-111111111111";
|
||||||
|
|
||||||
@@ -9,7 +11,7 @@ await prisma.experiment.deleteMany({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const experiment = await prisma.experiment.create({
|
await prisma.experiment.create({
|
||||||
data: {
|
data: {
|
||||||
id: experimentId,
|
id: experimentId,
|
||||||
label: "Country Capitals Example",
|
label: "Country Capitals Example",
|
||||||
@@ -37,28 +39,34 @@ await prisma.promptVariant.createMany({
|
|||||||
label: "Prompt Variant 1",
|
label: "Prompt Variant 1",
|
||||||
sortIndex: 0,
|
sortIndex: 0,
|
||||||
model: "gpt-3.5-turbo-0613",
|
model: "gpt-3.5-turbo-0613",
|
||||||
constructFn: `prompt = {
|
constructFn: dedent`
|
||||||
model: "gpt-3.5-turbo-0613",
|
prompt = {
|
||||||
messages: [{ role: "user", content: "What is the capital of {{country}}?" }],
|
model: "gpt-3.5-turbo-0613",
|
||||||
temperature: 0,
|
messages: [
|
||||||
}`,
|
{
|
||||||
|
role: "user",
|
||||||
|
content: \`What is the capital of ${"$"}{scenario.country}?\`
|
||||||
|
}
|
||||||
|
],
|
||||||
|
temperature: 0,
|
||||||
|
}`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
experimentId,
|
experimentId,
|
||||||
label: "Prompt Variant 2",
|
label: "Prompt Variant 2",
|
||||||
sortIndex: 1,
|
sortIndex: 1,
|
||||||
model: "gpt-3.5-turbo-0613",
|
model: "gpt-3.5-turbo-0613",
|
||||||
constructFn: `prompt = {
|
constructFn: dedent`
|
||||||
model: "gpt-3.5-turbo-0613",
|
prompt = {
|
||||||
messages: [
|
model: "gpt-3.5-turbo-0613",
|
||||||
{
|
messages: [
|
||||||
role: "user",
|
{
|
||||||
content:
|
role: "user",
|
||||||
"What is the capital of {{country}}? Return just the city name and nothing else.",
|
content: \`What is the capital of ${"$"}{scenario.country}? Return just the city name and nothing else.\`
|
||||||
},
|
}
|
||||||
],
|
],
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
}`,
|
}`,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
});
|
});
|
||||||
@@ -109,3 +117,26 @@ await prisma.testScenario.createMany({
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const variants = await prisma.promptVariant.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const scenarios = await prisma.testScenario.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
await Promise.all(
|
||||||
|
variants
|
||||||
|
.flatMap((variant) =>
|
||||||
|
scenarios.map((scenario) => ({
|
||||||
|
promptVariantId: variant.id,
|
||||||
|
testScenarioId: scenario.id,
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.map((cell) => generateNewCell(cell.promptVariantId, cell.testScenarioId)),
|
||||||
|
);
|
||||||
|
|||||||
@@ -183,9 +183,9 @@ await prisma.templateVariable.createMany({
|
|||||||
await prisma.evaluation.create({
|
await prisma.evaluation.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: redditExperiment.id,
|
experimentId: redditExperiment.id,
|
||||||
name: "Relevance Accuracy",
|
label: "Relevance Accuracy",
|
||||||
matchType: "CONTAINS",
|
evalType: "CONTAINS",
|
||||||
matchString: '"{{relevance}}"',
|
value: '"{{relevance}}"',
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -1124,12 +1124,3 @@ await prisma.testScenario.createMany({
|
|||||||
variableValues: vars,
|
variableValues: vars,
|
||||||
})),
|
})),
|
||||||
});
|
});
|
||||||
|
|
||||||
// await prisma.evaluation.create({
|
|
||||||
// data: {
|
|
||||||
// experimentId: redditExperiment.id,
|
|
||||||
// name: "Scores Match",
|
|
||||||
// matchType: "CONTAINS",
|
|
||||||
// matchString: "{{score}}",
|
|
||||||
// },
|
|
||||||
// });
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import fs from "fs";
|
|||||||
import path from "path";
|
import path from "path";
|
||||||
import openapiTS, { type OpenAPI3 } from "openapi-typescript";
|
import openapiTS, { type OpenAPI3 } from "openapi-typescript";
|
||||||
import YAML from "yaml";
|
import YAML from "yaml";
|
||||||
import _ from "lodash";
|
import { pick } from "lodash-es";
|
||||||
import assert from "assert";
|
import assert from "assert";
|
||||||
|
|
||||||
const OPENAPI_URL =
|
const OPENAPI_URL =
|
||||||
@@ -31,7 +31,7 @@ modelProperty.oneOf = undefined;
|
|||||||
|
|
||||||
delete schema["paths"];
|
delete schema["paths"];
|
||||||
assert(schema.components?.schemas);
|
assert(schema.components?.schemas);
|
||||||
schema.components.schemas = _.pick(schema.components?.schemas, [
|
schema.components.schemas = pick(schema.components?.schemas, [
|
||||||
"CreateChatCompletionRequest",
|
"CreateChatCompletionRequest",
|
||||||
"ChatCompletionRequestMessage",
|
"ChatCompletionRequestMessage",
|
||||||
"ChatCompletionFunctions",
|
"ChatCompletionFunctions",
|
||||||
|
|||||||
@@ -12,13 +12,13 @@ import {
|
|||||||
Select,
|
Select,
|
||||||
FormHelperText,
|
FormHelperText,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { type Evaluation, EvaluationMatchType } from "@prisma/client";
|
import { type Evaluation, EvalType } from "@prisma/client";
|
||||||
import { useCallback, useState } from "react";
|
import { useCallback, useState } from "react";
|
||||||
import { BsPencil, BsX } from "react-icons/bs";
|
import { BsPencil, BsX } from "react-icons/bs";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
|
|
||||||
type EvalValues = Pick<Evaluation, "name" | "matchString" | "matchType">;
|
type EvalValues = Pick<Evaluation, "label" | "value" | "evalType">;
|
||||||
|
|
||||||
export function EvaluationEditor(props: {
|
export function EvaluationEditor(props: {
|
||||||
evaluation: Evaluation | null;
|
evaluation: Evaluation | null;
|
||||||
@@ -27,9 +27,9 @@ export function EvaluationEditor(props: {
|
|||||||
onCancel: () => void;
|
onCancel: () => void;
|
||||||
}) {
|
}) {
|
||||||
const [values, setValues] = useState<EvalValues>({
|
const [values, setValues] = useState<EvalValues>({
|
||||||
name: props.evaluation?.name ?? props.defaultName ?? "",
|
label: props.evaluation?.label ?? props.defaultName ?? "",
|
||||||
matchString: props.evaluation?.matchString ?? "",
|
value: props.evaluation?.value ?? "",
|
||||||
matchType: props.evaluation?.matchType ?? "CONTAINS",
|
evalType: props.evaluation?.evalType ?? "CONTAINS",
|
||||||
});
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -39,7 +39,7 @@ export function EvaluationEditor(props: {
|
|||||||
<FormLabel fontSize="sm">Evaluation Name</FormLabel>
|
<FormLabel fontSize="sm">Evaluation Name</FormLabel>
|
||||||
<Input
|
<Input
|
||||||
size="sm"
|
size="sm"
|
||||||
value={values.name}
|
value={values.label}
|
||||||
onChange={(e) => setValues((values) => ({ ...values, name: e.target.value }))}
|
onChange={(e) => setValues((values) => ({ ...values, name: e.target.value }))}
|
||||||
/>
|
/>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
@@ -47,15 +47,15 @@ export function EvaluationEditor(props: {
|
|||||||
<FormLabel fontSize="sm">Match Type</FormLabel>
|
<FormLabel fontSize="sm">Match Type</FormLabel>
|
||||||
<Select
|
<Select
|
||||||
size="sm"
|
size="sm"
|
||||||
value={values.matchType}
|
value={values.evalType}
|
||||||
onChange={(e) =>
|
onChange={(e) =>
|
||||||
setValues((values) => ({
|
setValues((values) => ({
|
||||||
...values,
|
...values,
|
||||||
matchType: e.target.value as EvaluationMatchType,
|
evalType: e.target.value as EvalType,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
>
|
>
|
||||||
{Object.values(EvaluationMatchType).map((type) => (
|
{Object.values(EvalType).map((type) => (
|
||||||
<option key={type} value={type}>
|
<option key={type} value={type}>
|
||||||
{type}
|
{type}
|
||||||
</option>
|
</option>
|
||||||
@@ -67,8 +67,8 @@ export function EvaluationEditor(props: {
|
|||||||
<FormLabel fontSize="sm">Match String</FormLabel>
|
<FormLabel fontSize="sm">Match String</FormLabel>
|
||||||
<Input
|
<Input
|
||||||
size="sm"
|
size="sm"
|
||||||
value={values.matchString}
|
value={values.value}
|
||||||
onChange={(e) => setValues((values) => ({ ...values, matchString: e.target.value }))}
|
onChange={(e) => setValues((values) => ({ ...values, value: e.target.value }))}
|
||||||
/>
|
/>
|
||||||
<FormHelperText>
|
<FormHelperText>
|
||||||
This string will be interpreted as a regex and checked against each model output.
|
This string will be interpreted as a regex and checked against each model output.
|
||||||
@@ -156,9 +156,9 @@ export default function EditEvaluations() {
|
|||||||
align="center"
|
align="center"
|
||||||
key={evaluation.id}
|
key={evaluation.id}
|
||||||
>
|
>
|
||||||
<Text fontWeight="bold">{evaluation.name}</Text>
|
<Text fontWeight="bold">{evaluation.label}</Text>
|
||||||
<Text flex={1}>
|
<Text flex={1}>
|
||||||
{evaluation.matchType}: "{evaluation.matchString}"
|
{evaluation.evalType}: "{evaluation.value}"
|
||||||
</Text>
|
</Text>
|
||||||
<Button
|
<Button
|
||||||
variant="unstyled"
|
variant="unstyled"
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ export const OutputStats = ({
|
|||||||
const passed = evaluateOutput(modelOutput, scenario, evaluation);
|
const passed = evaluateOutput(modelOutput, scenario, evaluation);
|
||||||
return (
|
return (
|
||||||
<HStack spacing={0} key={evaluation.id}>
|
<HStack spacing={0} key={evaluation.id}>
|
||||||
<Text>{evaluation.name}</Text>
|
<Text>{evaluation.label}</Text>
|
||||||
<Icon
|
<Icon
|
||||||
as={passed ? BsCheck : BsX}
|
as={passed ? BsCheck : BsX}
|
||||||
color={passed ? "green.500" : "red.500"}
|
color={passed ? "green.500" : "red.500"}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { type DragEvent } from "react";
|
import { type DragEvent } from "react";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { isEqual } from "lodash";
|
import { isEqual } from "lodash-es";
|
||||||
import { type Scenario } from "./types";
|
import { type Scenario } from "./types";
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
|||||||
const passedFrac = result.passCount / (result.passCount + result.failCount);
|
const passedFrac = result.passCount / (result.passCount + result.failCount);
|
||||||
return (
|
return (
|
||||||
<HStack key={result.id}>
|
<HStack key={result.id}>
|
||||||
<Text>{result.evaluation.name}</Text>
|
<Text>{result.evaluation.label}</Text>
|
||||||
<Text color={scale(passedFrac).hex()} fontWeight="bold">
|
<Text color={scale(passedFrac).hex()} fontWeight="bold">
|
||||||
{(passedFrac * 100).toFixed(1)}%
|
{(passedFrac * 100).toFixed(1)}%
|
||||||
</Text>
|
</Text>
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import { type CompletionCreateParams } from "openai/resources/chat";
|
import { type CompletionCreateParams } from "openai/resources/chat";
|
||||||
import { prisma } from "../db";
|
import { prisma } from "../db";
|
||||||
import { openai } from "../utils/openai";
|
import { openai } from "../utils/openai";
|
||||||
import { pick } from "lodash";
|
import { pick } from "lodash-es";
|
||||||
|
|
||||||
type AxiosError = {
|
type AxiosError = {
|
||||||
response?: {
|
response?: {
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { EvaluationMatchType } from "@prisma/client";
|
import { EvalType } from "@prisma/client";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
@@ -18,18 +18,18 @@ export const evaluationsRouter = createTRPCRouter({
|
|||||||
.input(
|
.input(
|
||||||
z.object({
|
z.object({
|
||||||
experimentId: z.string(),
|
experimentId: z.string(),
|
||||||
name: z.string(),
|
label: z.string(),
|
||||||
matchString: z.string(),
|
value: z.string(),
|
||||||
matchType: z.nativeEnum(EvaluationMatchType),
|
evalType: z.nativeEnum(EvalType),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.mutation(async ({ input }) => {
|
.mutation(async ({ input }) => {
|
||||||
const evaluation = await prisma.evaluation.create({
|
const evaluation = await prisma.evaluation.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
name: input.name,
|
label: input.label,
|
||||||
matchString: input.matchString,
|
value: input.value,
|
||||||
matchType: input.matchType,
|
evalType: input.evalType,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
await reevaluateEvaluation(evaluation);
|
await reevaluateEvaluation(evaluation);
|
||||||
@@ -41,8 +41,8 @@ export const evaluationsRouter = createTRPCRouter({
|
|||||||
id: z.string(),
|
id: z.string(),
|
||||||
updates: z.object({
|
updates: z.object({
|
||||||
name: z.string().optional(),
|
name: z.string().optional(),
|
||||||
matchString: z.string().optional(),
|
value: z.string().optional(),
|
||||||
matchType: z.nativeEnum(EvaluationMatchType).optional(),
|
evalType: z.nativeEnum(EvalType).optional(),
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
@@ -50,9 +50,9 @@ export const evaluationsRouter = createTRPCRouter({
|
|||||||
await prisma.evaluation.update({
|
await prisma.evaluation.update({
|
||||||
where: { id: input.id },
|
where: { id: input.id },
|
||||||
data: {
|
data: {
|
||||||
name: input.updates.name,
|
label: input.updates.name,
|
||||||
matchString: input.updates.matchString,
|
value: input.updates.value,
|
||||||
matchType: input.updates.matchType,
|
evalType: input.updates.evalType,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
await reevaluateEvaluation(
|
await reevaluateEvaluation(
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import dedent from "dedent";
|
import dedent from "dedent";
|
||||||
import { isObject } from "lodash";
|
import { isObject } from "lodash-es";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
|
|||||||
@@ -14,11 +14,11 @@ export const evaluateOutput = (
|
|||||||
|
|
||||||
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
|
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
|
||||||
|
|
||||||
const matchRegex = fillTemplate(evaluation.matchString, scenario.variableValues as VariableMap);
|
const matchRegex = fillTemplate(evaluation.value, scenario.variableValues as VariableMap);
|
||||||
|
|
||||||
let match;
|
let match;
|
||||||
|
|
||||||
switch (evaluation.matchType) {
|
switch (evaluation.evalType) {
|
||||||
case "CONTAINS":
|
case "CONTAINS":
|
||||||
match = stringifiedMessage.match(matchRegex) !== null;
|
match = stringifiedMessage.match(matchRegex) !== null;
|
||||||
break;
|
break;
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
/* eslint-disable @typescript-eslint/no-unsafe-call */
|
/* eslint-disable @typescript-eslint/no-unsafe-call */
|
||||||
import { isObject } from "lodash";
|
import { isObject } from "lodash-es";
|
||||||
import { Prisma } from "@prisma/client";
|
import { Prisma } from "@prisma/client";
|
||||||
import { streamChatCompletion } from "./openai";
|
import { streamChatCompletion } from "./openai";
|
||||||
import { wsConnection } from "~/utils/wsConnection";
|
import { wsConnection } from "~/utils/wsConnection";
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { omit } from "lodash";
|
import { omit } from "lodash-es";
|
||||||
import { env } from "~/env.mjs";
|
import { env } from "~/env.mjs";
|
||||||
|
|
||||||
import OpenAI from "openai";
|
import OpenAI from "openai";
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { isObject } from "lodash";
|
import { isObject } from "lodash-es";
|
||||||
import { type JSONSerializable } from "../types";
|
import { type JSONSerializable } from "../types";
|
||||||
|
|
||||||
export const shouldStream = (config: JSONSerializable): boolean => {
|
export const shouldStream = (config: JSONSerializable): boolean => {
|
||||||
|
|||||||
Reference in New Issue
Block a user