add evaluations
This commit is contained in:
@@ -29,6 +29,7 @@
|
|||||||
"@trpc/next": "^10.26.0",
|
"@trpc/next": "^10.26.0",
|
||||||
"@trpc/react-query": "^10.26.0",
|
"@trpc/react-query": "^10.26.0",
|
||||||
"@trpc/server": "^10.26.0",
|
"@trpc/server": "^10.26.0",
|
||||||
|
"chroma-js": "^2.4.2",
|
||||||
"concurrently": "^8.2.0",
|
"concurrently": "^8.2.0",
|
||||||
"cors": "^2.8.5",
|
"cors": "^2.8.5",
|
||||||
"dayjs": "^1.11.8",
|
"dayjs": "^1.11.8",
|
||||||
@@ -57,6 +58,7 @@
|
|||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@openapi-contrib/openapi-schema-to-json-schema": "^4.0.5",
|
"@openapi-contrib/openapi-schema-to-json-schema": "^4.0.5",
|
||||||
|
"@types/chroma-js": "^2.4.0",
|
||||||
"@types/cors": "^2.8.13",
|
"@types/cors": "^2.8.13",
|
||||||
"@types/eslint": "^8.37.0",
|
"@types/eslint": "^8.37.0",
|
||||||
"@types/express": "^4.17.17",
|
"@types/express": "^4.17.17",
|
||||||
|
|||||||
16
pnpm-lock.yaml
generated
16
pnpm-lock.yaml
generated
@@ -1,4 +1,4 @@
|
|||||||
lockfileVersion: '6.0'
|
lockfileVersion: '6.1'
|
||||||
|
|
||||||
settings:
|
settings:
|
||||||
autoInstallPeers: true
|
autoInstallPeers: true
|
||||||
@@ -50,6 +50,9 @@ dependencies:
|
|||||||
'@trpc/server':
|
'@trpc/server':
|
||||||
specifier: ^10.26.0
|
specifier: ^10.26.0
|
||||||
version: 10.26.0
|
version: 10.26.0
|
||||||
|
chroma-js:
|
||||||
|
specifier: ^2.4.2
|
||||||
|
version: 2.4.2
|
||||||
concurrently:
|
concurrently:
|
||||||
specifier: ^8.2.0
|
specifier: ^8.2.0
|
||||||
version: 8.2.0
|
version: 8.2.0
|
||||||
@@ -130,6 +133,9 @@ devDependencies:
|
|||||||
'@openapi-contrib/openapi-schema-to-json-schema':
|
'@openapi-contrib/openapi-schema-to-json-schema':
|
||||||
specifier: ^4.0.5
|
specifier: ^4.0.5
|
||||||
version: 4.0.5
|
version: 4.0.5
|
||||||
|
'@types/chroma-js':
|
||||||
|
specifier: ^2.4.0
|
||||||
|
version: 2.4.0
|
||||||
'@types/cors':
|
'@types/cors':
|
||||||
specifier: ^2.8.13
|
specifier: ^2.8.13
|
||||||
version: 2.8.13
|
version: 2.8.13
|
||||||
@@ -2072,6 +2078,10 @@ packages:
|
|||||||
'@types/node': 18.16.0
|
'@types/node': 18.16.0
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/@types/chroma-js@2.4.0:
|
||||||
|
resolution: {integrity: sha512-JklMxityrwjBTjGY2anH8JaTx3yjRU3/sEHSblLH1ba5lqcSh1LnImXJZO5peJfXyqKYWjHTGy4s5Wz++hARrw==}
|
||||||
|
dev: true
|
||||||
|
|
||||||
/@types/connect@3.4.35:
|
/@types/connect@3.4.35:
|
||||||
resolution: {integrity: sha512-cdeYyv4KWoEgpBISTxWvqYsVy444DOqehiF3fM3ne10AmJ62RSyNkUnxMJXHQWRQQX2eR94m5y1IZyDwBjV9FQ==}
|
resolution: {integrity: sha512-cdeYyv4KWoEgpBISTxWvqYsVy444DOqehiF3fM3ne10AmJ62RSyNkUnxMJXHQWRQQX2eR94m5y1IZyDwBjV9FQ==}
|
||||||
dependencies:
|
dependencies:
|
||||||
@@ -2697,6 +2707,10 @@ packages:
|
|||||||
fsevents: 2.3.2
|
fsevents: 2.3.2
|
||||||
dev: false
|
dev: false
|
||||||
|
|
||||||
|
/chroma-js@2.4.2:
|
||||||
|
resolution: {integrity: sha512-U9eDw6+wt7V8z5NncY2jJfZa+hUH8XEj8FQHgFJTrUFnJfXYf4Ml4adI2vXZOjqRDpFWtYVWypDfZwnJ+HIR4A==}
|
||||||
|
dev: false
|
||||||
|
|
||||||
/client-only@0.0.1:
|
/client-only@0.0.1:
|
||||||
resolution: {integrity: sha512-IV3Ou0jSMzZrd3pZ48nLkT9DA7Ag1pnPzaiQhpW7c3RbcqqzvzzVu+L8gfqMp/8IM2MQtSiqaCxrrcfu8I8rMA==}
|
resolution: {integrity: sha512-IV3Ou0jSMzZrd3pZ48nLkT9DA7Ag1pnPzaiQhpW7c3RbcqqzvzzVu+L8gfqMp/8IM2MQtSiqaCxrrcfu8I8rMA==}
|
||||||
dev: false
|
dev: false
|
||||||
|
|||||||
@@ -0,0 +1,40 @@
|
|||||||
|
-- CreateEnum
|
||||||
|
CREATE TYPE "EvaluationMatchType" AS ENUM ('CONTAINS', 'DOES_NOT_CONTAIN');
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "Evaluation" (
|
||||||
|
"id" UUID NOT NULL,
|
||||||
|
"name" TEXT NOT NULL,
|
||||||
|
"matchString" TEXT NOT NULL,
|
||||||
|
"matchType" "EvaluationMatchType" NOT NULL,
|
||||||
|
"experimentId" UUID NOT NULL,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||||
|
|
||||||
|
CONSTRAINT "Evaluation_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "EvaluationResult" (
|
||||||
|
"id" UUID NOT NULL,
|
||||||
|
"passCount" INTEGER NOT NULL,
|
||||||
|
"failCount" INTEGER NOT NULL,
|
||||||
|
"evaluationId" UUID NOT NULL,
|
||||||
|
"promptVariantId" UUID NOT NULL,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||||
|
|
||||||
|
CONSTRAINT "EvaluationResult_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "EvaluationResult_evaluationId_promptVariantId_key" ON "EvaluationResult"("evaluationId", "promptVariantId");
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "Evaluation" ADD CONSTRAINT "Evaluation_experimentId_fkey" FOREIGN KEY ("experimentId") REFERENCES "Experiment"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "EvaluationResult" ADD CONSTRAINT "EvaluationResult_evaluationId_fkey" FOREIGN KEY ("evaluationId") REFERENCES "Evaluation"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "EvaluationResult" ADD CONSTRAINT "EvaluationResult_promptVariantId_fkey" FOREIGN KEY ("promptVariantId") REFERENCES "PromptVariant"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
@@ -22,6 +22,7 @@ model Experiment {
|
|||||||
TemplateVariable TemplateVariable[]
|
TemplateVariable TemplateVariable[]
|
||||||
PromptVariant PromptVariant[]
|
PromptVariant PromptVariant[]
|
||||||
TestScenario TestScenario[]
|
TestScenario TestScenario[]
|
||||||
|
Evaluation Evaluation[]
|
||||||
}
|
}
|
||||||
|
|
||||||
model PromptVariant {
|
model PromptVariant {
|
||||||
@@ -37,9 +38,10 @@ model PromptVariant {
|
|||||||
experimentId String @db.Uuid
|
experimentId String @db.Uuid
|
||||||
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
|
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @updatedAt
|
updatedAt DateTime @updatedAt
|
||||||
ModelOutput ModelOutput[]
|
ModelOutput ModelOutput[]
|
||||||
|
EvaluationResult EvaluationResult[]
|
||||||
|
|
||||||
@@index([uiId])
|
@@index([uiId])
|
||||||
}
|
}
|
||||||
@@ -80,7 +82,7 @@ model ModelOutput {
|
|||||||
output Json
|
output Json
|
||||||
statusCode Int
|
statusCode Int
|
||||||
errorMessage String?
|
errorMessage String?
|
||||||
timeToComplete Int @default(0)
|
timeToComplete Int @default(0)
|
||||||
|
|
||||||
promptTokens Int? // Added promptTokens field
|
promptTokens Int? // Added promptTokens field
|
||||||
completionTokens Int? // Added completionTokens field
|
completionTokens Int? // Added completionTokens field
|
||||||
@@ -98,6 +100,44 @@ model ModelOutput {
|
|||||||
@@index([inputHash])
|
@@index([inputHash])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum EvaluationMatchType {
|
||||||
|
CONTAINS
|
||||||
|
DOES_NOT_CONTAIN
|
||||||
|
}
|
||||||
|
|
||||||
|
model Evaluation {
|
||||||
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
|
name String
|
||||||
|
matchString String
|
||||||
|
matchType EvaluationMatchType
|
||||||
|
|
||||||
|
experimentId String @db.Uuid
|
||||||
|
experiment Experiment @relation(fields: [experimentId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
|
createdAt DateTime @default(now())
|
||||||
|
updatedAt DateTime @updatedAt
|
||||||
|
EvaluationResult EvaluationResult[]
|
||||||
|
}
|
||||||
|
|
||||||
|
model EvaluationResult {
|
||||||
|
id String @id @default(uuid()) @db.Uuid
|
||||||
|
|
||||||
|
passCount Int
|
||||||
|
failCount Int
|
||||||
|
|
||||||
|
evaluationId String @db.Uuid
|
||||||
|
evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
|
promptVariantId String @db.Uuid
|
||||||
|
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
|
createdAt DateTime @default(now())
|
||||||
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
|
@@unique([evaluationId, promptVariantId])
|
||||||
|
}
|
||||||
|
|
||||||
// Necessary for Next auth
|
// Necessary for Next auth
|
||||||
model Account {
|
model Account {
|
||||||
id String @id @default(cuid())
|
id String @id @default(cuid())
|
||||||
|
|||||||
@@ -1,43 +1,216 @@
|
|||||||
import { Text, Heading, Stack } from "@chakra-ui/react";
|
import {
|
||||||
import { useState } from "react";
|
Text,
|
||||||
|
Button,
|
||||||
|
HStack,
|
||||||
|
Heading,
|
||||||
|
Icon,
|
||||||
|
Input,
|
||||||
|
Stack,
|
||||||
|
VStack,
|
||||||
|
FormControl,
|
||||||
|
FormLabel,
|
||||||
|
Select,
|
||||||
|
FormHelperText,
|
||||||
|
} from "@chakra-ui/react";
|
||||||
|
import { type Evaluation, EvaluationMatchType } from "@prisma/client";
|
||||||
|
import { useCallback, useState } from "react";
|
||||||
|
import { BsPencil, BsX } from "react-icons/bs";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
import { useStore } from "~/utils/store";
|
|
||||||
|
type EvalValues = Pick<Evaluation, "name" | "matchString" | "matchType">;
|
||||||
|
|
||||||
|
export function EvaluationEditor(props: {
|
||||||
|
evaluation: Evaluation | null;
|
||||||
|
defaultName?: string;
|
||||||
|
onSave: (id: string | undefined, vals: EvalValues) => void;
|
||||||
|
onCancel: () => void;
|
||||||
|
}) {
|
||||||
|
const [values, setValues] = useState<EvalValues>({
|
||||||
|
name: props.evaluation?.name ?? props.defaultName ?? "",
|
||||||
|
matchString: props.evaluation?.matchString ?? "",
|
||||||
|
matchType: props.evaluation?.matchType ?? "CONTAINS",
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<VStack borderTopWidth={1} borderColor="gray.200" py={4}>
|
||||||
|
<HStack w="100%">
|
||||||
|
<FormControl flex={1}>
|
||||||
|
<FormLabel fontSize="sm">Evaluation Name</FormLabel>
|
||||||
|
<Input
|
||||||
|
size="sm"
|
||||||
|
value={values.name}
|
||||||
|
onChange={(e) => setValues((values) => ({ ...values, name: e.target.value }))}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
<FormControl flex={1}>
|
||||||
|
<FormLabel fontSize="sm">Match Type</FormLabel>
|
||||||
|
<Select
|
||||||
|
size="sm"
|
||||||
|
value={values.matchType}
|
||||||
|
onChange={(e) =>
|
||||||
|
setValues((values) => ({
|
||||||
|
...values,
|
||||||
|
matchType: e.target.value as EvaluationMatchType,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
>
|
||||||
|
{Object.values(EvaluationMatchType).map((type) => (
|
||||||
|
<option key={type} value={type}>
|
||||||
|
{type}
|
||||||
|
</option>
|
||||||
|
))}
|
||||||
|
</Select>
|
||||||
|
</FormControl>
|
||||||
|
</HStack>
|
||||||
|
<FormControl>
|
||||||
|
<FormLabel fontSize="sm">Match String</FormLabel>
|
||||||
|
<FormHelperText>
|
||||||
|
This string will be interpreted as a regex and checked against each model output.
|
||||||
|
</FormHelperText>
|
||||||
|
<Input
|
||||||
|
size="sm"
|
||||||
|
value={values.matchString}
|
||||||
|
onChange={(e) => setValues((values) => ({ ...values, matchString: e.target.value }))}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
<HStack alignSelf="flex-end">
|
||||||
|
<Button size="sm" onClick={props.onCancel} colorScheme="gray">
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
onClick={() => props.onSave(props.evaluation?.id, values)}
|
||||||
|
colorScheme="blue"
|
||||||
|
>
|
||||||
|
Save
|
||||||
|
</Button>
|
||||||
|
</HStack>
|
||||||
|
</VStack>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
export default function EditEvaluations() {
|
export default function EditEvaluations() {
|
||||||
const experiment = useExperiment();
|
const experiment = useExperiment();
|
||||||
const vars =
|
const evaluations =
|
||||||
api.templateVars.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
|
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
|
||||||
|
|
||||||
const [newVariable, setNewVariable] = useState<string>("");
|
const [editingId, setEditingId] = useState<string | null>(null);
|
||||||
const newVarIsValid = newVariable.length > 0 && !vars.map((v) => v.label).includes(newVariable);
|
|
||||||
|
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
const addVarMutation = api.templateVars.create.useMutation();
|
const createMutation = api.evaluations.create.useMutation();
|
||||||
const [onAddVar] = useHandledAsyncCallback(async () => {
|
const updateMutation = api.evaluations.update.useMutation();
|
||||||
if (!experiment.data?.id) return;
|
|
||||||
if (!newVarIsValid) return;
|
|
||||||
await addVarMutation.mutateAsync({
|
|
||||||
experimentId: experiment.data.id,
|
|
||||||
label: newVariable,
|
|
||||||
});
|
|
||||||
await utils.templateVars.list.invalidate();
|
|
||||||
setNewVariable("");
|
|
||||||
}, [addVarMutation, experiment.data?.id, newVarIsValid, newVariable]);
|
|
||||||
|
|
||||||
const deleteMutation = api.templateVars.delete.useMutation();
|
const deleteMutation = api.evaluations.delete.useMutation();
|
||||||
const [onDeleteVar] = useHandledAsyncCallback(async (id: string) => {
|
const [onDelete] = useHandledAsyncCallback(async (id: string) => {
|
||||||
await deleteMutation.mutateAsync({ id });
|
await deleteMutation.mutateAsync({ id });
|
||||||
await utils.templateVars.list.invalidate();
|
await utils.evaluations.list.invalidate();
|
||||||
|
await utils.evaluations.results.invalidate();
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const closeDrawer = useStore((state) => state.closeDrawer);
|
const [onSave] = useHandledAsyncCallback(async (id: string | undefined, vals: EvalValues) => {
|
||||||
|
setEditingId(null);
|
||||||
|
if (!experiment.data?.id) return;
|
||||||
|
|
||||||
|
if (id) {
|
||||||
|
await updateMutation.mutateAsync({
|
||||||
|
id,
|
||||||
|
updates: vals,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
await createMutation.mutateAsync({
|
||||||
|
experimentId: experiment.data.id,
|
||||||
|
...vals,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
await utils.evaluations.list.invalidate();
|
||||||
|
await utils.evaluations.results.invalidate();
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const onCancel = useCallback(() => {
|
||||||
|
setEditingId(null);
|
||||||
|
}, []);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Stack>
|
<Stack>
|
||||||
<Heading size="sm">Edit Evaluations</Heading>
|
<Heading size="sm">Evaluations</Heading>
|
||||||
<Stack spacing={2} pt={2}>
|
<Stack spacing={4}>
|
||||||
<Text fontSize="sm"></Text>
|
<Text fontSize="sm">
|
||||||
|
Evaluations allow you to compare prompt performance in an automated way.
|
||||||
|
</Text>
|
||||||
|
<Stack spacing={2}>
|
||||||
|
{evaluations.map((evaluation) =>
|
||||||
|
editingId == evaluation.id ? (
|
||||||
|
<EvaluationEditor
|
||||||
|
evaluation={evaluation}
|
||||||
|
onSave={onSave}
|
||||||
|
onCancel={onCancel}
|
||||||
|
key={evaluation.id}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<HStack
|
||||||
|
fontSize="sm"
|
||||||
|
borderTopWidth={1}
|
||||||
|
borderColor="gray.200"
|
||||||
|
py={4}
|
||||||
|
align="center"
|
||||||
|
key={evaluation.id}
|
||||||
|
>
|
||||||
|
<Text fontWeight="bold">{evaluation.name}</Text>
|
||||||
|
<Text flex={1}>
|
||||||
|
{evaluation.matchType}: "{evaluation.matchString}"
|
||||||
|
</Text>
|
||||||
|
<Button
|
||||||
|
variant="unstyled"
|
||||||
|
color="gray.400"
|
||||||
|
height="unset"
|
||||||
|
width="unset"
|
||||||
|
minW="unset"
|
||||||
|
onClick={() => setEditingId(evaluation.id)}
|
||||||
|
_hover={{
|
||||||
|
color: "gray.800",
|
||||||
|
cursor: "pointer",
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Icon as={BsPencil} boxSize={4} />
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
variant="unstyled"
|
||||||
|
color="gray.400"
|
||||||
|
height="unset"
|
||||||
|
width="unset"
|
||||||
|
minW="unset"
|
||||||
|
onClick={() => onDelete(evaluation.id)}
|
||||||
|
_hover={{
|
||||||
|
color: "gray.800",
|
||||||
|
cursor: "pointer",
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Icon as={BsX} boxSize={6} />
|
||||||
|
</Button>
|
||||||
|
</HStack>
|
||||||
|
)
|
||||||
|
)}
|
||||||
|
{editingId == null && (
|
||||||
|
<Button
|
||||||
|
onClick={() => setEditingId("new")}
|
||||||
|
alignSelf="flex-start"
|
||||||
|
size="sm"
|
||||||
|
mt={4}
|
||||||
|
colorScheme="blue"
|
||||||
|
>
|
||||||
|
Add Evaluation
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
{editingId == "new" && (
|
||||||
|
<EvaluationEditor
|
||||||
|
evaluation={null}
|
||||||
|
defaultName={`Eval${evaluations.length + 1}`}
|
||||||
|
onSave={onSave}
|
||||||
|
onCancel={onCancel}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</Stack>
|
||||||
</Stack>
|
</Stack>
|
||||||
</Stack>
|
</Stack>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -33,8 +33,8 @@ export default function EditScenarioVars() {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<Stack>
|
<Stack>
|
||||||
<Heading size="sm">Edit Scenario Variables</Heading>
|
<Heading size="sm">Scenario Variables</Heading>
|
||||||
<Stack spacing={2} pt={2}>
|
<Stack spacing={2}>
|
||||||
<Text fontSize="sm">
|
<Text fontSize="sm">
|
||||||
Scenario variables can be used in your prompt variants as well as evaluations. Reference
|
Scenario variables can be used in your prompt variants as well as evaluations. Reference
|
||||||
them using <Code>{"{{curly_braces}}"}</Code>.
|
them using <Code>{"{{curly_braces}}"}</Code>.
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { Button } from "@chakra-ui/react";
|
import { Button, Icon, Spinner } from "@chakra-ui/react";
|
||||||
import { BsPlus } from "react-icons/bs";
|
import { BsPlus } from "react-icons/bs";
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks";
|
||||||
@@ -9,7 +9,7 @@ export default function NewVariantButton() {
|
|||||||
const mutation = api.promptVariants.create.useMutation();
|
const mutation = api.promptVariants.create.useMutation();
|
||||||
const utils = api.useContext();
|
const utils = api.useContext();
|
||||||
|
|
||||||
const [onClick] = useHandledAsyncCallback(async () => {
|
const [onClick, loading] = useHandledAsyncCallback(async () => {
|
||||||
if (!experiment.data) return;
|
if (!experiment.data) return;
|
||||||
await mutation.mutateAsync({
|
await mutation.mutateAsync({
|
||||||
experimentId: experiment.data.id,
|
experimentId: experiment.data.id,
|
||||||
@@ -30,7 +30,7 @@ export default function NewVariantButton() {
|
|||||||
height="unset"
|
height="unset"
|
||||||
minH={headerMinHeight}
|
minH={headerMinHeight}
|
||||||
>
|
>
|
||||||
<BsPlus size={24} />
|
<Icon as={loading ? Spinner : BsPlus} boxSize={6} mr={loading ? 1 : 0} />
|
||||||
Add Variant
|
Add Variant
|
||||||
</Button>
|
</Button>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -1,17 +1,18 @@
|
|||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import { type PromptVariant, type Scenario } from "./types";
|
import { type PromptVariant, type Scenario } from "./types";
|
||||||
import { Spinner, Text, Box, Center, Flex, Icon } from "@chakra-ui/react";
|
import { Spinner, Text, Box, Center, Flex, Icon, HStack } from "@chakra-ui/react";
|
||||||
import { useExperiment } from "~/utils/hooks";
|
import { useExperiment } from "~/utils/hooks";
|
||||||
import SyntaxHighlighter from "react-syntax-highlighter";
|
import SyntaxHighlighter from "react-syntax-highlighter";
|
||||||
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
|
import { docco } from "react-syntax-highlighter/dist/cjs/styles/hljs";
|
||||||
import stringify from "json-stringify-pretty-compact";
|
import stringify from "json-stringify-pretty-compact";
|
||||||
import { useMemo, type ReactElement } from "react";
|
import { useMemo, type ReactElement } from "react";
|
||||||
import { BsClock } from "react-icons/bs";
|
import { BsCheck, BsClock, BsX } from "react-icons/bs";
|
||||||
import { type ModelOutput } from "@prisma/client";
|
import { type ModelOutput } from "@prisma/client";
|
||||||
import { type ChatCompletion } from "openai/resources/chat";
|
import { type ChatCompletion } from "openai/resources/chat";
|
||||||
import { generateChannel } from "~/utils/generateChannel";
|
import { generateChannel } from "~/utils/generateChannel";
|
||||||
import { isObject } from "lodash";
|
import { isObject } from "lodash";
|
||||||
import useSocket from "~/utils/useSocket";
|
import useSocket from "~/utils/useSocket";
|
||||||
|
import { evaluateOutput } from "~/server/utils/evaluateOutput";
|
||||||
|
|
||||||
export default function OutputCell({
|
export default function OutputCell({
|
||||||
scenario,
|
scenario,
|
||||||
@@ -109,7 +110,7 @@ export default function OutputCell({
|
|||||||
{ maxLength: 40 }
|
{ maxLength: 40 }
|
||||||
)}
|
)}
|
||||||
</SyntaxHighlighter>
|
</SyntaxHighlighter>
|
||||||
<OutputStats modelOutput={output.data} />
|
<OutputStats modelOutput={output.data} scenario={scenario} />
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -120,18 +121,44 @@ export default function OutputCell({
|
|||||||
return (
|
return (
|
||||||
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
|
<Flex w="100%" h="100%" direction="column" justifyContent="space-between" whiteSpace="pre-wrap">
|
||||||
{contentToDisplay}
|
{contentToDisplay}
|
||||||
{output.data && <OutputStats modelOutput={output.data} />}
|
{output.data && <OutputStats modelOutput={output.data} scenario={scenario} />}
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const OutputStats = ({ modelOutput }: { modelOutput: ModelOutput }) => {
|
const OutputStats = ({
|
||||||
|
modelOutput,
|
||||||
|
scenario,
|
||||||
|
}: {
|
||||||
|
modelOutput: ModelOutput;
|
||||||
|
scenario: Scenario;
|
||||||
|
}) => {
|
||||||
const timeToComplete = modelOutput.timeToComplete;
|
const timeToComplete = modelOutput.timeToComplete;
|
||||||
|
const experiment = useExperiment();
|
||||||
|
const evals =
|
||||||
|
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex justifyContent="flex-end" alignItems="center" color="gray.500" fontSize="xs" mt={2}>
|
<HStack align="center" color="gray.500" fontSize="xs" mt={2}>
|
||||||
<Icon as={BsClock} mr={0.5} />
|
<HStack flex={1}>
|
||||||
<Text>{(timeToComplete / 1000).toFixed(2)}s</Text>
|
{evals.map((evaluation) => {
|
||||||
</Flex>
|
const passed = evaluateOutput(modelOutput, scenario, evaluation);
|
||||||
|
return (
|
||||||
|
<HStack spacing={0} key={evaluation.id}>
|
||||||
|
<Text>{evaluation.name}</Text>
|
||||||
|
<Icon
|
||||||
|
as={passed ? BsCheck : BsX}
|
||||||
|
color={passed ? "green.500" : "red.500"}
|
||||||
|
boxSize={6}
|
||||||
|
/>
|
||||||
|
</HStack>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</HStack>
|
||||||
|
<HStack>
|
||||||
|
<Icon as={BsClock} mr={0.5} />
|
||||||
|
<Text>{(timeToComplete / 1000).toFixed(2)}s</Text>
|
||||||
|
</HStack>
|
||||||
|
</HStack>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -6,9 +6,11 @@ import {
|
|||||||
DrawerHeader,
|
DrawerHeader,
|
||||||
DrawerOverlay,
|
DrawerOverlay,
|
||||||
Heading,
|
Heading,
|
||||||
|
Stack,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { useStore } from "~/utils/store";
|
import { useStore } from "~/utils/store";
|
||||||
import EditScenarioVars from "./EditScenarioVars";
|
import EditScenarioVars from "./EditScenarioVars";
|
||||||
|
import EditEvaluations from "./EditEvaluations";
|
||||||
|
|
||||||
export default function SettingsDrawer() {
|
export default function SettingsDrawer() {
|
||||||
const isOpen = useStore((state) => state.drawerOpen);
|
const isOpen = useStore((state) => state.drawerOpen);
|
||||||
@@ -23,8 +25,10 @@ export default function SettingsDrawer() {
|
|||||||
<Heading size="md">Settings</Heading>
|
<Heading size="md">Settings</Heading>
|
||||||
</DrawerHeader>
|
</DrawerHeader>
|
||||||
<DrawerBody>
|
<DrawerBody>
|
||||||
<EditScenarioVars />
|
<Stack spacing={6}>
|
||||||
{/* <EditEvaluations /> */}
|
<EditScenarioVars />
|
||||||
|
<EditEvaluations />
|
||||||
|
</Stack>
|
||||||
</DrawerBody>
|
</DrawerBody>
|
||||||
</DrawerContent>
|
</DrawerContent>
|
||||||
</Drawer>
|
</Drawer>
|
||||||
|
|||||||
38
src/components/OutputsTable/VariantStats.tsx
Normal file
38
src/components/OutputsTable/VariantStats.tsx
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import { HStack, Text, useToken } from "@chakra-ui/react";
|
||||||
|
import { type PromptVariant } from "./types";
|
||||||
|
import { cellPadding } from "../constants";
|
||||||
|
import { api } from "~/utils/api";
|
||||||
|
import chroma from "chroma-js";
|
||||||
|
|
||||||
|
export default function VariantStats(props: { variant: PromptVariant }) {
|
||||||
|
const evalResults =
|
||||||
|
api.evaluations.results.useQuery({
|
||||||
|
variantId: props.variant.id,
|
||||||
|
}).data ?? [];
|
||||||
|
|
||||||
|
const [passColor, neutralColor, failColor] = useToken("colors", [
|
||||||
|
"green.500",
|
||||||
|
"gray.500",
|
||||||
|
"red.500",
|
||||||
|
]);
|
||||||
|
|
||||||
|
const scale = chroma.scale([failColor, neutralColor, passColor]).domain([0, 0.5, 1]);
|
||||||
|
|
||||||
|
if (!(evalResults.length > 0)) return null;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<HStack px={cellPadding.x} py={cellPadding.y} fontSize="sm">
|
||||||
|
{evalResults.map((result) => {
|
||||||
|
const passedFrac = result.passCount / (result.passCount + result.failCount);
|
||||||
|
return (
|
||||||
|
<HStack key={result.id}>
|
||||||
|
<Text>{result.evaluation.name}</Text>
|
||||||
|
<Text color={scale(passedFrac).hex()} fontWeight="bold">
|
||||||
|
{(passedFrac * 100).toFixed(1)}%
|
||||||
|
</Text>
|
||||||
|
</HStack>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</HStack>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -1,11 +1,4 @@
|
|||||||
import {
|
import { Button, Grid, GridItem, HStack, Heading, type SystemStyleObject } from "@chakra-ui/react";
|
||||||
Button,
|
|
||||||
Grid,
|
|
||||||
GridItem,
|
|
||||||
HStack,
|
|
||||||
Heading,
|
|
||||||
type SystemStyleObject,
|
|
||||||
} from "@chakra-ui/react";
|
|
||||||
import { api } from "~/utils/api";
|
import { api } from "~/utils/api";
|
||||||
import NewScenarioButton from "./NewScenarioButton";
|
import NewScenarioButton from "./NewScenarioButton";
|
||||||
import NewVariantButton from "./NewVariantButton";
|
import NewVariantButton from "./NewVariantButton";
|
||||||
@@ -15,6 +8,7 @@ import VariantHeader from "./VariantHeader";
|
|||||||
import { cellPadding } from "../constants";
|
import { cellPadding } from "../constants";
|
||||||
import { BsPencil } from "react-icons/bs";
|
import { BsPencil } from "react-icons/bs";
|
||||||
import { useStore } from "~/utils/store";
|
import { useStore } from "~/utils/store";
|
||||||
|
import VariantStats from "./VariantStats";
|
||||||
|
|
||||||
const stickyHeaderStyle: SystemStyleObject = {
|
const stickyHeaderStyle: SystemStyleObject = {
|
||||||
position: "sticky",
|
position: "sticky",
|
||||||
@@ -38,6 +32,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
|||||||
if (!variants.data || !scenarios.data) return null;
|
if (!variants.data || !scenarios.data) return null;
|
||||||
|
|
||||||
const allCols = variants.data.length + 1;
|
const allCols = variants.data.length + 1;
|
||||||
|
const headerRows = 3;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Grid
|
<Grid
|
||||||
@@ -55,7 +50,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
|||||||
<GridItem
|
<GridItem
|
||||||
display="flex"
|
display="flex"
|
||||||
alignItems="flex-end"
|
alignItems="flex-end"
|
||||||
rowSpan={2}
|
rowSpan={headerRows}
|
||||||
px={cellPadding.x}
|
px={cellPadding.x}
|
||||||
py={cellPadding.y}
|
py={cellPadding.y}
|
||||||
>
|
>
|
||||||
@@ -82,7 +77,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
|||||||
</GridItem>
|
</GridItem>
|
||||||
))}
|
))}
|
||||||
<GridItem
|
<GridItem
|
||||||
rowSpan={scenarios.data.length + 2}
|
rowSpan={scenarios.data.length + headerRows}
|
||||||
padding={0}
|
padding={0}
|
||||||
// Have to use `style` instead of emotion style props to work around css specificity issues conflicting with the "> *" selector on Grid
|
// Have to use `style` instead of emotion style props to work around css specificity issues conflicting with the "> *" selector on Grid
|
||||||
style={{ borderRightWidth: 0, borderBottomWidth: 0 }}
|
style={{ borderRightWidth: 0, borderBottomWidth: 0 }}
|
||||||
@@ -92,10 +87,15 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
|
|||||||
</GridItem>
|
</GridItem>
|
||||||
|
|
||||||
{variants.data.map((variant) => (
|
{variants.data.map((variant) => (
|
||||||
<GridItem key={variant.uiId} padding={0}>
|
<GridItem key={variant.uiId}>
|
||||||
<VariantConfigEditor variant={variant} />
|
<VariantConfigEditor variant={variant} />
|
||||||
</GridItem>
|
</GridItem>
|
||||||
))}
|
))}
|
||||||
|
{variants.data.map((variant) => (
|
||||||
|
<GridItem key={variant.uiId}>
|
||||||
|
<VariantStats variant={variant} />
|
||||||
|
</GridItem>
|
||||||
|
))}
|
||||||
{scenarios.data.map((scenario) => (
|
{scenarios.data.map((scenario) => (
|
||||||
<ScenarioRow key={scenario.uiId} scenario={scenario} variants={variants.data} />
|
<ScenarioRow key={scenario.uiId} scenario={scenario} variants={variants.data} />
|
||||||
))}
|
))}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import { experimentsRouter } from "./routers/experiments.router";
|
|||||||
import { scenariosRouter } from "./routers/scenarios.router";
|
import { scenariosRouter } from "./routers/scenarios.router";
|
||||||
import { modelOutputsRouter } from "./routers/modelOutputs.router";
|
import { modelOutputsRouter } from "./routers/modelOutputs.router";
|
||||||
import { templateVarsRouter } from "./routers/templateVariables.router";
|
import { templateVarsRouter } from "./routers/templateVariables.router";
|
||||||
|
import { evaluationsRouter } from "./routers/evaluations.router";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This is the primary router for your server.
|
* This is the primary router for your server.
|
||||||
@@ -16,6 +17,7 @@ export const appRouter = createTRPCRouter({
|
|||||||
scenarios: scenariosRouter,
|
scenarios: scenariosRouter,
|
||||||
outputs: modelOutputsRouter,
|
outputs: modelOutputsRouter,
|
||||||
templateVars: templateVarsRouter,
|
templateVars: templateVarsRouter,
|
||||||
|
evaluations: evaluationsRouter,
|
||||||
});
|
});
|
||||||
|
|
||||||
// export type definition of API
|
// export type definition of API
|
||||||
|
|||||||
77
src/server/api/routers/evaluations.router.ts
Normal file
77
src/server/api/routers/evaluations.router.ts
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import { EvaluationMatchType } from "@prisma/client";
|
||||||
|
import { z } from "zod";
|
||||||
|
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||||
|
import { prisma } from "~/server/db";
|
||||||
|
import { reevaluateEvaluation } from "~/server/utils/evaluations";
|
||||||
|
|
||||||
|
export const evaluationsRouter = createTRPCRouter({
|
||||||
|
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
||||||
|
return await prisma.evaluation.findMany({
|
||||||
|
where: {
|
||||||
|
experimentId: input.experimentId,
|
||||||
|
},
|
||||||
|
orderBy: { createdAt: "asc" },
|
||||||
|
});
|
||||||
|
}),
|
||||||
|
|
||||||
|
results: publicProcedure.input(z.object({ variantId: z.string() })).query(async ({ input }) => {
|
||||||
|
return await prisma.evaluationResult.findMany({
|
||||||
|
where: {
|
||||||
|
promptVariantId: input.variantId,
|
||||||
|
},
|
||||||
|
include: { evaluation: true },
|
||||||
|
});
|
||||||
|
}),
|
||||||
|
|
||||||
|
create: publicProcedure
|
||||||
|
.input(
|
||||||
|
z.object({
|
||||||
|
experimentId: z.string(),
|
||||||
|
name: z.string(),
|
||||||
|
matchString: z.string(),
|
||||||
|
matchType: z.nativeEnum(EvaluationMatchType),
|
||||||
|
})
|
||||||
|
)
|
||||||
|
.mutation(async ({ input }) => {
|
||||||
|
const evaluation = await prisma.evaluation.create({
|
||||||
|
data: {
|
||||||
|
experimentId: input.experimentId,
|
||||||
|
name: input.name,
|
||||||
|
matchString: input.matchString,
|
||||||
|
matchType: input.matchType,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
await reevaluateEvaluation(evaluation);
|
||||||
|
}),
|
||||||
|
|
||||||
|
update: publicProcedure
|
||||||
|
.input(
|
||||||
|
z.object({
|
||||||
|
id: z.string(),
|
||||||
|
updates: z.object({
|
||||||
|
name: z.string().optional(),
|
||||||
|
matchString: z.string().optional(),
|
||||||
|
matchType: z.nativeEnum(EvaluationMatchType).optional(),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
)
|
||||||
|
.mutation(async ({ input }) => {
|
||||||
|
await prisma.evaluation.update({
|
||||||
|
where: { id: input.id },
|
||||||
|
data: {
|
||||||
|
name: input.updates.name,
|
||||||
|
matchString: input.updates.matchString,
|
||||||
|
matchType: input.updates.matchType,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
await reevaluateEvaluation(
|
||||||
|
await prisma.evaluation.findUniqueOrThrow({ where: { id: input.id } })
|
||||||
|
);
|
||||||
|
}),
|
||||||
|
|
||||||
|
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
|
||||||
|
await prisma.evaluation.delete({
|
||||||
|
where: { id: input.id },
|
||||||
|
});
|
||||||
|
}),
|
||||||
|
});
|
||||||
@@ -1,15 +1,18 @@
|
|||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||||
import { prisma } from "~/server/db";
|
import { prisma } from "~/server/db";
|
||||||
import fillTemplate, { type VariableMap } from "~/server/utils/fillTemplate";
|
import { fillTemplateJson, type VariableMap } from "~/server/utils/fillTemplate";
|
||||||
import { type JSONSerializable } from "~/server/types";
|
import { type JSONSerializable } from "~/server/types";
|
||||||
import { getCompletion } from "~/server/utils/getCompletion";
|
|
||||||
import crypto from "crypto";
|
import crypto from "crypto";
|
||||||
import type { Prisma } from "@prisma/client";
|
import type { Prisma } from "@prisma/client";
|
||||||
|
import { reevaluateVariant } from "~/server/utils/evaluations";
|
||||||
|
import { getCompletion } from "~/server/utils/getCompletion";
|
||||||
|
|
||||||
export const modelOutputsRouter = createTRPCRouter({
|
export const modelOutputsRouter = createTRPCRouter({
|
||||||
get: publicProcedure
|
get: publicProcedure
|
||||||
.input(z.object({ scenarioId: z.string(), variantId: z.string(), channel: z.string().optional() }))
|
.input(
|
||||||
|
z.object({ scenarioId: z.string(), variantId: z.string(), channel: z.string().optional() })
|
||||||
|
)
|
||||||
.query(async ({ input }) => {
|
.query(async ({ input }) => {
|
||||||
const existing = await prisma.modelOutput.findUnique({
|
const existing = await prisma.modelOutput.findUnique({
|
||||||
where: {
|
where: {
|
||||||
@@ -36,7 +39,7 @@ export const modelOutputsRouter = createTRPCRouter({
|
|||||||
|
|
||||||
if (!variant || !scenario) return null;
|
if (!variant || !scenario) return null;
|
||||||
|
|
||||||
const filledTemplate = fillTemplate(
|
const filledTemplate = fillTemplateJson(
|
||||||
variant.config as JSONSerializable,
|
variant.config as JSONSerializable,
|
||||||
scenario.variableValues as VariableMap
|
scenario.variableValues as VariableMap
|
||||||
);
|
);
|
||||||
@@ -73,6 +76,8 @@ export const modelOutputsRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
await reevaluateVariant(input.variantId);
|
||||||
|
|
||||||
return modelOutput;
|
return modelOutput;
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -10,9 +10,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
experimentId: input.experimentId,
|
experimentId: input.experimentId,
|
||||||
visible: true,
|
visible: true,
|
||||||
},
|
},
|
||||||
orderBy: {
|
orderBy: { sortIndex: "asc" },
|
||||||
sortIndex: "asc",
|
|
||||||
},
|
|
||||||
});
|
});
|
||||||
}),
|
}),
|
||||||
|
|
||||||
|
|||||||
31
src/server/utils/evaluateOutput.ts
Normal file
31
src/server/utils/evaluateOutput.ts
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/client";
|
||||||
|
import { type ChatCompletion } from "openai/resources/chat";
|
||||||
|
import { type VariableMap, fillTemplate } from "./fillTemplate";
|
||||||
|
|
||||||
|
export const evaluateOutput = (
|
||||||
|
modelOutput: ModelOutput,
|
||||||
|
scenario: TestScenario,
|
||||||
|
evaluation: Evaluation
|
||||||
|
): boolean => {
|
||||||
|
const output = modelOutput.output as unknown as ChatCompletion;
|
||||||
|
const message = output.choices?.[0]?.message;
|
||||||
|
|
||||||
|
if (!message) return false;
|
||||||
|
|
||||||
|
const stringifiedMessage = JSON.stringify(message);
|
||||||
|
|
||||||
|
const matchRegex = fillTemplate(evaluation.matchString, scenario.variableValues as VariableMap);
|
||||||
|
|
||||||
|
let match;
|
||||||
|
|
||||||
|
switch (evaluation.matchType) {
|
||||||
|
case "CONTAINS":
|
||||||
|
match = stringifiedMessage.match(matchRegex) !== null;
|
||||||
|
break;
|
||||||
|
case "DOES_NOT_CONTAIN":
|
||||||
|
match = stringifiedMessage.match(matchRegex) === null;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
return match;
|
||||||
|
};
|
||||||
91
src/server/utils/evaluations.ts
Normal file
91
src/server/utils/evaluations.ts
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
import { type Evaluation } from "@prisma/client";
|
||||||
|
import { prisma } from "../db";
|
||||||
|
import { evaluateOutput } from "./evaluateOutput";
|
||||||
|
|
||||||
|
export const reevaluateVariant = async (variantId: string) => {
|
||||||
|
const variant = await prisma.promptVariant.findUnique({
|
||||||
|
where: { id: variantId },
|
||||||
|
});
|
||||||
|
if (!variant) return;
|
||||||
|
|
||||||
|
const evaluations = await prisma.evaluation.findMany({
|
||||||
|
where: { experimentId: variant.experimentId },
|
||||||
|
});
|
||||||
|
|
||||||
|
const modelOutputs = await prisma.modelOutput.findMany({
|
||||||
|
where: { promptVariantId: variantId },
|
||||||
|
include: { testScenario: true },
|
||||||
|
});
|
||||||
|
|
||||||
|
const scenarios = await prisma.testScenario.findMany({
|
||||||
|
where: { experimentId: variant.experimentId, visible: true },
|
||||||
|
});
|
||||||
|
|
||||||
|
await Promise.all(
|
||||||
|
evaluations.map(async (evaluation) => {
|
||||||
|
const passCount = modelOutputs.filter((output) =>
|
||||||
|
evaluateOutput(output, output.testScenario, evaluation)
|
||||||
|
).length;
|
||||||
|
const failCount = scenarios.length - passCount;
|
||||||
|
|
||||||
|
await prisma.evaluationResult.upsert({
|
||||||
|
where: {
|
||||||
|
evaluationId_promptVariantId: {
|
||||||
|
evaluationId: evaluation.id,
|
||||||
|
promptVariantId: variantId,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
create: {
|
||||||
|
evaluationId: evaluation.id,
|
||||||
|
promptVariantId: variantId,
|
||||||
|
passCount,
|
||||||
|
failCount,
|
||||||
|
},
|
||||||
|
update: {
|
||||||
|
passCount,
|
||||||
|
failCount,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
})
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const reevaluateEvaluation = async (evaluation: Evaluation) => {
|
||||||
|
const variants = await prisma.promptVariant.findMany({
|
||||||
|
where: { experimentId: evaluation.experimentId, visible: true },
|
||||||
|
});
|
||||||
|
|
||||||
|
const modelOutputs = await prisma.modelOutput.findMany({
|
||||||
|
where: { promptVariantId: { in: variants.map((v) => v.id) }, testScenario: { visible: true } },
|
||||||
|
include: { testScenario: true },
|
||||||
|
});
|
||||||
|
|
||||||
|
await Promise.all(
|
||||||
|
variants.map(async (variant) => {
|
||||||
|
const outputs = modelOutputs.filter((output) => output.promptVariantId === variant.id);
|
||||||
|
const passCount = outputs.filter((output) =>
|
||||||
|
evaluateOutput(output, output.testScenario, evaluation)
|
||||||
|
).length;
|
||||||
|
const failCount = outputs.length - passCount;
|
||||||
|
|
||||||
|
await prisma.evaluationResult.upsert({
|
||||||
|
where: {
|
||||||
|
evaluationId_promptVariantId: {
|
||||||
|
evaluationId: evaluation.id,
|
||||||
|
promptVariantId: variant.id,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
create: {
|
||||||
|
evaluationId: evaluation.id,
|
||||||
|
promptVariantId: variant.id,
|
||||||
|
passCount,
|
||||||
|
failCount,
|
||||||
|
},
|
||||||
|
update: {
|
||||||
|
passCount,
|
||||||
|
failCount,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
})
|
||||||
|
);
|
||||||
|
};
|
||||||
@@ -2,17 +2,21 @@ import { type JSONSerializable } from "../types";
|
|||||||
|
|
||||||
export type VariableMap = Record<string, string>;
|
export type VariableMap = Record<string, string>;
|
||||||
|
|
||||||
export default function fillTemplate<T extends JSONSerializable>(
|
export function fillTemplate(template: string, variables: VariableMap): string {
|
||||||
|
return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || "");
|
||||||
|
}
|
||||||
|
|
||||||
|
export function fillTemplateJson<T extends JSONSerializable>(
|
||||||
template: T,
|
template: T,
|
||||||
variables: VariableMap
|
variables: VariableMap
|
||||||
): T {
|
): T {
|
||||||
if (typeof template === "string") {
|
if (typeof template === "string") {
|
||||||
return template.replace(/{{\s*(\w+)\s*}}/g, (_, key: string) => variables[key] || "") as T;
|
return fillTemplate(template, variables) as T;
|
||||||
} else if (Array.isArray(template)) {
|
} else if (Array.isArray(template)) {
|
||||||
return template.map((item) => fillTemplate(item, variables)) as T;
|
return template.map((item) => fillTemplateJson(item, variables)) as T;
|
||||||
} else if (typeof template === "object" && template !== null) {
|
} else if (typeof template === "object" && template !== null) {
|
||||||
return Object.keys(template).reduce((acc, key) => {
|
return Object.keys(template).reduce((acc, key) => {
|
||||||
acc[key] = fillTemplate(template[key] as JSONSerializable, variables);
|
acc[key] = fillTemplateJson(template[key] as JSONSerializable, variables);
|
||||||
return acc;
|
return acc;
|
||||||
}, {} as { [key: string]: JSONSerializable } & T);
|
}, {} as { [key: string]: JSONSerializable } & T);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
Reference in New Issue
Block a user