diff --git a/app/@types/nextjs-routes.d.ts b/app/@types/nextjs-routes.d.ts index 4e6ca91..b920d3b 100644 --- a/app/@types/nextjs-routes.d.ts +++ b/app/@types/nextjs-routes.d.ts @@ -19,6 +19,8 @@ declare module "nextjs-routes" { | DynamicRoute<"/api/v1/[...trpc]", { "trpc": string[] }> | StaticRoute<"/api/v1/openapi"> | StaticRoute<"/dashboard"> + | DynamicRoute<"/datasets/[id]", { "id": string }> + | StaticRoute<"/datasets"> | DynamicRoute<"/experiments/[experimentSlug]", { "experimentSlug": string }> | StaticRoute<"/experiments"> | StaticRoute<"/fine-tunes"> diff --git a/app/prisma/migrations/20230904234505_revamp_dataset_entry/migration.sql b/app/prisma/migrations/20230904234505_revamp_dataset_entry/migration.sql new file mode 100644 index 0000000..7477d39 --- /dev/null +++ b/app/prisma/migrations/20230904234505_revamp_dataset_entry/migration.sql @@ -0,0 +1,27 @@ +/* + Warnings: + + - Added the required column `input` to the `DatasetEntry` table without a default value. This is not possible if the table is not empty. + - Added the required column `inputTokens` to the `DatasetEntry` table without a default value. This is not possible if the table is not empty. + - Added the required column `outputTokens` to the `DatasetEntry` table without a default value. This is not possible if the table is not empty. + - Added the required column `type` to the `DatasetEntry` table without a default value. This is not possible if the table is not empty. + +*/ +-- CreateEnum +CREATE TYPE "DatasetEntryType" AS ENUM ('TRAIN', 'TEST'); + +-- AlterTable +ALTER TABLE "Dataset" ADD COLUMN "trainingRatio" DOUBLE PRECISION NOT NULL DEFAULT 0.8; + +-- AlterTable +ALTER TABLE "DatasetEntry" ADD COLUMN "input" JSONB NOT NULL, +ADD COLUMN "inputTokens" INTEGER NOT NULL, +ADD COLUMN "output" JSONB, +ADD COLUMN "outputTokens" INTEGER NOT NULL, +ADD COLUMN "type" "DatasetEntryType" NOT NULL; + +-- CreateIndex +CREATE INDEX "DatasetEntry_datasetId_createdAt_id_idx" ON "DatasetEntry"("datasetId", "createdAt", "id"); + +-- CreateIndex +CREATE INDEX "DatasetEntry_datasetId_type_idx" ON "DatasetEntry"("datasetId", "type"); diff --git a/app/prisma/schema.prisma b/app/prisma/schema.prisma index 7f05bdf..f7fe09b 100644 --- a/app/prisma/schema.prisma +++ b/app/prisma/schema.prisma @@ -179,9 +179,10 @@ model OutputEvaluation { model Dataset { id String @id @default(uuid()) @db.Uuid - name String - datasetEntries DatasetEntry[] - fineTunes FineTune[] + name String + datasetEntries DatasetEntry[] + fineTunes FineTune[] + trainingRatio Float @default(0.8) projectId String @db.Uuid project Project @relation(fields: [projectId], references: [id], onDelete: Cascade) @@ -190,17 +191,32 @@ model Dataset { updatedAt DateTime @updatedAt } +enum DatasetEntryType { + TRAIN + TEST +} + model DatasetEntry { id String @id @default(uuid()) @db.Uuid loggedCallId String @db.Uuid loggedCall LoggedCall @relation(fields: [loggedCallId], references: [id], onDelete: Cascade) + input Json + output Json? + inputTokens Int + outputTokens Int + + type DatasetEntryType + datasetId String @db.Uuid dataset Dataset? @relation(fields: [datasetId], references: [id], onDelete: Cascade) createdAt DateTime @default(now()) updatedAt DateTime @updatedAt + + @@index([datasetId, createdAt, id]) + @@index([datasetId, type]) } model Project { @@ -452,7 +468,7 @@ model FineTune { deploymentFinishedAt DateTime? datasetId String @db.Uuid - dataset Dataset @relation(fields: [datasetId], references: [id], onDelete: Cascade) + dataset Dataset @relation(fields: [datasetId], references: [id], onDelete: Cascade) projectId String @db.Uuid project Project @relation(fields: [projectId], references: [id], onDelete: Cascade) diff --git a/app/prisma/seedAgiEval.ts b/app/prisma/seedAgiEval.ts index 548d3c4..2996103 100644 --- a/app/prisma/seedAgiEval.ts +++ b/app/prisma/seedAgiEval.ts @@ -1,5 +1,4 @@ import { prisma } from "~/server/db"; -import { generateNewCell } from "~/server/utils/generateNewCell"; import dedent from "dedent"; import { execSync } from "child_process"; import fs from "fs"; diff --git a/app/prisma/seedDashboard.ts b/app/prisma/seedDashboard.ts index 1926642..a29d6d3 100644 --- a/app/prisma/seedDashboard.ts +++ b/app/prisma/seedDashboard.ts @@ -108,7 +108,7 @@ const MODEL_RESPONSE_TEMPLATES: { inputTokens: 236, outputTokens: 5, finishReason: "stop", - tags: [{ name: "prompt_id", value: "define_func" }], + tags: [{ name: "prompt_id", value: "add_scenario" }], }, { reqPayload: { @@ -311,7 +311,7 @@ const MODEL_RESPONSE_TEMPLATES: { outputTokens: 108, finishReason: "stop", tags: [ - { name: "prompt_id", value: "chatcmpl-7" }, + { name: "prompt_id", value: "define_func" }, { name: "some_other_tag", value: "some_other_value" }, ], }, diff --git a/app/src/components/requestLogs/ActionButton.tsx b/app/src/components/ActionButton.tsx similarity index 97% rename from app/src/components/requestLogs/ActionButton.tsx rename to app/src/components/ActionButton.tsx index 315274e..6822e02 100644 --- a/app/src/components/requestLogs/ActionButton.tsx +++ b/app/src/components/ActionButton.tsx @@ -3,7 +3,7 @@ import { useState } from "react"; import { Button, HStack, type ButtonProps, Icon, Text } from "@chakra-ui/react"; import { type IconType } from "react-icons"; import { useAppStore } from "~/state/store"; -import { BetaModal } from "../BetaModal"; +import { BetaModal } from "./BetaModal"; const ActionButton = ({ icon, diff --git a/app/src/components/requestLogs/FormattedJson.tsx b/app/src/components/FormattedJson.tsx similarity index 100% rename from app/src/components/requestLogs/FormattedJson.tsx rename to app/src/components/FormattedJson.tsx diff --git a/app/src/components/InputDropdown.tsx b/app/src/components/InputDropdown.tsx index eade73b..b7979ad 100644 --- a/app/src/components/InputDropdown.tsx +++ b/app/src/components/InputDropdown.tsx @@ -16,12 +16,16 @@ import { import { FiChevronDown } from "react-icons/fi"; import { BiCheck } from "react-icons/bi"; +import { isEqual } from "lodash-es"; +import React from "react"; type InputDropdownProps = { options: ReadonlyArray; selectedOption: T; onSelect: (option: T) => void; inputGroupProps?: InputGroupProps; + getDisplayLabel?: (option: T) => string; + isDisabled?: boolean; }; const InputDropdown = ({ @@ -29,19 +33,21 @@ const InputDropdown = ({ selectedOption, onSelect, inputGroupProps, + getDisplayLabel = (option) => option as string, + isDisabled, }: InputDropdownProps) => { - const popover = useDisclosure(); + const { onOpen, ...popover } = useDisclosure(); return ( - + {}} cursor="pointer" @@ -52,9 +58,10 @@ const InputDropdown = ({ onFocus={(e) => { e.target.blur(); }} + isDisabled={isDisabled} /> - + @@ -78,8 +85,10 @@ const InputDropdown = ({ fontSize="sm" borderBottomWidth={1} > - {option as string} - {option === selectedOption && } + {getDisplayLabel(option)} + {isEqual(option, selectedOption) && ( + + )} ))} diff --git a/app/src/components/OutputsTable/ScenariosHeader.tsx b/app/src/components/OutputsTable/ScenariosHeader.tsx index 7e18286..e8b8718 100644 --- a/app/src/components/OutputsTable/ScenariosHeader.tsx +++ b/app/src/components/OutputsTable/ScenariosHeader.tsx @@ -19,15 +19,13 @@ import { useScenarios, } from "~/utils/hooks"; import { BsGear, BsPencil, BsPlus, BsStars } from "react-icons/bs"; -import { useAppStore } from "~/state/store"; import { api } from "~/utils/api"; export const ActionButton = (props: ButtonProps) => ( + + + + ); +}; diff --git a/app/src/components/datasets/DatasetConfigurationDrawer/DeleteDatasetDialog.tsx b/app/src/components/datasets/DatasetConfigurationDrawer/DeleteDatasetDialog.tsx new file mode 100644 index 0000000..454bfc7 --- /dev/null +++ b/app/src/components/datasets/DatasetConfigurationDrawer/DeleteDatasetDialog.tsx @@ -0,0 +1,73 @@ +import { useRef } from "react"; +import { + type UseDisclosureReturn, + AlertDialog, + AlertDialogOverlay, + AlertDialogContent, + AlertDialogHeader, + AlertDialogBody, + AlertDialogFooter, + Button, +} from "@chakra-ui/react"; +import { api } from "~/utils/api"; + +import { useHandledAsyncCallback } from "~/utils/hooks"; + +const DeleteDatasetDialog = ({ + datasetId, + onDelete, + disclosure, +}: { + datasetId?: string; + onDelete?: () => void; + disclosure: UseDisclosureReturn; +}) => { + const cancelRef = useRef(null); + + const mutation = api.datasets.delete.useMutation(); + const utils = api.useContext(); + + const [onDeleteConfirm, deletionInProgress] = useHandledAsyncCallback(async () => { + if (!datasetId) return; + await mutation.mutateAsync({ id: datasetId }); + await utils.datasets.list.invalidate(); + onDelete?.(); + + disclosure.onClose(); + }, [mutation, datasetId, disclosure.onClose]); + + console.log("dataset id", datasetId); + + return ( + + + + + Delete Dataset + + + + If you delete this dataset all the associated dataset entries will be deleted as well. + Are you sure? + + + + + + + + + + ); +}; + +export default DeleteDatasetDialog; diff --git a/app/src/components/datasets/DatasetEntriesTable/DatasetEntriesTable.tsx b/app/src/components/datasets/DatasetEntriesTable/DatasetEntriesTable.tsx new file mode 100644 index 0000000..b2f6876 --- /dev/null +++ b/app/src/components/datasets/DatasetEntriesTable/DatasetEntriesTable.tsx @@ -0,0 +1,46 @@ +import { Card, Table, Tbody } from "@chakra-ui/react"; +import { useState } from "react"; +import { useDatasetEntries } from "~/utils/hooks"; +import { TableHeader, TableRow, EmptyTableRow } from "./TableRow"; +import DatasetEntryEditorDrawer from "./DatasetEntryEditorDrawer"; + +export default function DatasetEntriesTable() { + const [expandedDatasetEntryId, setExpandedDatasetEntryId] = useState(null); + const datasetEntries = useDatasetEntries().data?.entries; + + return ( + <> + + + + + {datasetEntries?.length ? ( + datasetEntries?.map((entry) => { + return ( + { + if (entry.id === expandedDatasetEntryId) { + setExpandedDatasetEntryId(null); + } else { + setExpandedDatasetEntryId(entry.id); + } + }} + showOptions + /> + ); + }) + ) : ( + + )} + +
+
+ setExpandedDatasetEntryId(null)} + /> + + ); +} diff --git a/app/src/components/datasets/DatasetEntriesTable/DatasetEntryEditorDrawer.tsx b/app/src/components/datasets/DatasetEntriesTable/DatasetEntryEditorDrawer.tsx new file mode 100644 index 0000000..2d244d1 --- /dev/null +++ b/app/src/components/datasets/DatasetEntriesTable/DatasetEntryEditorDrawer.tsx @@ -0,0 +1,174 @@ +import { useState, useEffect, useMemo } from "react"; +import { + Drawer, + DrawerBody, + DrawerCloseButton, + DrawerContent, + DrawerHeader, + DrawerOverlay, + DrawerFooter, + Heading, + VStack, + HStack, + Button, + Text, + Divider, + Icon, +} from "@chakra-ui/react"; +import { type CreateChatCompletionRequestMessage } from "openai/resources/chat"; +import { BsPlus } from "react-icons/bs"; +import { type DatasetEntryType } from "@prisma/client"; + +import { api } from "~/utils/api"; +import { useDatasetEntry, useHandledAsyncCallback } from "~/utils/hooks"; +import EditableMessage from "./EditableMessage"; +import EntryTypeDropdown from "./EntryTypeDropdown"; + +export default function DatasetDentryEditorDrawer({ + datasetEntryId, + clearDatasetEntryId, +}: { + datasetEntryId: string | null; + clearDatasetEntryId: () => void; +}) { + const utils = api.useContext(); + + const datasetEntry = useDatasetEntry(datasetEntryId).data; + + const savedInputMessages = useMemo( + () => datasetEntry?.input as unknown as CreateChatCompletionRequestMessage[], + [datasetEntry], + ); + const savedOutputMessage = useMemo( + () => datasetEntry?.output as unknown as CreateChatCompletionRequestMessage, + [datasetEntry], + ); + + const [inputMessagesToSave, setInputMessagesToSave] = useState< + CreateChatCompletionRequestMessage[] + >([]); + const [outputMessageToSave, setOutputMessageToSave] = + useState(null); + + useEffect(() => { + if (savedInputMessages) { + setInputMessagesToSave(savedInputMessages); + setOutputMessageToSave(savedOutputMessage); + } + }, [savedInputMessages, savedOutputMessage]); + + const updateMutation = api.datasetEntries.update.useMutation(); + const [onSave, savingInProgress] = useHandledAsyncCallback(async () => { + if (!datasetEntryId || !inputMessagesToSave) return; + await updateMutation.mutateAsync({ + id: datasetEntryId, + updates: { + input: JSON.stringify(inputMessagesToSave), + output: JSON.stringify(outputMessageToSave), + }, + }); + await utils.datasetEntries.list.invalidate(); + await utils.datasetEntries.get.invalidate({ id: datasetEntryId }); + }, [updateMutation, datasetEntryId, inputMessagesToSave, outputMessageToSave, utils]); + + const [onUpdateType] = useHandledAsyncCallback( + async (type: DatasetEntryType) => { + if (!datasetEntryId) return; + await updateMutation.mutateAsync({ + id: datasetEntryId, + updates: { + type, + }, + }); + await utils.datasetEntries.list.invalidate(); + await utils.datasetEntries.get.invalidate({ id: datasetEntryId }); + }, + [updateMutation, datasetEntryId, utils], + ); + + return ( + + + + + + + Dataset Entry + {datasetEntry && ( + + )} + + + + + + + Input + {inputMessagesToSave.map((message, i) => { + return ( + <> + + { + const newInputMessages = [...inputMessagesToSave]; + newInputMessages[i] = message; + setInputMessagesToSave(newInputMessages); + }} + onDelete={() => { + const newInputMessages = [...inputMessagesToSave]; + newInputMessages.splice(i, 1); + setInputMessagesToSave(newInputMessages); + }} + /> + + ); + })} + + + + + Output + + setOutputMessageToSave(message)} + isOutput + /> + + + + + + + + + + + + + ); +} diff --git a/app/src/components/datasets/DatasetEntriesTable/EditableMessage.tsx b/app/src/components/datasets/DatasetEntriesTable/EditableMessage.tsx new file mode 100644 index 0000000..2025984 --- /dev/null +++ b/app/src/components/datasets/DatasetEntriesTable/EditableMessage.tsx @@ -0,0 +1,105 @@ +import { VStack, HStack, Tooltip, IconButton, Icon } from "@chakra-ui/react"; +import { type CreateChatCompletionRequestMessage } from "openai/resources/chat"; +import { BsX } from "react-icons/bs"; + +import AutoResizeTextArea from "~/components/AutoResizeTextArea"; +import InputDropdown from "~/components/InputDropdown"; +import { parseableToFunctionCall } from "~/utils/utils"; +import FunctionCallEditor from "./FunctionCallEditor"; + +const MESSAGE_ROLE_OPTIONS = ["system", "user", "assistant", "function"] as const; +const OUTPUT_OPTIONS = ["plaintext", "func_call"] as const; + +const EditableMessage = ({ + message, + onEdit, + onDelete, + isOutput, +}: { + message: CreateChatCompletionRequestMessage | null; + onEdit: (message: CreateChatCompletionRequestMessage) => void; + onDelete?: () => void; + isOutput?: boolean; +}) => { + const { role = "assistant", content = "", function_call } = message || {}; + + const currentOutputOption: (typeof OUTPUT_OPTIONS)[number] = function_call + ? "func_call" + : "plaintext"; + + return ( + + + + {!isOutput && ( + { + const updatedMessage = { role: option, content }; + if (role === "assistant" && currentOutputOption === "func_call") { + updatedMessage.content = JSON.stringify(function_call, null, 2); + } + onEdit(updatedMessage); + }} + inputGroupProps={{ w: "32", bgColor: "white" }} + /> + )} + {role === "assistant" && ( + { + const updatedMessage: CreateChatCompletionRequestMessage = { + role, + content: null, + function_call: undefined, + }; + if (option === "plaintext") { + updatedMessage.content = JSON.stringify(function_call, null, 2); + } else if (option === "func_call") { + updatedMessage.function_call = + content && parseableToFunctionCall(content) + ? JSON.parse(content) + : { name: "", arguments: "{}" }; + } + onEdit(updatedMessage); + }} + inputGroupProps={{ w: "32", bgColor: "white" }} + /> + )} + + {!isOutput && ( + + + } + onClick={onDelete} + size="xs" + display="flex" + colorScheme="gray" + color="gray.500" + variant="ghost" + /> + + + )} + + {function_call ? ( + onEdit({ role, function_call, content: null })} + /> + ) : ( + onEdit({ role, content: e.target.value })} + bgColor="white" + /> + )} + + ); +}; + +export default EditableMessage; diff --git a/app/src/components/datasets/DatasetEntriesTable/EntryTypeDropdown.tsx b/app/src/components/datasets/DatasetEntriesTable/EntryTypeDropdown.tsx new file mode 100644 index 0000000..210f96a --- /dev/null +++ b/app/src/components/datasets/DatasetEntriesTable/EntryTypeDropdown.tsx @@ -0,0 +1,24 @@ +import { type DatasetEntryType } from "@prisma/client"; + +import InputDropdown from "~/components/InputDropdown"; + +const ENTRY_TYPE_OPTIONS: DatasetEntryType[] = ["TRAIN", "TEST"]; + +const EntryTypeDropdown = ({ + type, + onTypeChange, +}: { + type: DatasetEntryType; + onTypeChange: (type: DatasetEntryType) => void; +}) => { + return ( + + ); +}; + +export default EntryTypeDropdown; diff --git a/app/src/components/datasets/DatasetEntriesTable/FunctionCallEditor.tsx b/app/src/components/datasets/DatasetEntriesTable/FunctionCallEditor.tsx new file mode 100644 index 0000000..c44afb7 --- /dev/null +++ b/app/src/components/datasets/DatasetEntriesTable/FunctionCallEditor.tsx @@ -0,0 +1,125 @@ +import { useRef, useMemo, useEffect } from "react"; +import { VStack, HStack, Text, Input, Box } from "@chakra-ui/react"; +import { type CreateChatCompletionRequestMessage } from "openai/resources/chat"; + +import { useAppStore } from "~/state/store"; +import { type CreatedEditor } from "~/state/sharedVariantEditor.slice"; + +const FunctionCallEditor = ({ + function_call, + onEdit, +}: { + function_call: CreateChatCompletionRequestMessage.FunctionCall; + onEdit: (function_call: CreateChatCompletionRequestMessage.FunctionCall) => void; +}) => { + const monaco = useAppStore.use.sharedArgumentsEditor.monaco(); + const editorRef = useRef(null); + const editorId = useMemo(() => `editor_${Math.random().toString(36).substring(7)}`, []); + + useEffect(() => { + if (monaco) { + const container = document.getElementById(editorId) as HTMLElement; + + const editor = monaco.editor.create(container, { + value: function_call.arguments, + language: "json", + theme: "customTheme", + lineNumbers: "off", + minimap: { enabled: false }, + wrappingIndent: "indent", + wrappingStrategy: "advanced", + wordWrap: "on", + folding: false, + scrollbar: { + alwaysConsumeMouseWheel: false, + verticalScrollbarSize: 0, + }, + wordWrapBreakAfterCharacters: "", + wordWrapBreakBeforeCharacters: "", + quickSuggestions: true, + renderLineHighlight: "none", + fontSize: 14, + scrollBeyondLastLine: false, + }); + + editorRef.current = editor; + + const updateHeight = () => { + const contentHeight = editor.getContentHeight(); + container.style.height = `${contentHeight}px`; + editor.layout(); + }; + + const attemptDocumentFormat = () => { + const action = editor.getAction("editor.action.formatDocument"); + if (action) { + action + .run() + .then(updateHeight) + .catch((error) => { + console.error("Error running formatDocument:", error); + }); + return true; + } + return false; + }; + + editor.onDidBlurEditorText(() => { + attemptDocumentFormat(); + onEdit({ name: function_call.name, arguments: editor.getValue() }); + }); + + // Interval function to check for action availability + const checkForActionInterval = setInterval(() => { + const formatted = attemptDocumentFormat(); + if (formatted) { + clearInterval(checkForActionInterval); // Clear the interval once the action is found and run + } + }, 100); // Check every 100ms + + // Add content change listener + const contentChangeListener = editor.onDidChangeModelContent(updateHeight); + + const resizeObserver = new ResizeObserver(() => { + editor.layout(); + }); + resizeObserver.observe(container); + + return () => { + contentChangeListener.dispose(); + resizeObserver.disconnect(); + editor?.dispose(); + }; + } + }, [monaco, editorId, function_call.name, function_call.arguments, onEdit]); + + return ( + + + + Name: + + onEdit({ name: e.target.value, arguments: function_call.arguments })} + bgColor="white" + /> + + + Arguments + + + + + + ); +}; + +export default FunctionCallEditor; diff --git a/app/src/components/datasets/DatasetEntriesTable/TableRow.tsx b/app/src/components/datasets/DatasetEntriesTable/TableRow.tsx new file mode 100644 index 0000000..ae14dd5 --- /dev/null +++ b/app/src/components/datasets/DatasetEntriesTable/TableRow.tsx @@ -0,0 +1,128 @@ +import { Box, Td, Tr, Thead, Th, Tooltip, HStack, Text, Checkbox } from "@chakra-ui/react"; +import Link from "next/link"; + +import dayjs from "~/utils/dayjs"; +import { type RouterOutputs } from "~/utils/api"; +import { useAppStore } from "~/state/store"; +import { useIsClientRehydrated, useDatasetEntries } from "~/utils/hooks"; +import { useMemo } from "react"; + +type DatasetEntry = RouterOutputs["datasetEntries"]["list"]["entries"][0]; + +export const TableHeader = () => { + const matchingDatasetEntryIds = useDatasetEntries().data?.matchingEntryIds; + const selectedDatasetEntryIds = useAppStore((s) => s.selectedDatasetEntries.selectedIds); + const addSelectedIds = useAppStore((s) => s.selectedDatasetEntries.addSelectedIds); + const clearSelectedIds = useAppStore((s) => s.selectedDatasetEntries.clearSelectedIds); + const allSelected = useMemo(() => { + if (!matchingDatasetEntryIds || !matchingDatasetEntryIds.length) return false; + return matchingDatasetEntryIds.every((id) => selectedDatasetEntryIds.has(id)); + }, [matchingDatasetEntryIds, selectedDatasetEntryIds]); + const isClientRehydrated = useIsClientRehydrated(); + if (!isClientRehydrated) return null; + + return ( + + + + + { + allSelected ? clearSelectedIds() : addSelectedIds(matchingDatasetEntryIds || []); + }} + /> + + ({selectedDatasetEntryIds.size ? `${selectedDatasetEntryIds.size}/` : ""} + {matchingDatasetEntryIds?.length || 0}) + + + + Created At + Input tokens + Output tokens + Type + + + ); +}; + +export const TableRow = ({ + datasetEntry, + onToggle, + showOptions, +}: { + datasetEntry: DatasetEntry; + onToggle: () => void; + showOptions?: boolean; +}) => { + const createdAt = dayjs(datasetEntry.createdAt).format("MMMM D h:mm A"); + const fullTime = dayjs(datasetEntry.createdAt).toString(); + + const isChecked = useAppStore((s) => s.selectedDatasetEntries.selectedIds.has(datasetEntry.id)); + const toggleChecked = useAppStore((s) => s.selectedDatasetEntries.toggleSelectedId); + + const isClientRehydrated = useIsClientRehydrated(); + if (!isClientRehydrated) return null; + + return ( + + {showOptions && ( + + toggleChecked(datasetEntry.id)} /> + + )} + + + + {createdAt} + + + + {datasetEntry.inputTokens} + {datasetEntry.outputTokens} + {datasetEntry.type} + + ); +}; + +export const EmptyTableRow = ({ filtersApplied = true }: { filtersApplied?: boolean }) => { + const visibleColumns = useAppStore((s) => s.columnVisibility.visibleColumns); + const filters = useAppStore((state) => state.logFilters.filters); + const { isLoading } = useDatasetEntries(); + + if (isLoading) return null; + + if (filters.length && filtersApplied) { + return ( + + + + No matching entries found. Try removing some filters. + + + + ); + } + + return ( + + + + This dataset has no entries. Add some logs in the{" "} + + + Request Logs + + {" "} + tab. + + + + ); +}; diff --git a/app/src/components/datasets/DatasetEntryPaginator.tsx b/app/src/components/datasets/DatasetEntryPaginator.tsx new file mode 100644 index 0000000..cbe6919 --- /dev/null +++ b/app/src/components/datasets/DatasetEntryPaginator.tsx @@ -0,0 +1,16 @@ +import { type StackProps } from "@chakra-ui/react"; + +import { useDatasetEntries } from "~/utils/hooks"; +import Paginator from "../Paginator"; + +const DatasetEntryPaginator = (props: StackProps) => { + const { data } = useDatasetEntries(); + + if (!data) return null; + + const { matchingEntryIds } = data; + + return ; +}; + +export default DatasetEntryPaginator; diff --git a/app/src/components/datasets/DatasetHeaderButtons.tsx b/app/src/components/datasets/DatasetHeaderButtons.tsx new file mode 100644 index 0000000..7be2a38 --- /dev/null +++ b/app/src/components/datasets/DatasetHeaderButtons.tsx @@ -0,0 +1,20 @@ +import { Button, HStack, Icon, Text } from "@chakra-ui/react"; +import { useDataset } from "~/utils/hooks"; +import { BsGearFill } from "react-icons/bs"; + +export const DatasetHeaderButtons = ({ openDrawer }: { openDrawer: () => void }) => { + const dataset = useDataset(); + + if (dataset.isLoading) return null; + + return ( + + + + ); +}; diff --git a/app/src/components/datasets/DatasetsTable.tsx b/app/src/components/datasets/DatasetsTable.tsx new file mode 100644 index 0000000..eb2da68 --- /dev/null +++ b/app/src/components/datasets/DatasetsTable.tsx @@ -0,0 +1,52 @@ +import { Card, Table, Thead, Tr, Th, Tbody, Td, VStack, Icon, Text } from "@chakra-ui/react"; +import { FaTable } from "react-icons/fa"; +import Link from "next/link"; + +import dayjs from "~/utils/dayjs"; +import { useDatasets } from "~/utils/hooks"; + +const DatasetsTable = ({}) => { + const { data } = useDatasets(); + + const datasets = data || []; + + return ( + + {datasets.length ? ( + + + + + + + + + + {datasets.map((dataset) => { + return ( + + + + + + ); + })} + +
NameCreated AtSize
+ + {dataset.name} + + {dayjs(dataset.createdAt).format("MMMM D h:mm A")}{dataset._count.datasetEntries}
+ ) : ( + + + + No Datasets Found. Create your first dataset. + + + )} +
+ ); +}; + +export default DatasetsTable; diff --git a/app/src/components/datasets/ExperimentButton.tsx b/app/src/components/datasets/ExperimentButton.tsx new file mode 100644 index 0000000..db1fcba --- /dev/null +++ b/app/src/components/datasets/ExperimentButton.tsx @@ -0,0 +1,21 @@ +import { RiFlaskLine } from "react-icons/ri"; + +import { useAppStore } from "~/state/store"; +import ActionButton from "../ActionButton"; + +const ExperimentButton = () => { + const selectedIds = useAppStore((s) => s.selectedDatasetEntries.selectedIds); + return ( + { + console.log("experimenting with these ids", selectedIds); + }} + label="Experiment" + icon={RiFlaskLine} + isDisabled={selectedIds.size === 0} + requireBeta + /> + ); +}; + +export default ExperimentButton; diff --git a/app/src/components/requestLogs/FineTuneButton.tsx b/app/src/components/datasets/FineTuneButton.tsx similarity index 86% rename from app/src/components/requestLogs/FineTuneButton.tsx rename to app/src/components/datasets/FineTuneButton.tsx index 91dac22..80bcfd1 100644 --- a/app/src/components/requestLogs/FineTuneButton.tsx +++ b/app/src/components/datasets/FineTuneButton.tsx @@ -20,17 +20,19 @@ import { AiTwotoneThunderbolt } from "react-icons/ai"; import humanId from "human-id"; import { useRouter } from "next/router"; -import { useHandledAsyncCallback } from "~/utils/hooks"; +import { useDataset, useDatasetEntries, useHandledAsyncCallback } from "~/utils/hooks"; import { api } from "~/utils/api"; import { useAppStore } from "~/state/store"; -import ActionButton from "./ActionButton"; +import ActionButton from "../ActionButton"; import InputDropdown from "../InputDropdown"; import { FiChevronDown } from "react-icons/fi"; const SUPPORTED_BASE_MODELS = ["llama2-7b", "llama2-13b", "llama2-70b", "gpt-3.5-turbo"]; const FineTuneButton = () => { - const selectedLogIds = useAppStore((s) => s.selectedLogs.selectedLogIds); + const datasetEntries = useDatasetEntries().data; + + const numEntries = datasetEntries?.matchingEntryIds.length || 0; const disclosure = useDisclosure(); @@ -40,7 +42,7 @@ const FineTuneButton = () => { onClick={disclosure.onOpen} label="Fine Tune" icon={AiTwotoneThunderbolt} - isDisabled={selectedLogIds.size === 0} + isDisabled={numEntries === 0} requireBeta /> @@ -52,8 +54,8 @@ export default FineTuneButton; const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => { const selectedProjectId = useAppStore((s) => s.selectedProjectId); - const selectedLogIds = useAppStore((s) => s.selectedLogs.selectedLogIds); - const clearSelectedLogIds = useAppStore((s) => s.selectedLogs.clearSelectedLogIds); + const dataset = useDataset().data; + const datasetEntries = useDatasetEntries().data; const [selectedBaseModel, setSelectedBaseModel] = useState(SUPPORTED_BASE_MODELS[0]); const [modelSlug, setModelSlug] = useState(humanId({ separator: "-", capitalize: false })); @@ -71,19 +73,17 @@ const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => { const createFineTuneMutation = api.fineTunes.create.useMutation(); const [createFineTune, creationInProgress] = useHandledAsyncCallback(async () => { - if (!selectedProjectId || !modelSlug || !selectedBaseModel || !selectedLogIds.size) return; + if (!selectedProjectId || !modelSlug || !selectedBaseModel || !dataset) return; await createFineTuneMutation.mutateAsync({ - projectId: selectedProjectId, slug: modelSlug, baseModel: selectedBaseModel, - selectedLogIds: Array.from(selectedLogIds), + datasetId: dataset.id, }); await utils.fineTunes.list.invalidate(); await router.push({ pathname: "/fine-tunes" }); - clearSelectedLogIds(); disclosure.onClose(); - }, [createFineTuneMutation, selectedProjectId, selectedLogIds, modelSlug, selectedBaseModel]); + }, [createFineTuneMutation, selectedProjectId, modelSlug, selectedBaseModel]); return ( @@ -99,7 +99,8 @@ const FineTuneModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => { - We'll train on the {selectedLogIds.size} logs you've selected. + We'll train on {datasetEntries?.trainingCount} and test on{" "} + {datasetEntries?.testingCount} entries in this dataset. diff --git a/app/src/components/experiments/DeleteExperimentDialog.tsx b/app/src/components/experiments/DeleteExperimentDialog.tsx index a82ec84..8e0ac11 100644 --- a/app/src/components/experiments/DeleteExperimentDialog.tsx +++ b/app/src/components/experiments/DeleteExperimentDialog.tsx @@ -27,7 +27,7 @@ const DeleteExperimentDialog = ({ const mutation = api.experiments.delete.useMutation(); const utils = api.useContext(); - const [onDeleteConfirm] = useHandledAsyncCallback(async () => { + const [onDeleteConfirm, deletionInProgress] = useHandledAsyncCallback(async () => { if (!experimentId) return; await mutation.mutateAsync({ id: experimentId }); await utils.experiments.list.invalidate(); @@ -53,7 +53,12 @@ const DeleteExperimentDialog = ({ - diff --git a/app/src/components/experiments/ExperimentHeaderButtons/DeleteDialog.tsx b/app/src/components/experiments/ExperimentHeaderButtons/DeleteDialog.tsx deleted file mode 100644 index 5fe2b60..0000000 --- a/app/src/components/experiments/ExperimentHeaderButtons/DeleteDialog.tsx +++ /dev/null @@ -1,57 +0,0 @@ -import { - Button, - AlertDialog, - AlertDialogBody, - AlertDialogFooter, - AlertDialogHeader, - AlertDialogContent, - AlertDialogOverlay, -} from "@chakra-ui/react"; - -import { useRouter } from "next/router"; -import { useRef } from "react"; -import { api } from "~/utils/api"; -import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; - -export const DeleteDialog = ({ onClose }: { onClose: () => void }) => { - const experiment = useExperiment(); - const deleteMutation = api.experiments.delete.useMutation(); - const utils = api.useContext(); - const router = useRouter(); - - const cancelRef = useRef(null); - - const [onDeleteConfirm] = useHandledAsyncCallback(async () => { - if (!experiment.data?.id) return; - await deleteMutation.mutateAsync({ id: experiment.data.id }); - await utils.experiments.list.invalidate(); - await router.push({ pathname: "/experiments" }); - onClose(); - }, [deleteMutation, experiment.data?.id, router]); - - return ( - - - - - Delete Experiment - - - - If you delete this experiment all the associated prompts and scenarios will be deleted - as well. Are you sure? - - - - - - - - - - ); -}; diff --git a/app/src/components/experiments/ExperimentHeaderButtons/ExperimentHeaderButtons.tsx b/app/src/components/experiments/ExperimentHeaderButtons/ExperimentHeaderButtons.tsx index 97960b6..22cc9bc 100644 --- a/app/src/components/experiments/ExperimentHeaderButtons/ExperimentHeaderButtons.tsx +++ b/app/src/components/experiments/ExperimentHeaderButtons/ExperimentHeaderButtons.tsx @@ -3,17 +3,14 @@ import { useOnForkButtonPressed } from "./useOnForkButtonPressed"; import { useExperiment } from "~/utils/hooks"; import { BsGearFill } from "react-icons/bs"; import { TbGitFork } from "react-icons/tb"; -import { useAppStore } from "~/state/store"; -export const ExperimentHeaderButtons = () => { +export const ExperimentHeaderButtons = ({ openDrawer }: { openDrawer: () => void }) => { const experiment = useExperiment(); const canModify = experiment.data?.access.canModify ?? false; const { onForkButtonPressed, isForking } = useOnForkButtonPressed(); - const openDrawer = useAppStore((s) => s.openDrawer); - if (experiment.isLoading) return null; return ( diff --git a/app/src/components/ExperimentSettingsDrawer/DeleteButton.tsx b/app/src/components/experiments/ExperimentSettingsDrawer/DeleteButton.tsx similarity index 81% rename from app/src/components/ExperimentSettingsDrawer/DeleteButton.tsx rename to app/src/components/experiments/ExperimentSettingsDrawer/DeleteButton.tsx index c4be40d..0ff075b 100644 --- a/app/src/components/ExperimentSettingsDrawer/DeleteButton.tsx +++ b/app/src/components/experiments/ExperimentSettingsDrawer/DeleteButton.tsx @@ -2,17 +2,15 @@ import { Button, Icon, useDisclosure, Text } from "@chakra-ui/react"; import { useRouter } from "next/router"; import { BsTrash } from "react-icons/bs"; -import { useAppStore } from "~/state/store"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; -import DeleteExperimentDialog from "../experiments/DeleteExperimentDialog"; +import DeleteExperimentDialog from "../DeleteExperimentDialog"; -export const DeleteButton = () => { +export const DeleteButton = ({ closeDrawer }: { closeDrawer: () => void }) => { const experiment = useExperiment(); const router = useRouter(); const disclosure = useDisclosure(); - const closeDrawer = useAppStore((s) => s.closeDrawer); const [onDelete] = useHandledAsyncCallback(async () => { await router.push({ pathname: "/experiments" }); closeDrawer(); diff --git a/app/src/components/OutputsTable/EditEvaluations.tsx b/app/src/components/experiments/ExperimentSettingsDrawer/EditEvaluations.tsx similarity index 99% rename from app/src/components/OutputsTable/EditEvaluations.tsx rename to app/src/components/experiments/ExperimentSettingsDrawer/EditEvaluations.tsx index 1375cf8..1b7a09f 100644 --- a/app/src/components/OutputsTable/EditEvaluations.tsx +++ b/app/src/components/experiments/ExperimentSettingsDrawer/EditEvaluations.tsx @@ -19,7 +19,7 @@ import { useCallback, useState } from "react"; import { BsPencil, BsX } from "react-icons/bs"; import { api } from "~/utils/api"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; -import AutoResizeTextArea from "../AutoResizeTextArea"; +import AutoResizeTextArea from "~/components/AutoResizeTextArea"; type EvalValues = Pick; diff --git a/app/src/components/OutputsTable/EditScenarioVars.tsx b/app/src/components/experiments/ExperimentSettingsDrawer/EditScenarioVars.tsx similarity index 98% rename from app/src/components/OutputsTable/EditScenarioVars.tsx rename to app/src/components/experiments/ExperimentSettingsDrawer/EditScenarioVars.tsx index 8cee904..3f65c0c 100644 --- a/app/src/components/OutputsTable/EditScenarioVars.tsx +++ b/app/src/components/experiments/ExperimentSettingsDrawer/EditScenarioVars.tsx @@ -5,7 +5,7 @@ import { BsPencil, BsX } from "react-icons/bs"; import { api } from "~/utils/api"; import { useExperiment, useHandledAsyncCallback, useScenarioVars } from "~/utils/hooks"; import { maybeReportError } from "~/utils/errorHandling/maybeReportError"; -import { FloatingLabelInput } from "./FloatingLabelInput"; +import { FloatingLabelInput } from "~/components/OutputsTable/FloatingLabelInput"; export const ScenarioVar = ({ variable, diff --git a/app/src/components/ExperimentSettingsDrawer/ExperimentSettingsDrawer.tsx b/app/src/components/experiments/ExperimentSettingsDrawer/ExperimentSettingsDrawer.tsx similarity index 60% rename from app/src/components/ExperimentSettingsDrawer/ExperimentSettingsDrawer.tsx rename to app/src/components/experiments/ExperimentSettingsDrawer/ExperimentSettingsDrawer.tsx index 8536bcf..db2fac4 100644 --- a/app/src/components/ExperimentSettingsDrawer/ExperimentSettingsDrawer.tsx +++ b/app/src/components/experiments/ExperimentSettingsDrawer/ExperimentSettingsDrawer.tsx @@ -7,18 +7,19 @@ import { DrawerOverlay, Heading, VStack, + type UseDisclosureReturn, } from "@chakra-ui/react"; -import EditScenarioVars from "../OutputsTable/EditScenarioVars"; -import EditEvaluations from "../OutputsTable/EditEvaluations"; -import { useAppStore } from "~/state/store"; +import EditScenarioVars from "./EditScenarioVars"; +import EditEvaluations from "./EditEvaluations"; import { DeleteButton } from "./DeleteButton"; -export default function ExperimentSettingsDrawer() { - const isOpen = useAppStore((state) => state.drawerOpen); - const closeDrawer = useAppStore((state) => state.closeDrawer); - +export default function ExperimentSettingsDrawer({ + disclosure, +}: { + disclosure: UseDisclosureReturn; +}) { return ( - + @@ -31,7 +32,7 @@ export default function ExperimentSettingsDrawer() { - + diff --git a/app/src/components/nav/AppShell.tsx b/app/src/components/nav/AppShell.tsx index 85be5f0..f6a6057 100644 --- a/app/src/components/nav/AppShell.tsx +++ b/app/src/components/nav/AppShell.tsx @@ -17,7 +17,7 @@ import { useRouter } from "next/router"; import { BsGearFill, BsGithub, BsPersonCircle } from "react-icons/bs"; import { IoStatsChartOutline } from "react-icons/io5"; import { RiHome3Line, RiFlaskLine } from "react-icons/ri"; -import { AiOutlineThunderbolt } from "react-icons/ai"; +import { AiOutlineThunderbolt, AiOutlineDatabase } from "react-icons/ai"; import { FaReadme } from "react-icons/fa"; import { signIn, useSession } from "next-auth/react"; @@ -78,6 +78,7 @@ const NavSidebar = () => { + diff --git a/app/src/components/requestLogs/AddToDatasetButton.tsx b/app/src/components/requestLogs/AddToDatasetButton.tsx new file mode 100644 index 0000000..aa62098 --- /dev/null +++ b/app/src/components/requestLogs/AddToDatasetButton.tsx @@ -0,0 +1,194 @@ +import { useState, useEffect, useMemo } from "react"; +import { + Modal, + ModalOverlay, + ModalContent, + ModalHeader, + ModalCloseButton, + ModalBody, + ModalFooter, + HStack, + VStack, + Icon, + Text, + Button, + Flex, + Input, + useDisclosure, + type UseDisclosureReturn, + Checkbox, +} from "@chakra-ui/react"; +import { FiPlusSquare } from "react-icons/fi"; + +import { useDatasets, useHandledAsyncCallback } from "~/utils/hooks"; +import { api } from "~/utils/api"; +import { useAppStore } from "~/state/store"; +import ActionButton from "../ActionButton"; +import InputDropdown from "../InputDropdown"; +import { maybeReportError } from "~/utils/errorHandling/maybeReportError"; +import { useRouter } from "next/router"; + +const AddToDatasetButton = () => { + const selectedLogIds = useAppStore((s) => s.selectedLogs.selectedLogIds); + + const disclosure = useDisclosure(); + + return ( + <> + + + + ); +}; + +export default AddToDatasetButton; + +const AddToDatasetModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => { + const selectedProjectId = useAppStore((s) => s.selectedProjectId); + const selectedLogIds = useAppStore((s) => s.selectedLogs.selectedLogIds); + const clearSelectedLogIds = useAppStore((s) => s.selectedLogs.clearSelectedLogIds); + const router = useRouter(); + + const datasets = useDatasets().data; + + const existingDatasetOptions = useMemo( + () => + datasets?.length + ? datasets.map((d) => ({ label: d.name, id: d.id })) + : [{ label: "", id: "" }], + [datasets], + ); + + const [selectedDatasetOption, setSelectedDatasetOption] = useState(existingDatasetOptions?.[0]); + const [newDatasetName, setNewDatasetName] = useState(""); + const [createNewDataset, setCreateNewDataset] = useState(false); + + useEffect(() => { + if (disclosure.isOpen) { + setSelectedDatasetOption(existingDatasetOptions?.[0]); + setCreateNewDataset(!existingDatasetOptions[0]?.id); + } + }, [disclosure.isOpen, existingDatasetOptions]); + + const createDatasetEntriesMutation = api.datasetEntries.create.useMutation(); + + const [addToDataset, addingInProgress] = useHandledAsyncCallback(async () => { + if ( + !selectedProjectId || + !selectedLogIds.size || + !(createNewDataset ? newDatasetName : selectedDatasetOption?.id) + ) + return; + const datasetParams = createNewDataset + ? { newDatasetParams: { projectId: selectedProjectId, name: newDatasetName } } + : { datasetId: selectedDatasetOption?.id }; + const response = await createDatasetEntriesMutation.mutateAsync({ + loggedCallIds: Array.from(selectedLogIds), + ...datasetParams, + }); + + if (maybeReportError(response)) return; + + const datasetId = response.payload; + + await router.push({ pathname: "/datasets/[id]", query: { id: datasetId } }); + + disclosure.onClose(); + clearSelectedLogIds(); + }, [ + selectedProjectId, + selectedLogIds, + createNewDataset, + selectedDatasetOption?.id, + newDatasetName, + router, + ]); + + return ( + + + + + + + Add to Dataset + + + + + + + We'll add the {selectedLogIds.size} logs you have selected to the dataset you + choose. + + + {existingDatasetOptions?.length && selectedDatasetOption && ( + + + Dataset: + + option.label} + onSelect={(option) => setSelectedDatasetOption(option)} + inputGroupProps={{ w: 48 }} + isDisabled={createNewDataset} + /> + setCreateNewDataset(e.target.checked)} + paddingLeft={4} + isDisabled={!existingDatasetOptions[0]?.id} + > + Create New Dataset + + + )} + + {createNewDataset && ( + + + Dataset Name: + + setNewDatasetName(e.target.value)} + /> + + )} + + + + + + + + + + + + ); +}; diff --git a/app/src/components/requestLogs/ColumnVisiblityDropdown.tsx b/app/src/components/requestLogs/ColumnVisiblityDropdown.tsx index 1155ec8..a501bfc 100644 --- a/app/src/components/requestLogs/ColumnVisiblityDropdown.tsx +++ b/app/src/components/requestLogs/ColumnVisiblityDropdown.tsx @@ -17,7 +17,7 @@ import { useMemo } from "react"; import { useIsClientRehydrated, useTagNames } from "~/utils/hooks"; import { useAppStore } from "~/state/store"; import { StaticColumnKeys } from "~/state/columnVisiblitySlice"; -import ActionButton from "./ActionButton"; +import ActionButton from "../ActionButton"; const ColumnVisiblityDropdown = () => { const tagNames = useTagNames().data; diff --git a/app/src/components/requestLogs/ExportButton.tsx b/app/src/components/requestLogs/ExportButton.tsx index af61af6..ba5ec7f 100644 --- a/app/src/components/requestLogs/ExportButton.tsx +++ b/app/src/components/requestLogs/ExportButton.tsx @@ -28,7 +28,7 @@ import { BiExport } from "react-icons/bi"; import { useHandledAsyncCallback } from "~/utils/hooks"; import { api } from "~/utils/api"; import { useAppStore } from "~/state/store"; -import ActionButton from "./ActionButton"; +import ActionButton from "../ActionButton"; import InputDropdown from "../InputDropdown"; import { FiChevronUp, FiChevronDown } from "react-icons/fi"; import InfoCircle from "../InfoCircle"; @@ -81,7 +81,7 @@ const ExportLogsModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => return; const response = await exportLogsMutation.mutateAsync({ projectId: selectedProjectId, - selectedLogIds: Array.from(selectedLogIds), + loggedCallIds: Array.from(selectedLogIds), testingSplit, selectedExportFormat, removeDuplicates, diff --git a/app/src/components/requestLogs/TableRow.tsx b/app/src/components/requestLogs/TableRow.tsx index 0ca1372..a4335fe 100644 --- a/app/src/components/requestLogs/TableRow.tsx +++ b/app/src/components/requestLogs/TableRow.tsx @@ -9,17 +9,14 @@ import { Collapse, HStack, VStack, - Button, - ButtonGroup, Text, Checkbox, Link as ChakraLink, } from "@chakra-ui/react"; -import Link from "next/link"; import dayjs from "~/utils/dayjs"; import { type RouterOutputs } from "~/utils/api"; -import { FormattedJson } from "./FormattedJson"; +import { FormattedJson } from "../FormattedJson"; import { useAppStore } from "~/state/store"; import { useIsClientRehydrated, useLoggedCalls, useTagNames } from "~/utils/hooks"; import { useMemo } from "react"; @@ -176,23 +173,16 @@ export const TableRow = ({ - - - - Input - - - - Output - - - - - - - + + + Input + + + + Output + + + diff --git a/app/src/modelProviders/openai-ChatCompletion/index.ts b/app/src/modelProviders/openai-ChatCompletion/index.ts index b5dc00f..6e863e9 100644 --- a/app/src/modelProviders/openai-ChatCompletion/index.ts +++ b/app/src/modelProviders/openai-ChatCompletion/index.ts @@ -42,24 +42,21 @@ const modelProvider: OpenaiChatModelProvider = { canStream: true, getCompletion, getUsage: (input, output) => { - if (output.choices.length === 0) return null; - const model = modelProvider.getModel(input); if (!model) return null; let inputTokens: number; let outputTokens: number; - if (output.usage) { + if (output?.usage) { inputTokens = output.usage.prompt_tokens; outputTokens = output.usage.completion_tokens; } else { try { inputTokens = countOpenAIChatTokens(model, input.messages); - outputTokens = countOpenAIChatTokens( - model, - output.choices.map((c) => c.message).filter(truthyFilter), - ); + outputTokens = output + ? countOpenAIChatTokens(model, output.choices.map((c) => c.message).filter(truthyFilter)) + : 0; } catch (err) { inputTokens = 0; outputTokens = 0; diff --git a/app/src/modelProviders/types.ts b/app/src/modelProviders/types.ts index 89ea9be..dbbfc67 100644 --- a/app/src/modelProviders/types.ts +++ b/app/src/modelProviders/types.ts @@ -59,7 +59,7 @@ export type ModelProvider Promise>; getUsage: ( input: InputSchema, - output: OutputSchema, + output?: OutputSchema, ) => { gpuRuntime?: number; inputTokens?: number; outputTokens?: number; cost?: number } | null; // This is just a convenience for type inference, don't use it at runtime diff --git a/app/src/pages/datasets/[id].tsx b/app/src/pages/datasets/[id].tsx new file mode 100644 index 0000000..86780a6 --- /dev/null +++ b/app/src/pages/datasets/[id].tsx @@ -0,0 +1,113 @@ +import { + Breadcrumb, + BreadcrumbItem, + Center, + Flex, + Icon, + Input, + VStack, + HStack, + useDisclosure, +} from "@chakra-ui/react"; +import Link from "next/link"; + +import { useState, useEffect } from "react"; +import { AiOutlineDatabase } from "react-icons/ai"; +import AppShell from "~/components/nav/AppShell"; +import { api } from "~/utils/api"; +import { useDataset, useHandledAsyncCallback } from "~/utils/hooks"; +import PageHeaderContainer from "~/components/nav/PageHeaderContainer"; +import ProjectBreadcrumbContents from "~/components/nav/ProjectBreadcrumbContents"; +import DatasetConfigurationDrawer from "~/components/datasets/DatasetConfigurationDrawer/DatasetConfigurationDrawer"; +import { DatasetHeaderButtons } from "~/components/datasets/DatasetHeaderButtons"; +import DatasetEntriesTable from "~/components/datasets/DatasetEntriesTable/DatasetEntriesTable"; +import DatasetEntryPaginator from "~/components/datasets/DatasetEntryPaginator"; +import { useAppStore } from "~/state/store"; +import FineTuneButton from "~/components/datasets/FineTuneButton"; +import ExperimentButton from "~/components/datasets/ExperimentButton"; + +export default function Dataset() { + const utils = api.useContext(); + + const dataset = useDataset(); + + const drawerDisclosure = useDisclosure(); + const [name, setName] = useState(dataset.data?.name || ""); + useEffect(() => { + setName(dataset.data?.name || ""); + }, [dataset.data?.name]); + + useEffect(() => { + useAppStore.getState().sharedArgumentsEditor.loadMonaco().catch(console.error); + }, []); + + const updateMutation = api.datasets.update.useMutation(); + const [onSaveName] = useHandledAsyncCallback(async () => { + if (name && name !== dataset.data?.name && dataset.data?.id) { + await updateMutation.mutateAsync({ + id: dataset.data.id, + name, + }); + await Promise.all([utils.datasets.list.invalidate(), utils.datasets.get.invalidate()]); + } + }, [updateMutation, dataset.data?.id, dataset.data?.name, name]); + + if (!dataset.isLoading && !dataset.data) { + return ( + +
+
Dataset not found 😕
+
+
+ ); + } + + return ( + <> + + + + + + + + + + + Datasets + + + + + setName(e.target.value)} + onBlur={onSaveName} + borderWidth={1} + borderColor="transparent" + fontSize={16} + px={0} + minW={{ base: 100, lg: 300 }} + flex={1} + _hover={{ borderColor: "gray.300" }} + _focus={{ borderColor: "blue.500", outline: "none" }} + /> + + + + + + + + + + + + + + + + + ); +} diff --git a/app/src/pages/datasets/index.tsx b/app/src/pages/datasets/index.tsx new file mode 100644 index 0000000..a223630 --- /dev/null +++ b/app/src/pages/datasets/index.tsx @@ -0,0 +1,17 @@ +import { VStack, Text, Divider } from "@chakra-ui/react"; +import AppShell from "~/components/nav/AppShell"; +import DatasetsTable from "~/components/datasets/DatasetsTable"; + +export default function DatasetsPage() { + return ( + + + + Datasets + + + + + + ); +} diff --git a/app/src/pages/experiments/[experimentSlug].tsx b/app/src/pages/experiments/[experimentSlug].tsx index db72b49..77e15b6 100644 --- a/app/src/pages/experiments/[experimentSlug].tsx +++ b/app/src/pages/experiments/[experimentSlug].tsx @@ -8,26 +8,25 @@ import { Input, Text, VStack, + useDisclosure, } from "@chakra-ui/react"; import Link from "next/link"; -import { useRouter } from "next/router"; import { useState, useEffect } from "react"; import { RiFlaskLine } from "react-icons/ri"; import OutputsTable from "~/components/OutputsTable"; -import ExperimentSettingsDrawer from "~/components/ExperimentSettingsDrawer/ExperimentSettingsDrawer"; +import ExperimentSettingsDrawer from "~/components/experiments/ExperimentSettingsDrawer/ExperimentSettingsDrawer"; +import { ExperimentHeaderButtons } from "~/components/experiments/ExperimentHeaderButtons/ExperimentHeaderButtons"; import AppShell from "~/components/nav/AppShell"; import { api } from "~/utils/api"; import { useExperiment, useHandledAsyncCallback } from "~/utils/hooks"; import { useAppStore } from "~/state/store"; import { useSyncVariantEditor } from "~/state/sync"; -import { ExperimentHeaderButtons } from "~/components/experiments/ExperimentHeaderButtons/ExperimentHeaderButtons"; import Head from "next/head"; import PageHeaderContainer from "~/components/nav/PageHeaderContainer"; import ProjectBreadcrumbContents from "~/components/nav/ProjectBreadcrumbContents"; export default function Experiment() { - const router = useRouter(); const utils = api.useContext(); useSyncVariantEditor(); @@ -44,6 +43,7 @@ export default function Experiment() { useAppStore.getState().sharedVariantEditor.loadMonaco().catch(console.error); }, []); + const drawerDisclosure = useDisclosure(); const [label, setLabel] = useState(experiment.data?.label || ""); useEffect(() => { setLabel(experiment.data?.label || ""); @@ -121,11 +121,11 @@ export default function Experiment() { )} - + - + - +
diff --git a/app/src/pages/request-logs/index.tsx b/app/src/pages/request-logs/index.tsx index f9016dc..14c5d85 100644 --- a/app/src/pages/request-logs/index.tsx +++ b/app/src/pages/request-logs/index.tsx @@ -4,14 +4,13 @@ import { Text, VStack, Divider, HStack, Box } from "@chakra-ui/react"; import AppShell from "~/components/nav/AppShell"; import LoggedCallTable from "~/components/requestLogs/LoggedCallsTable"; import LoggedCallsPaginator from "~/components/requestLogs/LoggedCallsPaginator"; -import ActionButton from "~/components/requestLogs/ActionButton"; +import ActionButton from "~/components/ActionButton"; import { useAppStore } from "~/state/store"; -import { RiFlaskLine } from "react-icons/ri"; import { FiFilter } from "react-icons/fi"; import LogFilters from "~/components/requestLogs/LogFilters/LogFilters"; import ColumnVisiblityDropdown from "~/components/requestLogs/ColumnVisiblityDropdown"; -import FineTuneButton from "~/components/requestLogs/FineTuneButton"; import ExportButton from "~/components/requestLogs/ExportButton"; +import AddToDatasetButton from "~/components/requestLogs/AddToDatasetButton"; export default function LoggedCalls() { const selectedLogIds = useAppStore((s) => s.selectedLogs.selectedLogIds); @@ -27,16 +26,7 @@ export default function LoggedCalls() { - - { - console.log("experimenting with these ids", selectedLogIds); - }} - label="Experiment" - icon={RiFlaskLine} - isDisabled={selectedLogIds.size === 0} - requireBeta - /> + { + const { datasetId, page, pageSize } = input; + + const { projectId } = await prisma.dataset.findUniqueOrThrow({ + where: { id: datasetId }, + }); + await requireCanViewProject(projectId, ctx); + + const [entries, matchingEntries, trainingCount, testingCount] = await prisma.$transaction([ + prisma.datasetEntry.findMany({ + where: { + datasetId: datasetId, + }, + orderBy: [{ createdAt: "desc" }, { id: "desc" }], + skip: (page - 1) * pageSize, + take: pageSize, + }), + prisma.datasetEntry.findMany({ + where: { + datasetId: datasetId, + }, + select: { + id: true, + }, + }), + prisma.datasetEntry.count({ + where: { + datasetId: datasetId, + type: "TRAIN", + }, + }), + prisma.datasetEntry.count({ + where: { + datasetId: datasetId, + type: "TEST", + }, + }), + ]); + + return { + entries, + matchingEntryIds: matchingEntries.map((entry) => entry.id), + trainingCount, + testingCount, + }; + }), + get: protectedProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => { + const entry = await prisma.datasetEntry.findUniqueOrThrow({ + where: { id: input.id }, + include: { + dataset: true, + }, + }); + + if (!entry.dataset) { + throw new TRPCError({ message: "Dataset not found for dataset entry", code: "NOT_FOUND" }); + } + + await requireCanViewProject(entry.dataset.projectId, ctx); + + if (!entry) { + throw new TRPCError({ message: "Dataset entry not found", code: "NOT_FOUND" }); + } + + return entry; + }), + create: protectedProcedure + .input( + z.object({ + datasetId: z.string().optional(), + newDatasetParams: z + .object({ + projectId: z.string(), + name: z.string(), + }) + .optional(), + loggedCallIds: z.string().array(), + }), + ) + .mutation(async ({ input, ctx }) => { + let datasetId: string; + let trainingRatio = 0.8; + if (input.datasetId) { + datasetId = input.datasetId; + const { projectId, trainingRatio: datasetTrainingRatio } = + await prisma.dataset.findUniqueOrThrow({ + where: { id: input.datasetId }, + }); + trainingRatio = datasetTrainingRatio; + await requireCanModifyProject(projectId, ctx); + } else if (input.newDatasetParams) { + await requireCanModifyProject(input.newDatasetParams.projectId, ctx); + datasetId = uuidv4(); + } else { + return error("No datasetId or newDatasetParams provided"); + } + + const [loggedCalls, existingTrainingCount, existingTestingCount] = await prisma.$transaction([ + prisma.loggedCall.findMany({ + where: { + id: { + in: input.loggedCallIds, + }, + modelResponse: { + isNot: null, + }, + }, + include: { + modelResponse: { + select: { + reqPayload: true, + respPayload: true, + inputTokens: true, + outputTokens: true, + }, + }, + }, + }), + prisma.datasetEntry.count({ + where: { + datasetId, + type: "TRAIN", + }, + }), + prisma.datasetEntry.count({ + where: { + datasetId, + type: "TEST", + }, + }), + ]); + + const shuffledLoggedCalls = shuffle(loggedCalls); + + const totalEntries = existingTrainingCount + existingTestingCount + loggedCalls.length; + const numTrainingToAdd = Math.floor(trainingRatio * totalEntries) - existingTrainingCount; + + const datasetEntriesToCreate: Prisma.DatasetEntryCreateManyInput[] = []; + + let i = 0; + for (const loggedCall of shuffledLoggedCalls) { + const inputMessages = ( + loggedCall.modelResponse?.reqPayload as unknown as CompletionCreateParams + ).messages; + let output: ChatCompletion.Choice.Message | undefined = undefined; + const resp = loggedCall.modelResponse?.respPayload as unknown as ChatCompletion | undefined; + if (resp && resp.choices?.[0]) { + output = resp.choices[0].message; + } else { + output = { + role: "assistant", + content: "", + }; + } + + datasetEntriesToCreate.push({ + datasetId, + loggedCallId: loggedCall.id, + input: inputMessages as unknown as Prisma.InputJsonValue, + output: output as unknown as Prisma.InputJsonValue, + inputTokens: loggedCall.modelResponse?.inputTokens || 0, + outputTokens: loggedCall.modelResponse?.outputTokens || 0, + type: i < numTrainingToAdd ? "TRAIN" : "TEST", + }); + i++; + } + + // Ensure dataset and dataset entries are created atomically + await prisma.$transaction([ + prisma.dataset.upsert({ + where: { id: datasetId }, + update: {}, + create: { + id: datasetId, + projectId: input.newDatasetParams?.projectId ?? "", + name: input.newDatasetParams?.name ?? "", + trainingRatio, + }, + }), + prisma.datasetEntry.createMany({ + data: shuffle(datasetEntriesToCreate), + }), + ]); + + return success(datasetId); + }), + + update: protectedProcedure + .input( + z.object({ + id: z.string(), + updates: z.object({ + type: z.enum(["TRAIN", "TEST"]).optional(), + input: z.string().optional(), + output: z.string().optional(), + }), + }), + ) + .mutation(async ({ input, ctx }) => { + const { dataset } = await prisma.datasetEntry.findUniqueOrThrow({ + where: { id: input.id }, + include: { + dataset: true, + }, + }); + + if (!dataset) { + return error("Dataset not found for dataset entry"); + } + + await requireCanModifyProject(dataset.projectId, ctx); + + let parsedInput = undefined; + let inputTokens = undefined; + if (input.updates.input) { + parsedInput = JSON.parse(input.updates.input); + inputTokens = countOpenAIChatTokens( + "gpt-4-0613", + parsedInput as unknown as CreateChatCompletionRequestMessage[], + ); + } + + let parsedOutput = undefined; + let outputTokens = undefined; + if (input.updates.output) { + parsedOutput = JSON.parse(input.updates.output); + outputTokens = countOpenAIChatTokens("gpt-4-0613", [ + parsedOutput as unknown as ChatCompletion.Choice.Message, + ]); + } + + await prisma.datasetEntry.update({ + where: { id: input.id }, + data: { + type: input.updates.type, + input: parsedInput, + output: parsedOutput, + inputTokens, + outputTokens, + }, + }); + + return success("Dataset entry updated"); + }), + + delete: protectedProcedure + .input(z.object({ ids: z.string().array() })) + .mutation(async ({ input, ctx }) => { + if (input.ids.length === 0) { + return error("No ids provided"); + } + const { dataset } = await prisma.datasetEntry.findUniqueOrThrow({ + where: { id: input.ids[0] }, + include: { + dataset: true, + }, + }); + + if (!dataset) { + return error("Dataset not found for dataset entry"); + } + + await requireCanModifyProject(dataset.projectId, ctx); + + await prisma.datasetEntry.deleteMany({ + where: { + id: { + in: input.ids, + }, + datasetId: dataset?.id, + }, + }); + + return success("Dataset entries deleted"); + }), +}); diff --git a/app/src/server/api/routers/datasets.router.ts b/app/src/server/api/routers/datasets.router.ts new file mode 100644 index 0000000..596717e --- /dev/null +++ b/app/src/server/api/routers/datasets.router.ts @@ -0,0 +1,97 @@ +import { z } from "zod"; +import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc"; +import { prisma } from "~/server/db"; +import { requireCanModifyProject, requireCanViewProject } from "~/utils/accessControl"; +import { success } from "~/utils/errorHandling/standardResponses"; + +export const datasetsRouter = createTRPCRouter({ + get: protectedProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => { + const dataset = await prisma.dataset.findUniqueOrThrow({ + where: { id: input.id }, + include: { + project: true, + }, + }); + + await requireCanViewProject(dataset.projectId, ctx); + + return dataset; + }), + list: protectedProcedure + .input(z.object({ projectId: z.string() })) + .query(async ({ input, ctx }) => { + await requireCanViewProject(input.projectId, ctx); + + return await prisma.dataset.findMany({ + where: { + projectId: input.projectId, + }, + include: { + _count: { + select: { + datasetEntries: true, + }, + }, + }, + orderBy: { createdAt: "desc" }, + }); + }), + + create: protectedProcedure + .input( + z.object({ + projectId: z.string(), + name: z.string(), + }), + ) + .mutation(async ({ input, ctx }) => { + await requireCanModifyProject(input.projectId, ctx); + + const dataset = await prisma.dataset.create({ + data: { + projectId: input.projectId, + name: input.name, + }, + }); + + return success(dataset); + }), + + update: protectedProcedure + .input( + z.object({ + id: z.string(), + name: z.string(), + }), + ) + .mutation(async ({ input, ctx }) => { + const { projectId } = await prisma.dataset.findUniqueOrThrow({ + where: { id: input.id }, + }); + await requireCanModifyProject(projectId, ctx); + + await prisma.dataset.update({ + where: { id: input.id }, + data: { + name: input.name, + }, + }); + + return success("Dataset updated"); + }), + + delete: protectedProcedure + .input(z.object({ id: z.string() })) + .mutation(async ({ input, ctx }) => { + const { projectId } = await prisma.dataset.findUniqueOrThrow({ + where: { id: input.id }, + }); + await requireCanModifyProject(projectId, ctx); + + await prisma.dataset.delete({ + where: { id: input.id }, + }); + + return success("Dataset deleted"); + }), +}); diff --git a/app/src/server/api/routers/fineTunes.router.ts b/app/src/server/api/routers/fineTunes.router.ts index cff976e..9e06b66 100644 --- a/app/src/server/api/routers/fineTunes.router.ts +++ b/app/src/server/api/routers/fineTunes.router.ts @@ -1,6 +1,4 @@ import { z } from "zod"; -import { v4 as uuidv4 } from "uuid"; -import { type Prisma } from "@prisma/client"; import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc"; import { prisma } from "~/server/db"; @@ -55,14 +53,18 @@ export const fineTunesRouter = createTRPCRouter({ create: protectedProcedure .input( z.object({ - projectId: z.string(), - selectedLogIds: z.array(z.string()), + datasetId: z.string(), slug: z.string(), baseModel: z.string(), }), ) .mutation(async ({ input, ctx }) => { - await requireCanModifyProject(input.projectId, ctx); + const { projectId } = await prisma.dataset.findUniqueOrThrow({ + where: { + id: input.datasetId, + }, + }); + await requireCanModifyProject(projectId, ctx); const existingFineTune = await prisma.fineTune.findFirst({ where: { @@ -74,39 +76,14 @@ export const fineTunesRouter = createTRPCRouter({ return error("A fine tune with that slug already exists"); } - const newDatasetId = uuidv4(); - - const datasetEntriesToCreate: Prisma.DatasetEntryCreateManyDatasetInput[] = - input.selectedLogIds.map((loggedCallId) => ({ - loggedCallId, - })); - - await prisma.$transaction([ - prisma.dataset.create({ - data: { - id: newDatasetId, - name: input.slug, - project: { - connect: { - id: input.projectId, - }, - }, - datasetEntries: { - createMany: { - data: datasetEntriesToCreate, - }, - }, - }, - }), - prisma.fineTune.create({ - data: { - projectId: input.projectId, - slug: input.slug, - baseModel: input.baseModel, - datasetId: newDatasetId, - }, - }), - ]); + await prisma.fineTune.create({ + data: { + projectId, + slug: input.slug, + baseModel: input.baseModel, + datasetId: input.datasetId, + }, + }); return success(); }), diff --git a/app/src/server/api/routers/loggedCalls.router.ts b/app/src/server/api/routers/loggedCalls.router.ts index 6eb5dc1..ede408b 100644 --- a/app/src/server/api/routers/loggedCalls.router.ts +++ b/app/src/server/api/routers/loggedCalls.router.ts @@ -189,7 +189,7 @@ export const loggedCallsRouter = createTRPCRouter({ .input( z.object({ projectId: z.string(), - selectedLogIds: z.string().array(), + loggedCallIds: z.string().array(), testingSplit: z.number(), selectedExportFormat: z.string(), removeDuplicates: z.boolean(), @@ -203,7 +203,7 @@ export const loggedCallsRouter = createTRPCRouter({ where: { originalLoggedCall: { projectId: input.projectId, - id: { in: input.selectedLogIds }, + id: { in: input.loggedCallIds }, }, statusCode: 200, }, diff --git a/app/src/server/api/routers/promptVariants.router.ts b/app/src/server/api/routers/promptVariants.router.ts index ececa15..a37f448 100644 --- a/app/src/server/api/routers/promptVariants.router.ts +++ b/app/src/server/api/routers/promptVariants.router.ts @@ -93,17 +93,12 @@ export const promptVariantsRouter = createTRPCRouter({ visible: true, }, }); - const outputCount = await prisma.scenarioVariantCell.count({ + const finishedCount = await prisma.scenarioVariantCell.count({ where: { promptVariantId: input.variantId, testScenario: { visible: true }, - modelResponses: { - some: { - outdated: false, - respPayload: { - not: Prisma.AnyNull, - }, - }, + retrievalStatus: { + in: ["COMPLETE", "ERROR"], }, }, }); @@ -131,7 +126,7 @@ export const promptVariantsRouter = createTRPCRouter({ const inputTokens = overallTokens._sum?.inputTokens ?? 0; const outputTokens = overallTokens._sum?.outputTokens ?? 0; - const awaitingCompletions = outputCount < scenarioCount; + const awaitingCompletions = finishedCount < scenarioCount; const awaitingEvals = !!evalResults.find( (result) => result.totalCount < scenarioCount * evals.length, @@ -143,7 +138,7 @@ export const promptVariantsRouter = createTRPCRouter({ outputTokens, overallCost: overallTokens._sum?.cost ?? 0, scenarioCount, - outputCount, + finishedCount, awaitingCompletions, awaitingEvals, }; diff --git a/app/src/state/selectedDatasetEntriesSlice.ts b/app/src/state/selectedDatasetEntriesSlice.ts new file mode 100644 index 0000000..ef5b01d --- /dev/null +++ b/app/src/state/selectedDatasetEntriesSlice.ts @@ -0,0 +1,33 @@ +import { type SliceCreator } from "./store"; + +export type SelectedDatasetEntriesSlice = { + selectedIds: Set; + toggleSelectedId: (id: string) => void; + addSelectedIds: (ids: string[]) => void; + clearSelectedIds: () => void; +}; + +export const createSelectedDatasetEntriesSlice: SliceCreator = ( + set, +) => ({ + selectedIds: new Set(), + toggleSelectedId: (id: string) => + set((state) => { + if (state.selectedDatasetEntries.selectedIds.has(id)) { + state.selectedDatasetEntries.selectedIds.delete(id); + } else { + state.selectedDatasetEntries.selectedIds.add(id); + } + }), + addSelectedIds: (ids: string[]) => + set((state) => { + state.selectedDatasetEntries.selectedIds = new Set([ + ...state.selectedDatasetEntries.selectedIds, + ...ids, + ]); + }), + clearSelectedIds: () => + set((state) => { + state.selectedDatasetEntries.selectedIds = new Set(); + }), +}); diff --git a/app/src/state/selectedLogsSlice.ts b/app/src/state/selectedLogsSlice.ts index 1af55c1..e0c4691 100644 --- a/app/src/state/selectedLogsSlice.ts +++ b/app/src/state/selectedLogsSlice.ts @@ -7,7 +7,7 @@ export type SelectedLogsSlice = { clearSelectedLogIds: () => void; }; -export const createSelectedLogsSlice: SliceCreator = (set, get) => ({ +export const createSelectedLogsSlice: SliceCreator = (set) => ({ selectedLogIds: new Set(), toggleSelectedLogId: (id: string) => set((state) => { diff --git a/app/src/state/sharedArgumentsEditor.slice.ts b/app/src/state/sharedArgumentsEditor.slice.ts new file mode 100644 index 0000000..9823eda --- /dev/null +++ b/app/src/state/sharedArgumentsEditor.slice.ts @@ -0,0 +1,33 @@ +import loader, { type Monaco } from "@monaco-editor/loader"; + +import { type SliceCreator } from "./store"; + +export const editorBackground = "#fafafa"; + +export type SharedArgumentsEditorSlice = { + monaco: null | Monaco; + loadMonaco: () => Promise; +}; + +export const createArgumentsEditorSlice: SliceCreator = (set, get) => ({ + monaco: loader.__getMonacoInstance(), + loadMonaco: async () => { + // We only want to run this client-side + if (typeof window === "undefined") return; + + const monaco = await loader.init(); + + monaco.editor.defineTheme("customTheme", { + base: "vs", + inherit: true, + rules: [], + colors: { + "editor.background": "#ffffff", + }, + }); + + set((state) => { + state.sharedArgumentsEditor.monaco = monaco; + }); + }, +}); diff --git a/app/src/state/store.ts b/app/src/state/store.ts index bbd3ce1..4ae3e27 100644 --- a/app/src/state/store.ts +++ b/app/src/state/store.ts @@ -7,9 +7,17 @@ import { type SharedVariantEditorSlice, createVariantEditorSlice, } from "./sharedVariantEditor.slice"; +import { + type SharedArgumentsEditorSlice, + createArgumentsEditorSlice, +} from "./sharedArgumentsEditor.slice"; import { type APIClient } from "~/utils/api"; import { type PersistedState, persistOptions } from "./persist"; import { type SelectedLogsSlice, createSelectedLogsSlice } from "./selectedLogsSlice"; +import { + type SelectedDatasetEntriesSlice, + createSelectedDatasetEntriesSlice, +} from "./selectedDatasetEntriesSlice"; import { type LogFiltersSlice, createLogFiltersSlice } from "./logFiltersSlice"; import { type ColumnVisibilitySlice, createColumnVisibilitySlice } from "./columnVisiblitySlice"; import { type FeatureFlagsSlice, createFeatureFlagsSlice } from "./featureFlags"; @@ -18,15 +26,14 @@ enableMapSet(); export type State = { isRehydrated: boolean; - drawerOpen: boolean; - openDrawer: () => void; - closeDrawer: () => void; api: APIClient | null; setApi: (api: APIClient) => void; sharedVariantEditor: SharedVariantEditorSlice; + sharedArgumentsEditor: SharedArgumentsEditorSlice; selectedProjectId: string | null; setSelectedProjectId: (id: string) => void; selectedLogs: SelectedLogsSlice; + selectedDatasetEntries: SelectedDatasetEntriesSlice; logFilters: LogFiltersSlice; columnVisibility: ColumnVisibilitySlice; featureFlags: FeatureFlagsSlice; @@ -46,22 +53,15 @@ const useBaseStore = create { state.api = api; }), - drawerOpen: false, - openDrawer: () => - set((state) => { - state.drawerOpen = true; - }), - closeDrawer: () => - set((state) => { - state.drawerOpen = false; - }), sharedVariantEditor: createVariantEditorSlice(set, get, ...rest), + sharedArgumentsEditor: createArgumentsEditorSlice(set, get, ...rest), selectedProjectId: null, setSelectedProjectId: (id: string) => set((state) => { state.selectedProjectId = id; }), selectedLogs: createSelectedLogsSlice(set, get, ...rest), + selectedDatasetEntries: createSelectedDatasetEntriesSlice(set, get, ...rest), logFilters: createLogFiltersSlice(set, get, ...rest), columnVisibility: createColumnVisibilitySlice(set, get, ...rest), featureFlags: createFeatureFlagsSlice(set, get, ...rest), diff --git a/app/src/utils/countTokens.ts b/app/src/utils/countTokens.ts index 653adaa..2a4ff14 100644 --- a/app/src/utils/countTokens.ts +++ b/app/src/utils/countTokens.ts @@ -12,6 +12,13 @@ export const countOpenAIChatTokens = ( model: SupportedModel, messages: ChatCompletion.Choice.Message[], ) => { - return new GPTTokens({ model, messages: messages as unknown as GPTTokensMessageItem[] }) - .usedTokens; + const reformattedMessages = messages.map((message) => ({ + role: message.role, + // Not completely accurate, but gives a rough idea of the token count + content: message.content ?? JSON.stringify(message.function_call), + })); + return new GPTTokens({ + model, + messages: reformattedMessages as unknown as GPTTokensMessageItem[], + }).usedTokens; }; diff --git a/app/src/utils/hooks.ts b/app/src/utils/hooks.ts index 9815f3f..bd51d8a 100644 --- a/app/src/utils/hooks.ts +++ b/app/src/utils/hooks.ts @@ -148,6 +148,49 @@ export const useScenarioVars = () => { ); }; +export const useDatasets = () => { + const selectedProjectId = useAppStore((state) => state.selectedProjectId); + return api.datasets.list.useQuery( + { projectId: selectedProjectId ?? "" }, + { enabled: !!selectedProjectId }, + ); +}; + +export const useDataset = () => { + const router = useRouter(); + const dataset = api.datasets.get.useQuery( + { id: router.query.id as string }, + { enabled: !!router.query.id }, + ); + + return dataset; +}; + +export const useDatasetEntries = () => { + const dataset = useDataset().data; + const { page, pageSize } = usePageParams(); + + const { data, isLoading, ...rest } = api.datasetEntries.list.useQuery( + { datasetId: dataset?.id ?? "", page, pageSize }, + { enabled: !!dataset?.id }, + ); + + const [stableData, setStableData] = useState(data); + + useEffect(() => { + // Prevent annoying flashes while logs are loading from the server + if (!isLoading) { + setStableData(data); + } + }, [data, isLoading]); + + return { data: stableData, isLoading, ...rest }; +}; + +export const useDatasetEntry = (entryId: string | null) => { + return api.datasetEntries.get.useQuery({ id: entryId as string }, { enabled: !!entryId }); +}; + export const useLoggedCalls = (applyFilters = true) => { const selectedProjectId = useAppStore((state) => state.selectedProjectId); const { page, pageSize } = usePageParams(); diff --git a/app/src/utils/utils.ts b/app/src/utils/utils.ts index a1432b9..6688013 100644 --- a/app/src/utils/utils.ts +++ b/app/src/utils/utils.ts @@ -10,3 +10,45 @@ export const lookupModel = (provider: string, model: string) => { export const modelLabel = (provider: string, model: string) => `${provider}/${lookupModel(provider, model)?.name ?? model}`; + +// Check if the str could be parsed to a message function call +export const parseableToFunctionCall = (str: string) => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + let parsedJSON: any; + try { + parsedJSON = JSON.parse(str); + } catch { + return false; + } + + // Check if the parsedJSON is an object and not null + if (typeof parsedJSON !== "object" || parsedJSON === null) { + return false; + } + + // Check if only the keys "name" and "arguments" exist + const keys = Object.keys(parsedJSON as Record); + if (keys.length !== 2 || !keys.includes("name") || !keys.includes("arguments")) { + return false; + } + + // Check if both "name" and "arguments" are of type string + if (typeof parsedJSON.name !== "string" || typeof parsedJSON.arguments !== "string") { + return false; + } + + // Check if the "arguments" value is parseable to an object + let parsedArguments: unknown; + try { + parsedArguments = JSON.parse(parsedJSON["arguments"]); + } catch { + return false; + } + + // Check if parsedArguments is an object and not null + if (typeof parsedArguments !== "object" || parsedArguments === null) { + return false; + } + + return true; +};