Export formatted logged calls (#187)

* Export formatted data

* Properly update inputMessageHashMap

* Hide remove duplicates checkbox in advanced options

* Remove unused import
This commit is contained in:
arcticfly
2023-08-23 20:45:27 -07:00
committed by GitHub
parent a4131e4a10
commit 9632ccbc71
5 changed files with 508 additions and 0 deletions

View File

@@ -48,6 +48,7 @@
"@trpc/react-query": "^10.26.0",
"@trpc/server": "^10.26.0",
"@vercel/og": "^0.5.9",
"archiver": "^6.0.0",
"ast-types": "^0.14.2",
"chroma-js": "^2.4.2",
"concurrently": "^8.2.0",
@@ -99,6 +100,7 @@
"replicate": "^0.12.3",
"socket.io": "^4.7.1",
"socket.io-client": "^4.7.1",
"stream-buffers": "^3.0.2",
"superjson": "1.12.2",
"trpc-openapi": "^1.2.0",
"tsx": "^3.12.7",
@@ -111,6 +113,7 @@
},
"devDependencies": {
"@openapi-contrib/openapi-schema-to-json-schema": "^4.0.5",
"@types/archiver": "^5.3.2",
"@types/babel__core": "^7.20.1",
"@types/babel__standalone": "^7.1.4",
"@types/chroma-js": "^2.4.0",
@@ -127,6 +130,7 @@
"@types/react": "^18.2.6",
"@types/react-dom": "^18.2.4",
"@types/react-syntax-highlighter": "^15.5.7",
"@types/stream-buffers": "^3.0.4",
"@types/uuid": "^9.0.2",
"@typescript-eslint/eslint-plugin": "^5.59.6",
"@typescript-eslint/parser": "^5.59.6",

View File

@@ -0,0 +1,197 @@
import { useState, useEffect } from "react";
import {
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalCloseButton,
ModalBody,
ModalFooter,
HStack,
VStack,
Icon,
Text,
Button,
Checkbox,
NumberInput,
NumberInputField,
NumberInputStepper,
NumberIncrementStepper,
NumberDecrementStepper,
Collapse,
useDisclosure,
type UseDisclosureReturn,
} from "@chakra-ui/react";
import { BiExport } from "react-icons/bi";
import { useHandledAsyncCallback } from "~/utils/hooks";
import { api } from "~/utils/api";
import { useAppStore } from "~/state/store";
import ActionButton from "./ActionButton";
import InputDropdown from "../InputDropdown";
import { FiChevronUp, FiChevronDown } from "react-icons/fi";
const SUPPORTED_EXPORT_FORMATS = ["alpaca-finetune", "openai-fine-tune", "unformatted"];
const ExportButton = () => {
const selectedLogIds = useAppStore((s) => s.selectedLogs.selectedLogIds);
const disclosure = useDisclosure();
return (
<>
<ActionButton
onClick={disclosure.onOpen}
label="Export"
icon={BiExport}
isDisabled={selectedLogIds.size === 0}
/>
<ExportLogsModal disclosure={disclosure} />
</>
);
};
export default ExportButton;
const ExportLogsModal = ({ disclosure }: { disclosure: UseDisclosureReturn }) => {
const selectedProjectId = useAppStore((s) => s.selectedProjectId);
const selectedLogIds = useAppStore((s) => s.selectedLogs.selectedLogIds);
const clearSelectedLogIds = useAppStore((s) => s.selectedLogs.clearSelectedLogIds);
const [selectedExportFormat, setSelectedExportFormat] = useState(SUPPORTED_EXPORT_FORMATS[0]);
const [testingSplit, setTestingSplit] = useState(10);
const [removeDuplicates, setRemoveDuplicates] = useState(true);
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
useEffect(() => {
if (disclosure.isOpen) {
setSelectedExportFormat(SUPPORTED_EXPORT_FORMATS[0]);
setTestingSplit(10);
setRemoveDuplicates(true);
}
}, [disclosure.isOpen]);
const exportLogsMutation = api.loggedCalls.export.useMutation();
const [exportLogs, exportInProgress] = useHandledAsyncCallback(async () => {
if (!selectedProjectId || !selectedLogIds.size || !testingSplit || !selectedExportFormat)
return;
const response = await exportLogsMutation.mutateAsync({
projectId: selectedProjectId,
selectedLogIds: Array.from(selectedLogIds),
testingSplit,
selectedExportFormat,
removeDuplicates,
});
const dataUrl = `data:application/pdf;base64,${response}`;
const blob = await fetch(dataUrl).then((res) => res.blob());
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = `data.zip`;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
disclosure.onClose();
clearSelectedLogIds();
}, [
exportLogsMutation,
selectedProjectId,
selectedLogIds,
testingSplit,
selectedExportFormat,
removeDuplicates,
]);
return (
<Modal size={{ base: "xl", md: "2xl" }} {...disclosure}>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader>
<HStack>
<Icon as={BiExport} />
<Text>Export Logs</Text>
</HStack>
</ModalHeader>
<ModalCloseButton />
<ModalBody maxW="unset">
<VStack w="full" spacing={8} pt={4} alignItems="flex-start">
<Text>
We'll export the <b>{selectedLogIds.size}</b> logs you have selected in the format of
your choice.
</Text>
<VStack alignItems="flex-start" spacing={4}>
<HStack spacing={2}>
<Text fontWeight="bold" w={48}>
Format:
</Text>
<InputDropdown
options={SUPPORTED_EXPORT_FORMATS}
selectedOption={selectedExportFormat}
onSelect={(option) => setSelectedExportFormat(option)}
inputGroupProps={{ w: 48 }}
/>
</HStack>
<HStack spacing={2}>
<Text fontWeight="bold" w={48}>
Testing Split:
</Text>
<HStack>
<NumberInput
defaultValue={10}
onChange={(_, num) => setTestingSplit(num)}
min={1}
max={100}
w={48}
>
<NumberInputField />
<NumberInputStepper>
<NumberIncrementStepper />
<NumberDecrementStepper />
</NumberInputStepper>
</NumberInput>
</HStack>
</HStack>
</VStack>
<VStack alignItems="flex-start" spacing={0}>
<Button
variant="unstyled"
color="blue.600"
onClick={() => setShowAdvancedOptions(!showAdvancedOptions)}
>
<HStack>
<Text>Advanced Options</Text>
<Icon as={showAdvancedOptions ? FiChevronUp : FiChevronDown} />
</HStack>
</Button>
<Collapse in={showAdvancedOptions} unmountOnExit={true}>
<VStack align="stretch" pt={4}>
<Checkbox
colorScheme="blue"
isChecked={removeDuplicates}
onChange={(e) => setRemoveDuplicates(e.target.checked)}
>
<Text>Remove duplicates? (recommended)</Text>
</Checkbox>
</VStack>
</Collapse>
</VStack>
</VStack>
</ModalBody>
<ModalFooter>
<HStack>
<Button colorScheme="gray" onClick={disclosure.onClose} minW={24}>
Cancel
</Button>
<Button colorScheme="blue" onClick={exportLogs} isLoading={exportInProgress} minW={24}>
Export
</Button>
</HStack>
</ModalFooter>
</ModalContent>
</Modal>
);
};

View File

@@ -11,6 +11,7 @@ import { FiFilter } from "react-icons/fi";
import LogFilters from "~/components/requestLogs/LogFilters/LogFilters";
import ColumnVisiblityDropdown from "~/components/requestLogs/ColumnVisiblityDropdown";
import FineTuneButton from "~/components/requestLogs/FineTuneButton";
import ExportButton from "~/components/requestLogs/ExportButton";
export default function LoggedCalls() {
const selectedLogIds = useAppStore((s) => s.selectedLogs.selectedLogIds);
@@ -35,6 +36,7 @@ export default function LoggedCalls() {
icon={RiFlaskLine}
isDisabled={selectedLogIds.size === 0}
/>
<ExportButton />
<ColumnVisiblityDropdown />
<ActionButton
onClick={() => {

View File

@@ -1,11 +1,16 @@
import { z } from "zod";
import { type Expression, type SqlBool, sql, type RawBuilder } from "kysely";
import { jsonArrayFrom } from "kysely/helpers/postgres";
import archiver from "archiver";
import { WritableStreamBuffer } from "stream-buffers";
import { type JsonValue } from "type-fest";
import { shuffle } from "lodash-es";
import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc";
import { kysely, prisma } from "~/server/db";
import { comparators, defaultFilterableFields } from "~/state/logFiltersSlice";
import { requireCanViewProject } from "~/utils/accessControl";
import hashObject from "~/server/utils/hashObject";
// create comparator type based off of comparators
const comparatorToSqlExpression = (comparator: (typeof comparators)[number], value: string) => {
@@ -180,4 +185,101 @@ export const loggedCallsRouter = createTRPCRouter({
return tags.map((tag) => tag.name);
}),
export: protectedProcedure
.input(
z.object({
projectId: z.string(),
selectedLogIds: z.string().array(),
testingSplit: z.number(),
selectedExportFormat: z.string(),
removeDuplicates: z.boolean(),
}),
)
.mutation(async ({ input, ctx }) => {
await requireCanViewProject(input.projectId, ctx);
// Fetch the real data using Prisma
const loggedCallsFromDb = await ctx.prisma.loggedCallModelResponse.findMany({
where: {
originalLoggedCall: {
projectId: input.projectId,
id: { in: input.selectedLogIds },
},
},
});
// Convert the database data into the desired format
let formattedLoggedCalls: { input: JsonValue[]; output: JsonValue }[] = loggedCallsFromDb.map(
(call) => ({
input: (call.reqPayload as unknown as Record<string, unknown>).messages as JsonValue[],
output: (call.respPayload as unknown as { choices: { message: unknown }[] }).choices[0]
?.message as JsonValue,
}),
);
if (input.removeDuplicates) {
const deduplicatedLoggedCalls = [];
const loggedCallHashSet = new Set<string>();
for (const loggedCall of formattedLoggedCalls) {
const loggedCallHash = hashObject(loggedCall);
if (!loggedCallHashSet.has(loggedCallHash)) {
loggedCallHashSet.add(loggedCallHash);
deduplicatedLoggedCalls.push(loggedCall);
}
}
formattedLoggedCalls = deduplicatedLoggedCalls;
}
// Remove duplicate messages from input
const inputMessageHashMap = new Map<string, number>();
for (const loggedCall of formattedLoggedCalls) {
for (const message of loggedCall.input) {
const hash = hashObject(message);
if (inputMessageHashMap.has(hash)) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
inputMessageHashMap.set(hash, inputMessageHashMap.get(hash)! + 1);
} else {
inputMessageHashMap.set(hash, 0);
}
}
}
for (const loggedCall of formattedLoggedCalls) {
loggedCall.input = loggedCall.input.filter((message) => {
const hash = hashObject(message);
// If the same message appears in a single input multiple times, there is some danger of
// it being removed from all logged calls. This is enough of an edge case that we don't
// need to worry about it for now.
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return inputMessageHashMap.get(hash)! < formattedLoggedCalls.length;
});
}
// Stringify inputs and outputs
const stringifiedLoggedCalls = shuffle(formattedLoggedCalls).map((loggedCall) => ({
input: JSON.stringify(loggedCall.input),
output: JSON.stringify(loggedCall.output),
}));
const splitIndex = Math.floor((stringifiedLoggedCalls.length * input.testingSplit) / 100);
const testingData = stringifiedLoggedCalls.slice(0, splitIndex);
const trainingData = stringifiedLoggedCalls.slice(splitIndex);
// Convert arrays to JSONL format
const trainingDataJSONL = trainingData.map((item) => JSON.stringify(item)).join("\n");
const testingDataJSONL = testingData.map((item) => JSON.stringify(item)).join("\n");
const output = new WritableStreamBuffer();
const archive = archiver("zip");
archive.pipe(output);
archive.append(trainingDataJSONL, { name: "train.jsonl" });
archive.append(testingDataJSONL, { name: "test.jsonl" });
await archive.finalize();
// Convert buffer to base64
const base64 = output.getContents().toString("base64");
return base64;
}),
});