Compare commits

..

13 Commits

Author SHA1 Message Date
David Corbitt
a1e3036064 Remove unused import 2023-08-18 01:35:01 -07:00
David Corbitt
22a1423690 Square header border when scrolled down 2023-08-18 01:33:13 -07:00
arcticfly
733d53625b Add Gryphe/MythoMax-L2-13b (#173) 2023-08-18 00:37:16 -07:00
arcticfly
a5e59e4235 Allow user to delete scenario without variables (#172)
* Allow user to delete scenario without variables

* Hide expand button for empty scenario editor

* Add header to scenario modal
2023-08-18 00:08:32 -07:00
Kyle Corbitt
d0102e3202 Merge pull request #171 from OpenPipe/experiment-slug
Use shorter experiment IDs
2023-08-17 23:33:30 -07:00
Kyle Corbitt
bd571c4c4e Merge pull request #170 from OpenPipe/jobs-log
Enqueue tasks more efficiently
2023-08-17 23:33:20 -07:00
Kyle Corbitt
296eb23d97 Use shorter experiment IDs
Because https://app.openpipe.ai/experiments/B1EtN6oHeXMele2 is a cooler URL than https://app.openpipe.ai/experiments/3692942c-6f1b-4bef-83b1-c11f00a3fbdd
2023-08-17 23:28:56 -07:00
Kyle Corbitt
4e2ae7a441 Enqueue tasks more efficiently
Previously we were opening a new database connection for each task we added. Not a problem at small scale but kinda overwhelming for Postgres now that we have more usage.
2023-08-17 22:42:46 -07:00
Kyle Corbitt
072dcee376 Merge pull request #168 from OpenPipe/jobs-log
Admin dashboard for jobs
2023-08-17 22:26:10 -07:00
Kyle Corbitt
94464c0617 Admin dashboard for jobs
Extremely simple jobs dashboard to sanity-check what we've got going on in the job queue.
2023-08-17 22:20:39 -07:00
arcticfly
980644f13c Support vicuna system message (#167)
* Support vicuna system message

* Change tags to USER and ASSISTANT
2023-08-17 21:02:27 -07:00
arcticfly
6a56250001 Add platypus 13b, vicuna 13b, and nous hermes 7b (#166)
* Add platypus

* Add vicuna 13b and nous hermes 7b
2023-08-17 20:01:10 -07:00
Kyle Corbitt
b1c7bbbd4a Merge pull request #165 from OpenPipe/better-output
Don't define CellWrapper inline
2023-08-17 19:07:32 -07:00
36 changed files with 912 additions and 214 deletions

View File

@@ -12,6 +12,7 @@ declare module "nextjs-routes" {
export type Route =
| StaticRoute<"/account/signin">
| StaticRoute<"/admin/jobs">
| DynamicRoute<"/api/auth/[...nextauth]", { "nextauth": string[] }>
| StaticRoute<"/api/experiments/og-image">
| DynamicRoute<"/api/trpc/[trpc]", { "trpc": string }>
@@ -20,7 +21,7 @@ declare module "nextjs-routes" {
| StaticRoute<"/dashboard">
| DynamicRoute<"/data/[id]", { "id": string }>
| StaticRoute<"/data">
| DynamicRoute<"/experiments/[id]", { "id": string }>
| DynamicRoute<"/experiments/[experimentSlug]", { "experimentSlug": string }>
| StaticRoute<"/experiments">
| StaticRoute<"/">
| DynamicRoute<"/invitations/[invitationToken]", { "invitationToken": string }>

View File

@@ -18,6 +18,7 @@
"lint": "next lint",
"start": "TZ=UTC next start",
"codegen:clients": "tsx src/server/scripts/client-codegen.ts",
"codegen:db": "prisma generate && kysely-codegen --dialect postgres --out-file src/server/db.types.ts",
"seed": "tsx prisma/seed.ts",
"check": "concurrently 'pnpm lint' 'pnpm tsc' 'pnpm prettier . --check'",
"test": "pnpm vitest"
@@ -65,6 +66,7 @@
"json-stringify-pretty-compact": "^4.0.0",
"jsonschema": "^1.4.1",
"kysely": "^0.26.1",
"kysely-codegen": "^0.10.1",
"lodash-es": "^4.17.21",
"lucide-react": "^0.265.0",
"marked": "^7.0.3",

View File

@@ -0,0 +1,88 @@
/*
* Copyright 2023 Viascom Ltd liab. Co
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
CREATE EXTENSION IF NOT EXISTS pgcrypto;
CREATE OR REPLACE FUNCTION nanoid(
size int DEFAULT 21,
alphabet text DEFAULT '_-0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
)
RETURNS text
LANGUAGE plpgsql
volatile
AS
$$
DECLARE
idBuilder text := '';
counter int := 0;
bytes bytea;
alphabetIndex int;
alphabetArray text[];
alphabetLength int;
mask int;
step int;
BEGIN
alphabetArray := regexp_split_to_array(alphabet, '');
alphabetLength := array_length(alphabetArray, 1);
mask := (2 << cast(floor(log(alphabetLength - 1) / log(2)) as int)) - 1;
step := cast(ceil(1.6 * mask * size / alphabetLength) AS int);
while true
loop
bytes := gen_random_bytes(step);
while counter < step
loop
alphabetIndex := (get_byte(bytes, counter) & mask) + 1;
if alphabetIndex <= alphabetLength then
idBuilder := idBuilder || alphabetArray[alphabetIndex];
if length(idBuilder) = size then
return idBuilder;
end if;
end if;
counter := counter + 1;
end loop;
counter := 0;
end loop;
END
$$;
-- Make a short_nanoid function that uses the default alphabet and length of 15
CREATE OR REPLACE FUNCTION short_nanoid()
RETURNS text
LANGUAGE plpgsql
volatile
AS
$$
BEGIN
RETURN nanoid(15, '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ');
END
$$;
-- AlterTable
ALTER TABLE "Experiment" ADD COLUMN "slug" TEXT NOT NULL DEFAULT short_nanoid();
-- For existing experiments, keep the existing id as the slug for backwards compatibility
UPDATE "Experiment" SET "slug" = "id";
-- CreateIndex
CREATE UNIQUE INDEX "Experiment_slug_key" ON "Experiment"("slug");

View File

@@ -11,7 +11,9 @@ datasource db {
}
model Experiment {
id String @id @default(uuid()) @db.Uuid
id String @id @default(uuid()) @db.Uuid
slug String @unique @default(dbgenerated("short_nanoid()"))
label String
sortIndex Int @default(0)
@@ -207,14 +209,14 @@ model Project {
personalProjectUserId String? @unique @db.Uuid
personalProjectUser User? @relation(fields: [personalProjectUserId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
projectUsers ProjectUser[]
projectUserInvitations UserInvitation[]
experiments Experiment[]
datasets Dataset[]
loggedCalls LoggedCall[]
apiKeys ApiKey[]
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
projectUsers ProjectUser[]
projectUserInvitations UserInvitation[]
experiments Experiment[]
datasets Dataset[]
loggedCalls LoggedCall[]
apiKeys ApiKey[]
}
enum ProjectUserRole {
@@ -324,10 +326,10 @@ model LoggedCallModelResponse {
}
model LoggedCallTag {
id String @id @default(uuid()) @db.Uuid
name String
value String?
projectId String @db.Uuid
id String @id @default(uuid()) @db.Uuid
name String
value String?
projectId String @db.Uuid
loggedCallId String @db.Uuid
loggedCall LoggedCall @relation(fields: [loggedCallId], references: [id], onDelete: Cascade)
@@ -391,12 +393,12 @@ model User {
role UserRole @default(USER)
accounts Account[]
sessions Session[]
projectUsers ProjectUser[]
projects Project[]
worldChampEntrant WorldChampEntrant?
sentUserInvitations UserInvitation[]
accounts Account[]
sessions Session[]
projectUsers ProjectUser[]
projects Project[]
worldChampEntrant WorldChampEntrant?
sentUserInvitations UserInvitation[]
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
@@ -405,17 +407,17 @@ model User {
model UserInvitation {
id String @id @default(uuid()) @db.Uuid
projectId String @db.Uuid
project Project @relation(fields: [projectId], references: [id], onDelete: Cascade)
email String
role ProjectUserRole
invitationToken String @unique
senderId String @db.Uuid
sender User @relation(fields: [senderId], references: [id], onDelete: Cascade)
projectId String @db.Uuid
project Project @relation(fields: [projectId], references: [id], onDelete: Cascade)
email String
role ProjectUserRole
invitationToken String @unique
senderId String @db.Uuid
sender User @relation(fields: [senderId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
@@unique([projectId, email])
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
}
model VerificationToken {

View File

@@ -10,6 +10,14 @@ await prisma.project.deleteMany({
where: { id: defaultId },
});
// Mark all users as admins
await prisma.user.updateMany({
where: {},
data: {
role: "ADMIN",
},
});
// If there's an existing project, just seed into it
const project =
(await prisma.project.findFirst({})) ??
@@ -18,12 +26,16 @@ const project =
}));
if (env.OPENPIPE_API_KEY) {
await prisma.apiKey.create({
data: {
await prisma.apiKey.upsert({
where: {
apiKey: env.OPENPIPE_API_KEY,
},
create: {
projectId: project.id,
name: "Default API Key",
apiKey: env.OPENPIPE_API_KEY,
},
update: {},
});
}

View File

@@ -8,7 +8,7 @@ import {
useHandledAsyncCallback,
useVisibleScenarioIds,
} from "~/utils/hooks";
import { cellPadding } from "../constants";
import { cellPadding } from "./constants";
import { ActionButton } from "./ScenariosHeader";
export default function AddVariantButton() {

View File

@@ -16,7 +16,7 @@ import {
VStack,
} from "@chakra-ui/react";
import { BsArrowsAngleExpand, BsX } from "react-icons/bs";
import { cellPadding } from "../constants";
import { cellPadding } from "./constants";
import { FloatingLabelInput } from "./FloatingLabelInput";
import { ScenarioEditorModal } from "./ScenarioEditorModal";
@@ -111,25 +111,23 @@ export default function ScenarioEditor({
onDrop={onReorder}
backgroundColor={isDragTarget ? "gray.100" : "transparent"}
>
{variableLabels.length === 0 ? (
<Box color="gray.500">
{vars.data ? "No scenario variables configured" : "Loading..."}
</Box>
) : (
{
<VStack spacing={4} flex={1} py={2}>
<HStack justifyContent="space-between" w="100%" align="center" spacing={0}>
<Text flex={1}>Scenario</Text>
<Tooltip label="Expand" hasArrow>
<IconButton
aria-label="Expand"
icon={<Icon as={BsArrowsAngleExpand} boxSize={3} />}
onClick={() => setScenarioEditorModalOpen(true)}
size="xs"
colorScheme="gray"
color="gray.500"
variant="ghost"
/>
</Tooltip>
{variableLabels.length && (
<Tooltip label="Expand" hasArrow>
<IconButton
aria-label="Expand"
icon={<Icon as={BsArrowsAngleExpand} boxSize={3} />}
onClick={() => setScenarioEditorModalOpen(true)}
size="xs"
colorScheme="gray"
color="gray.500"
variant="ghost"
/>
</Tooltip>
)}
{canModify && props.canHide && (
<Tooltip label="Delete" hasArrow>
<IconButton
@@ -150,31 +148,38 @@ export default function ScenarioEditor({
</Tooltip>
)}
</HStack>
{variableLabels.map((key) => {
const value = values[key] ?? "";
return (
<FloatingLabelInput
key={key}
label={key}
isDisabled={!canModify}
style={{ width: "100%" }}
maxHeight={32}
value={value}
onChange={(e) => {
setValues((prev) => ({ ...prev, [key]: e.target.value }));
}}
onKeyDown={(e) => {
if (e.key === "Enter" && (e.metaKey || e.ctrlKey)) {
e.preventDefault();
e.currentTarget.blur();
onSave();
}
}}
onMouseEnter={() => setVariableInputHovered(true)}
onMouseLeave={() => setVariableInputHovered(false)}
/>
);
})}
{variableLabels.length === 0 ? (
<Box color="gray.500">
{vars.data ? "No scenario variables configured" : "Loading..."}
</Box>
) : (
variableLabels.map((key) => {
const value = values[key] ?? "";
return (
<FloatingLabelInput
key={key}
label={key}
isDisabled={!canModify}
style={{ width: "100%" }}
maxHeight={32}
value={value}
onChange={(e) => {
setValues((prev) => ({ ...prev, [key]: e.target.value }));
}}
onKeyDown={(e) => {
if (e.key === "Enter" && (e.metaKey || e.ctrlKey)) {
e.preventDefault();
e.currentTarget.blur();
onSave();
}
}}
onMouseEnter={() => setVariableInputHovered(true)}
onMouseLeave={() => setVariableInputHovered(false)}
/>
);
})
)}
{hasChanged && (
<HStack justify="right">
<Button
@@ -192,7 +197,7 @@ export default function ScenarioEditor({
</HStack>
)}
</VStack>
)}
}
</HStack>
{scenarioEditorModalOpen && (
<ScenarioEditorModal

View File

@@ -65,11 +65,11 @@ export const ScenarioEditorModal = ({
<Modal
isOpen
onClose={onClose}
size={{ base: "xl", sm: "2xl", md: "3xl", lg: "5xl", xl: "7xl" }}
size={{ base: "xl", sm: "2xl", md: "3xl", lg: "4xl", xl: "5xl" }}
>
<ModalOverlay />
<ModalContent w={1200}>
<ModalHeader />
<ModalHeader>Edit Scenario</ModalHeader>
<ModalCloseButton />
<ModalBody maxW="unset">
<VStack spacing={8}>

View File

@@ -11,7 +11,7 @@ import {
IconButton,
Spinner,
} from "@chakra-ui/react";
import { cellPadding } from "../constants";
import { cellPadding } from "./constants";
import {
useExperiment,
useExperimentAccess,

View File

@@ -1,11 +1,11 @@
import { useState, type DragEvent } from "react";
import { type PromptVariant } from "../OutputsTable/types";
import { type PromptVariant } from "../types";
import { api } from "~/utils/api";
import { RiDraggable } from "react-icons/ri";
import { useExperimentAccess, useHandledAsyncCallback } from "~/utils/hooks";
import { HStack, Icon, Text, GridItem, type GridItemProps } from "@chakra-ui/react"; // Changed here
import { cellPadding, headerMinHeight } from "../constants";
import AutoResizeTextArea from "../AutoResizeTextArea";
import AutoResizeTextArea from "../../AutoResizeTextArea";
import VariantHeaderMenuButton from "./VariantHeaderMenuButton";
export default function VariantHeader(
@@ -75,7 +75,7 @@ export default function VariantHeader(
padding={0}
sx={{
position: "sticky",
top: "-2",
top: "0",
// Ensure that the menu always appears above the sticky header of other variants
zIndex: menuOpen ? "dropdown" : 10,
}}

View File

@@ -1,6 +1,4 @@
import { type PromptVariant } from "../OutputsTable/types";
import { api } from "~/utils/api";
import { useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks";
import { useState } from "react";
import {
Icon,
Menu,
@@ -14,10 +12,13 @@ import {
} from "@chakra-ui/react";
import { BsFillTrashFill, BsGear, BsStars } from "react-icons/bs";
import { FaRegClone } from "react-icons/fa";
import { useState } from "react";
import { RefinePromptModal } from "../RefinePromptModal/RefinePromptModal";
import { RiExchangeFundsFill } from "react-icons/ri";
import { ChangeModelModal } from "../ChangeModelModal/ChangeModelModal";
import { api } from "~/utils/api";
import { useHandledAsyncCallback, useVisibleScenarioIds } from "~/utils/hooks";
import { type PromptVariant } from "../types";
import { RefinePromptModal } from "../../RefinePromptModal/RefinePromptModal";
import { ChangeModelModal } from "../../ChangeModelModal/ChangeModelModal";
export default function VariantHeaderMenuButton({
variant,

View File

@@ -1,6 +1,6 @@
import { HStack, Icon, Text, useToken } from "@chakra-ui/react";
import { type PromptVariant } from "./types";
import { cellPadding } from "../constants";
import { cellPadding } from "./constants";
import { api } from "~/utils/api";
import chroma from "chroma-js";
import { BsCurrencyDollar } from "react-icons/bs";

View File

@@ -3,13 +3,14 @@ import { api } from "~/utils/api";
import AddVariantButton from "./AddVariantButton";
import ScenarioRow from "./ScenarioRow";
import VariantEditor from "./VariantEditor";
import VariantHeader from "../VariantHeader/VariantHeader";
import VariantHeader from "./VariantHeader/VariantHeader";
import VariantStats from "./VariantStats";
import { ScenariosHeader } from "./ScenariosHeader";
import { borders } from "./styles";
import { useScenarios } from "~/utils/hooks";
import ScenarioPaginator from "./ScenarioPaginator";
import { Fragment } from "react";
import useScrolledPast from "./useHasScrolledPast";
export default function OutputsTable({ experimentId }: { experimentId: string | undefined }) {
const variants = api.promptVariants.list.useQuery(
@@ -18,6 +19,7 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
);
const scenarios = useScenarios();
const shouldFlattenHeader = useScrolledPast(50);
if (!variants.data || !scenarios.data) return null;
@@ -63,8 +65,8 @@ export default function OutputsTable({ experimentId }: { experimentId: string |
variant={variant}
canHide={variants.data.length > 1}
rowStart={1}
borderTopLeftRadius={isFirst ? 8 : 0}
borderTopRightRadius={isLast ? 8 : 0}
borderTopLeftRadius={isFirst && !shouldFlattenHeader ? 8 : 0}
borderTopRightRadius={isLast && !shouldFlattenHeader ? 8 : 0}
{...sharedProps}
/>
<GridItem rowStart={2} {...sharedProps}>

View File

@@ -0,0 +1,34 @@
import { useState, useEffect } from "react";
const useScrolledPast = (scrollThreshold: number) => {
const [hasScrolledPast, setHasScrolledPast] = useState(true);
useEffect(() => {
const container = document.getElementById("output-container");
if (!container) {
console.warn('Element with id "outputs-container" not found.');
return;
}
const checkScroll = () => {
const { scrollTop } = container;
// Check if scrollTop is greater than or equal to scrollThreshold
setHasScrolledPast(scrollTop > scrollThreshold);
};
checkScroll();
container.addEventListener("scroll", checkScroll);
// Cleanup
return () => {
container.removeEventListener("scroll", checkScroll);
};
}, []);
return hasScrolledPast;
};
export default useScrolledPast;

View File

@@ -14,21 +14,11 @@ import { formatTimePast } from "~/utils/dayjs";
import Link from "next/link";
import { useRouter } from "next/router";
import { BsPlusSquare } from "react-icons/bs";
import { api } from "~/utils/api";
import { RouterOutputs, api } from "~/utils/api";
import { useHandledAsyncCallback } from "~/utils/hooks";
import { useAppStore } from "~/state/store";
type ExperimentData = {
testScenarioCount: number;
promptVariantCount: number;
id: string;
label: string;
sortIndex: number;
createdAt: Date;
updatedAt: Date;
};
export const ExperimentCard = ({ exp }: { exp: ExperimentData }) => {
export const ExperimentCard = ({ exp }: { exp: RouterOutputs["experiments"]["list"][0] }) => {
return (
<Card
w="full"
@@ -45,7 +35,7 @@ export const ExperimentCard = ({ exp }: { exp: ExperimentData }) => {
as={Link}
w="full"
h="full"
href={{ pathname: "/experiments/[id]", query: { id: exp.id } }}
href={{ pathname: "/experiments/[experimentSlug]", query: { experimentSlug: exp.slug } }}
justify="space-between"
>
<HStack w="full" color="gray.700" justify="center">
@@ -89,8 +79,8 @@ export const NewExperimentCard = () => {
projectId: selectedProjectId ?? "",
});
await router.push({
pathname: "/experiments/[id]",
query: { id: newExperiment.id },
pathname: "/experiments/[experimentSlug]",
query: { experimentSlug: newExperiment.slug },
});
}, [createMutation, router, selectedProjectId]);

View File

@@ -16,11 +16,14 @@ export const useOnForkButtonPressed = () => {
const [onFork, isForking] = useHandledAsyncCallback(async () => {
if (!experiment.data?.id || !selectedProjectId) return;
const forkedExperimentId = await forkMutation.mutateAsync({
const newExperiment = await forkMutation.mutateAsync({
id: experiment.data.id,
projectId: selectedProjectId,
});
await router.push({ pathname: "/experiments/[id]", query: { id: forkedExperimentId } });
await router.push({
pathname: "/experiments/[experimentSlug]",
query: { experimentSlug: newExperiment.slug },
});
}, [forkMutation, experiment.data?.id, router]);
const onForkButtonPressed = useCallback(() => {

View File

@@ -67,7 +67,13 @@ export default function ProjectMenu() {
);
return (
<VStack w="full" alignItems="flex-start" spacing={0} py={1}>
<VStack
w="full"
alignItems="flex-start"
spacing={0}
py={1}
zIndex={popover.isOpen ? "dropdown" : undefined}
>
<Popover
placement="bottom"
isOpen={popover.isOpen}

View File

@@ -12,7 +12,6 @@ export const refinementActions: Record<string, RefinementAction> = {
definePrompt("openai/ChatCompletion", {
model: "gpt-4",
stream: true,
messages: [
{
role: "system",
@@ -29,7 +28,6 @@ export const refinementActions: Record<string, RefinementAction> = {
definePrompt("openai/ChatCompletion", {
model: "gpt-4",
stream: true,
messages: [
{
role: "system",
@@ -126,7 +124,6 @@ export const refinementActions: Record<string, RefinementAction> = {
definePrompt("openai/ChatCompletion", {
model: "gpt-4",
stream: true,
messages: [
{
role: "system",
@@ -143,7 +140,6 @@ export const refinementActions: Record<string, RefinementAction> = {
definePrompt("openai/ChatCompletion", {
model: "gpt-4",
stream: true,
messages: [
{
role: "system",
@@ -237,7 +233,6 @@ export const refinementActions: Record<string, RefinementAction> = {
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo",
stream: true,
messages: [
{
role: "system",

View File

@@ -3,10 +3,12 @@ import { type FrontendModelProvider } from "../types";
import { refinementActions } from "./refinementActions";
import {
templateOpenOrcaPrompt,
// templateAlpacaInstructPrompt,
templateAlpacaInstructPrompt,
// templateSystemUserAssistantPrompt,
templateInstructionInputResponsePrompt,
templateAiroborosPrompt,
templateGryphePrompt,
templateVicunaPrompt,
} from "./templatePrompt";
const frontendModelProvider: FrontendModelProvider<SupportedModel, OpenpipeChatOutput> = {
@@ -22,15 +24,16 @@ const frontendModelProvider: FrontendModelProvider<SupportedModel, OpenpipeChatO
learnMoreUrl: "https://huggingface.co/Open-Orca/OpenOrcaxOpenChat-Preview2-13B",
templatePrompt: templateOpenOrcaPrompt,
},
// "Open-Orca/OpenOrca-Platypus2-13B": {
// name: "OpenOrca-Platypus2-13B",
// contextWindow: 4096,
// pricePerSecond: 0.0003,
// speed: "medium",
// provider: "openpipe/Chat",
// learnMoreUrl: "https://huggingface.co/Open-Orca/OpenOrca-Platypus2-13B",
// templatePrompt: templateAlpacaInstructPrompt,
// },
"Open-Orca/OpenOrca-Platypus2-13B": {
name: "OpenOrca-Platypus2-13B",
contextWindow: 4096,
pricePerSecond: 0.0003,
speed: "medium",
provider: "openpipe/Chat",
learnMoreUrl: "https://huggingface.co/Open-Orca/OpenOrca-Platypus2-13B",
templatePrompt: templateAlpacaInstructPrompt,
defaultStopTokens: ["</s>"],
},
// "stabilityai/StableBeluga-13B": {
// name: "StableBeluga-13B",
// contextWindow: 4096,
@@ -58,6 +61,33 @@ const frontendModelProvider: FrontendModelProvider<SupportedModel, OpenpipeChatO
learnMoreUrl: "https://huggingface.co/jondurbin/airoboros-l2-13b-gpt4-2.0",
templatePrompt: templateAiroborosPrompt,
},
"lmsys/vicuna-13b-v1.5": {
name: "vicuna-13b-v1.5",
contextWindow: 4096,
pricePerSecond: 0.0003,
speed: "medium",
provider: "openpipe/Chat",
learnMoreUrl: "https://huggingface.co/lmsys/vicuna-13b-v1.5",
templatePrompt: templateVicunaPrompt,
},
"Gryphe/MythoMax-L2-13b": {
name: "MythoMax-L2-13b",
contextWindow: 4096,
pricePerSecond: 0.0003,
speed: "medium",
provider: "openpipe/Chat",
learnMoreUrl: "https://huggingface.co/Gryphe/MythoMax-L2-13b",
templatePrompt: templateGryphePrompt,
},
"NousResearch/Nous-Hermes-llama-2-7b": {
name: "Nous-Hermes-llama-2-7b",
contextWindow: 4096,
pricePerSecond: 0.0003,
speed: "medium",
provider: "openpipe/Chat",
learnMoreUrl: "https://huggingface.co/NousResearch/Nous-Hermes-llama-2-7b",
templatePrompt: templateInstructionInputResponsePrompt,
},
},
refinementActions,

View File

@@ -8,10 +8,13 @@ import frontendModelProvider from "./frontend";
const modelEndpoints: Record<OpenpipeChatInput["model"], string> = {
"Open-Orca/OpenOrcaxOpenChat-Preview2-13B": "https://5ef82gjxk8kdys-8000.proxy.runpod.net/v1",
// "Open-Orca/OpenOrca-Platypus2-13B": "https://lt5qlel6qcji8t-8000.proxy.runpod.net/v1",
"Open-Orca/OpenOrca-Platypus2-13B": "https://lt5qlel6qcji8t-8000.proxy.runpod.net/v1",
// "stabilityai/StableBeluga-13B": "https://vcorl8mxni2ou1-8000.proxy.runpod.net/v1",
"NousResearch/Nous-Hermes-Llama2-13b": "https://ncv8pw3u0vb8j2-8000.proxy.runpod.net/v1",
"jondurbin/airoboros-l2-13b-gpt4-2.0": "https://9nrbx7oph4btou-8000.proxy.runpod.net/v1",
"lmsys/vicuna-13b-v1.5": "https://h88hkt3ux73rb7-8000.proxy.runpod.net/v1",
"Gryphe/MythoMax-L2-13b": "https://3l5jvhnxdgky3v-8000.proxy.runpod.net/v1",
"NousResearch/Nous-Hermes-llama-2-7b": "https://ua1bpc6kv3dgge-8000.proxy.runpod.net/v1",
};
export async function getCompletion(
@@ -36,10 +39,20 @@ export async function getCompletion(
const start = Date.now();
let finalCompletion: OpenpipeChatOutput = "";
const completionParams = {
model,
prompt: templatedPrompt,
...rest,
};
if (!completionParams.stop && frontendModelProvider.models[model].defaultStopTokens) {
completionParams.stop = frontendModelProvider.models[model].defaultStopTokens;
}
try {
if (onStream) {
const resp = await openai.completions.create(
{ model, prompt: templatedPrompt, ...rest, stream: true },
{ ...completionParams, stream: true },
{
maxRetries: 0,
},
@@ -58,7 +71,7 @@ export async function getCompletion(
}
} else {
const resp = await openai.completions.create(
{ model, prompt: templatedPrompt, ...rest, stream: false },
{ ...completionParams, stream: false },
{
maxRetries: 0,
},

View File

@@ -6,10 +6,13 @@ import frontendModelProvider from "./frontend";
const supportedModels = [
"Open-Orca/OpenOrcaxOpenChat-Preview2-13B",
// "Open-Orca/OpenOrca-Platypus2-13B",
"Open-Orca/OpenOrca-Platypus2-13B",
// "stabilityai/StableBeluga-13B",
"NousResearch/Nous-Hermes-Llama2-13b",
"jondurbin/airoboros-l2-13b-gpt4-2.0",
"lmsys/vicuna-13b-v1.5",
"Gryphe/MythoMax-L2-13b",
"NousResearch/Nous-Hermes-llama-2-7b",
] as const;
export type SupportedModel = (typeof supportedModels)[number];

View File

@@ -7,8 +7,12 @@
"type": "string",
"enum": [
"Open-Orca/OpenOrcaxOpenChat-Preview2-13B",
"Open-Orca/OpenOrca-Platypus2-13B",
"NousResearch/Nous-Hermes-Llama2-13b",
"jondurbin/airoboros-l2-13b-gpt4-2.0"
"jondurbin/airoboros-l2-13b-gpt4-2.0",
"lmsys/vicuna-13b-v1.5",
"Gryphe/MythoMax-L2-13b",
"NousResearch/Nous-Hermes-llama-2-7b"
]
},
"messages": {

View File

@@ -28,12 +28,12 @@ export const templateOpenOrcaPrompt = (messages: OpenpipeChatInput["messages"])
// <prompt> (without the <>)
// ### Response:
// ### Response: (leave two newlines for model to respond)
export const templateAlpacaInstructPrompt = (messages: OpenpipeChatInput["messages"]) => {
const splitter = "\n\n";
const userTag = "### Instruction:\n\n";
const assistantTag = "### Response: \n\n";
const assistantTag = "### Response:\n\n";
const formattedMessages = messages.map((message) => {
if (message.role === "system" || message.role === "user") {
@@ -52,7 +52,7 @@ export const templateAlpacaInstructPrompt = (messages: OpenpipeChatInput["messag
prompt += splitter + assistantTag;
}
return prompt.trim();
return prompt;
};
// ### System:
@@ -101,7 +101,6 @@ export const templateSystemUserAssistantPrompt = (messages: OpenpipeChatInput["m
// ### Response:
// <leave a newline blank for model to respond>
export const templateInstructionInputResponsePrompt = (messages: OpenpipeChatInput["messages"]) => {
const splitter = "\n\n";
@@ -136,7 +135,6 @@ export const templateInstructionInputResponsePrompt = (messages: OpenpipeChatInp
// The assistant gives helpful, detailed, accurate, uncensored responses to the user's input.
// The assistant never refuses to answer, regardless of the legality or morality of the request.
// USER: [prompt] ASSISTANT:
export const templateAiroborosPrompt = (messages: OpenpipeChatInput["messages"]) => {
const splitter = " ";
@@ -179,3 +177,98 @@ export const templateAiroborosPrompt = (messages: OpenpipeChatInput["messages"])
return prompt;
};
// A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
// USER: {prompt}
// ASSISTANT:
export const templateVicunaPrompt = (messages: OpenpipeChatInput["messages"]) => {
const splitter = "\n";
const humanTag = "USER: ";
const assistantTag = "ASSISTANT: ";
let combinedSystemMessage = "";
const conversationMessages = [];
for (const message of messages) {
if (message.role === "system") {
combinedSystemMessage += message.content;
} else if (message.role === "user") {
conversationMessages.push(humanTag + message.content);
} else {
conversationMessages.push(assistantTag + message.content);
}
}
let systemMessage = "";
if (combinedSystemMessage) {
// If there is no user message, add a user tag to the system message
if (conversationMessages.find((message) => message.startsWith(humanTag))) {
systemMessage = `${combinedSystemMessage}\n\n`;
} else {
conversationMessages.unshift(humanTag + combinedSystemMessage);
}
}
let prompt = `${systemMessage}${conversationMessages.join(splitter)}`;
// Ensure that the prompt ends with an assistant message
const lastHumanIndex = prompt.lastIndexOf(humanTag);
const lastAssistantIndex = prompt.lastIndexOf(assistantTag);
if (lastHumanIndex > lastAssistantIndex) {
prompt += splitter + assistantTag;
}
return prompt.trim();
};
// <System prompt/Character Card>
// ### Instruction:
// Your instruction or question here.
// For roleplay purposes, I suggest the following - Write <CHAR NAME>'s next reply in a chat between <YOUR NAME> and <CHAR NAME>. Write a single reply only.
// ### Response:
export const templateGryphePrompt = (messages: OpenpipeChatInput["messages"]) => {
const splitter = "\n\n";
const instructionTag = "### Instruction:\n";
const responseTag = "### Response:\n";
let combinedSystemMessage = "";
const conversationMessages = [];
for (const message of messages) {
if (message.role === "system") {
combinedSystemMessage += message.content;
} else if (message.role === "user") {
conversationMessages.push(instructionTag + message.content);
} else {
conversationMessages.push(responseTag + message.content);
}
}
let systemMessage = "";
if (combinedSystemMessage) {
// If there is no user message, add a user tag to the system message
if (conversationMessages.find((message) => message.startsWith(instructionTag))) {
systemMessage = `${combinedSystemMessage}\n\n`;
} else {
conversationMessages.unshift(instructionTag + combinedSystemMessage);
}
}
let prompt = `${systemMessage}${conversationMessages.join(splitter)}`;
// Ensure that the prompt ends with an assistant message
const lastInstructionIndex = prompt.lastIndexOf(instructionTag);
const lastAssistantIndex = prompt.lastIndexOf(responseTag);
if (lastInstructionIndex > lastAssistantIndex) {
prompt += splitter + responseTag;
}
return prompt;
};

View File

@@ -25,6 +25,7 @@ export type Model = {
learnMoreUrl?: string;
apiDocsUrl?: string;
templatePrompt?: (initialPrompt: OpenpipeChatInput["messages"]) => string;
defaultStopTokens?: string[];
};
export type ProviderModel = { provider: z.infer<typeof ZodSupportedProvider>; model: string };

View File

@@ -0,0 +1,54 @@
import { Card, Table, Tbody, Td, Th, Thead, Tr } from "@chakra-ui/react";
import dayjs from "dayjs";
import { isDate, isObject, isString } from "lodash-es";
import AppShell from "~/components/nav/AppShell";
import { type RouterOutputs, api } from "~/utils/api";
const fieldsToShow: (keyof RouterOutputs["adminJobs"]["list"][0])[] = [
"id",
"queue_name",
"payload",
"priority",
"attempts",
"last_error",
"created_at",
"key",
"locked_at",
"run_at",
];
export default function Jobs() {
const jobs = api.adminJobs.list.useQuery({});
return (
<AppShell title="Admin Jobs">
<Card m={4} overflowX="auto">
<Table>
<Thead>
<Tr>
{fieldsToShow.map((field) => (
<Th key={field}>{field}</Th>
))}
</Tr>
</Thead>
<Tbody>
{jobs.data?.map((job) => (
<Tr key={job.id}>
{fieldsToShow.map((field) => {
// Check if object
let value = job[field];
if (isDate(value)) {
value = dayjs(value).format("YYYY-MM-DD HH:mm:ss");
} else if (isObject(value) && !isString(value)) {
value = JSON.stringify(value);
} // check if date
return <Td key={field}>{value}</Td>;
})}
</Tr>
))}
</Tbody>
</Table>
</Card>
</AppShell>
);
}

View File

@@ -33,9 +33,9 @@ export default function Experiment() {
const experiment = useExperiment();
const experimentStats = api.experiments.stats.useQuery(
{ id: router.query.id as string },
{ id: experiment.data?.id as string },
{
enabled: !!router.query.id,
enabled: !!experiment.data?.id,
},
);
const stats = experimentStats.data;
@@ -124,8 +124,8 @@ export default function Experiment() {
<ExperimentHeaderButtons />
</PageHeaderContainer>
<ExperimentSettingsDrawer />
<Box w="100%" overflowX="auto" flex={1}>
<OutputsTable experimentId={router.query.id as string | undefined} />
<Box w="100%" overflowX="auto" flex={1} id="output-container">
<OutputsTable experimentId={experiment.data?.id} />
</Box>
</VStack>
</AppShell>

View File

@@ -12,6 +12,7 @@ import { projectsRouter } from "./routers/projects.router";
import { dashboardRouter } from "./routers/dashboard.router";
import { loggedCallsRouter } from "./routers/loggedCalls.router";
import { usersRouter } from "./routers/users.router";
import { adminJobsRouter } from "./routers/adminJobs.router";
/**
* This is the primary router for your server.
@@ -32,6 +33,7 @@ export const appRouter = createTRPCRouter({
dashboard: dashboardRouter,
loggedCalls: loggedCallsRouter,
users: usersRouter,
adminJobs: adminJobsRouter,
});
// export type definition of API

View File

@@ -0,0 +1,18 @@
import { z } from "zod";
import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc";
import { kysely } from "~/server/db";
import { requireIsAdmin } from "~/utils/accessControl";
export const adminJobsRouter = createTRPCRouter({
list: protectedProcedure.input(z.object({})).query(async ({ ctx }) => {
await requireIsAdmin(ctx);
return await kysely
.selectFrom("graphile_worker.jobs")
.limit(100)
.selectAll()
.orderBy("created_at", "desc")
.execute();
}),
});

View File

@@ -85,15 +85,16 @@ export const experimentsRouter = createTRPCRouter({
return experimentsWithCounts;
}),
get: publicProcedure.input(z.object({ id: z.string() })).query(async ({ input, ctx }) => {
await requireCanViewExperiment(input.id, ctx);
get: publicProcedure.input(z.object({ slug: z.string() })).query(async ({ input, ctx }) => {
const experiment = await prisma.experiment.findFirstOrThrow({
where: { id: input.id },
where: { slug: input.slug },
include: {
project: true,
},
});
await requireCanViewExperiment(experiment.id, ctx);
const canModify = ctx.session?.user.id
? await canModifyExperiment(experiment.id, ctx.session?.user.id)
: false;
@@ -290,7 +291,10 @@ export const experimentsRouter = createTRPCRouter({
}),
]);
return newExperimentId;
const newExperiment = await prisma.experiment.findUniqueOrThrow({
where: { id: newExperimentId },
});
return newExperiment;
}),
create: protectedProcedure
@@ -335,7 +339,6 @@ export const experimentsRouter = createTRPCRouter({
definePrompt("openai/ChatCompletion", {
model: "gpt-3.5-turbo-0613",
stream: true,
messages: [
{
role: "system",

View File

@@ -1,27 +1,6 @@
import {
type Experiment,
type PromptVariant,
type TestScenario,
type TemplateVariable,
type ScenarioVariantCell,
type ModelResponse,
type Evaluation,
type OutputEvaluation,
type Dataset,
type DatasetEntry,
type Project,
type ProjectUser,
type WorldChampEntrant,
type LoggedCall,
type LoggedCallModelResponse,
type LoggedCallTag,
type ApiKey,
type Account,
type Session,
type User,
type VerificationToken,
PrismaClient,
} from "@prisma/client";
import { type DB } from "./db.types";
import { PrismaClient } from "@prisma/client";
import { Kysely, PostgresDialect } from "kysely";
// TODO: Revert to normal import when our tsconfig.json is fixed
// import { Pool } from "pg";
@@ -32,30 +11,6 @@ const Pool = (UntypedPool.default ? UntypedPool.default : UntypedPool) as typeof
import { env } from "~/env.mjs";
interface DB {
Experiment: Experiment;
PromptVariant: PromptVariant;
TestScenario: TestScenario;
TemplateVariable: TemplateVariable;
ScenarioVariantCell: ScenarioVariantCell;
ModelResponse: ModelResponse;
Evaluation: Evaluation;
OutputEvaluation: OutputEvaluation;
Dataset: Dataset;
DatasetEntry: DatasetEntry;
Project: Project;
ProjectUser: ProjectUser;
WorldChampEntrant: WorldChampEntrant;
LoggedCall: LoggedCall;
LoggedCallModelResponse: LoggedCallModelResponse;
LoggedCallTag: LoggedCallTag;
ApiKey: ApiKey;
Account: Account;
Session: Session;
User: User;
VerificationToken: VerificationToken;
}
const globalForPrisma = globalThis as unknown as {
prisma: PrismaClient | undefined;
};

336
app/src/server/db.types.ts Normal file
View File

@@ -0,0 +1,336 @@
import type { ColumnType } from "kysely";
export type Generated<T> = T extends ColumnType<infer S, infer I, infer U>
? ColumnType<S, I | undefined, U>
: ColumnType<T, T | undefined, T>;
export type Int8 = ColumnType<string, string | number | bigint, string | number | bigint>;
export type Json = ColumnType<JsonValue, string, string>;
export type JsonArray = JsonValue[];
export type JsonObject = {
[K in string]?: JsonValue;
};
export type JsonPrimitive = boolean | null | number | string;
export type JsonValue = JsonArray | JsonObject | JsonPrimitive;
export type Numeric = ColumnType<string, string | number, string | number>;
export type Timestamp = ColumnType<Date, Date | string, Date | string>;
export interface _PrismaMigrations {
id: string;
checksum: string;
finished_at: Timestamp | null;
migration_name: string;
logs: string | null;
rolled_back_at: Timestamp | null;
started_at: Generated<Timestamp>;
applied_steps_count: Generated<number>;
}
export interface Account {
id: string;
userId: string;
type: string;
provider: string;
providerAccountId: string;
refresh_token: string | null;
refresh_token_expires_in: number | null;
access_token: string | null;
expires_at: number | null;
token_type: string | null;
scope: string | null;
id_token: string | null;
session_state: string | null;
}
export interface ApiKey {
id: string;
name: string;
apiKey: string;
projectId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
}
export interface Dataset {
id: string;
name: string;
projectId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
}
export interface DatasetEntry {
id: string;
input: string;
output: string | null;
datasetId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
}
export interface Evaluation {
id: string;
label: string;
value: string;
evalType: "CONTAINS" | "DOES_NOT_CONTAIN" | "GPT4_EVAL";
experimentId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
}
export interface Experiment {
id: string;
label: string;
sortIndex: Generated<number>;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
projectId: string;
}
export interface GraphileWorkerJobQueues {
queue_name: string;
job_count: number;
locked_at: Timestamp | null;
locked_by: string | null;
}
export interface GraphileWorkerJobs {
id: Generated<Int8>;
queue_name: string | null;
task_identifier: string;
payload: Generated<Json>;
priority: Generated<number>;
run_at: Generated<Timestamp>;
attempts: Generated<number>;
max_attempts: Generated<number>;
last_error: string | null;
created_at: Generated<Timestamp>;
updated_at: Generated<Timestamp>;
key: string | null;
locked_at: Timestamp | null;
locked_by: string | null;
revision: Generated<number>;
flags: Json | null;
}
export interface GraphileWorkerKnownCrontabs {
identifier: string;
known_since: Timestamp;
last_execution: Timestamp | null;
}
export interface GraphileWorkerMigrations {
id: number;
ts: Generated<Timestamp>;
}
export interface LoggedCall {
id: string;
requestedAt: Timestamp;
cacheHit: boolean;
modelResponseId: string | null;
projectId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
model: string | null;
}
export interface LoggedCallModelResponse {
id: string;
reqPayload: Json;
statusCode: number | null;
respPayload: Json | null;
errorMessage: string | null;
requestedAt: Timestamp;
receivedAt: Timestamp;
cacheKey: string | null;
durationMs: number | null;
inputTokens: number | null;
outputTokens: number | null;
finishReason: string | null;
completionId: string | null;
cost: Numeric | null;
originalLoggedCallId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
}
export interface LoggedCallTag {
id: string;
name: string;
value: string | null;
loggedCallId: string;
projectId: string;
}
export interface ModelResponse {
id: string;
cacheKey: string;
respPayload: Json | null;
inputTokens: number | null;
outputTokens: number | null;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
scenarioVariantCellId: string;
cost: number | null;
requestedAt: Timestamp | null;
receivedAt: Timestamp | null;
statusCode: number | null;
errorMessage: string | null;
retryTime: Timestamp | null;
outdated: Generated<boolean>;
}
export interface OutputEvaluation {
id: string;
result: number;
details: string | null;
modelResponseId: string;
evaluationId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
}
export interface Project {
id: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
personalProjectUserId: string | null;
name: Generated<string>;
}
export interface ProjectUser {
id: string;
role: "ADMIN" | "MEMBER" | "VIEWER";
projectId: string;
userId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
}
export interface PromptVariant {
id: string;
label: string;
uiId: string;
visible: Generated<boolean>;
sortIndex: Generated<number>;
experimentId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
promptConstructor: string;
model: string;
promptConstructorVersion: number;
modelProvider: string;
}
export interface ScenarioVariantCell {
id: string;
errorMessage: string | null;
promptVariantId: string;
testScenarioId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
retrievalStatus: Generated<"COMPLETE" | "ERROR" | "IN_PROGRESS" | "PENDING">;
prompt: Json | null;
jobQueuedAt: Timestamp | null;
jobStartedAt: Timestamp | null;
}
export interface Session {
id: string;
sessionToken: string;
userId: string;
expires: Timestamp;
}
export interface TemplateVariable {
id: string;
label: string;
experimentId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
}
export interface TestScenario {
id: string;
variableValues: Json;
uiId: string;
visible: Generated<boolean>;
sortIndex: Generated<number>;
experimentId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
}
export interface User {
id: string;
name: string | null;
email: string | null;
emailVerified: Timestamp | null;
image: string | null;
createdAt: Generated<Timestamp>;
updatedAt: Generated<Timestamp>;
role: Generated<"ADMIN" | "USER">;
}
export interface UserInvitation {
id: string;
projectId: string;
email: string;
role: "ADMIN" | "MEMBER" | "VIEWER";
invitationToken: string;
senderId: string;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
}
export interface VerificationToken {
identifier: string;
token: string;
expires: Timestamp;
}
export interface WorldChampEntrant {
id: string;
userId: string;
approved: Generated<boolean>;
createdAt: Generated<Timestamp>;
updatedAt: Timestamp;
}
export interface DB {
_prisma_migrations: _PrismaMigrations;
Account: Account;
ApiKey: ApiKey;
Dataset: Dataset;
DatasetEntry: DatasetEntry;
Evaluation: Evaluation;
Experiment: Experiment;
"graphile_worker.job_queues": GraphileWorkerJobQueues;
"graphile_worker.jobs": GraphileWorkerJobs;
"graphile_worker.known_crontabs": GraphileWorkerKnownCrontabs;
"graphile_worker.migrations": GraphileWorkerMigrations;
LoggedCall: LoggedCall;
LoggedCallModelResponse: LoggedCallModelResponse;
LoggedCallTag: LoggedCallTag;
ModelResponse: ModelResponse;
OutputEvaluation: OutputEvaluation;
Project: Project;
ProjectUser: ProjectUser;
PromptVariant: PromptVariant;
ScenarioVariantCell: ScenarioVariantCell;
Session: Session;
TemplateVariable: TemplateVariable;
TestScenario: TestScenario;
User: User;
UserInvitation: UserInvitation;
VerificationToken: VerificationToken;
WorldChampEntrant: WorldChampEntrant;
}

View File

@@ -1,15 +1,24 @@
// Import necessary dependencies
import { quickAddJob, type Helpers, type Task } from "graphile-worker";
import { type Helpers, type Task, makeWorkerUtils } from "graphile-worker";
import { env } from "~/env.mjs";
// Define the defineTask function
let workerUtilsPromise: ReturnType<typeof makeWorkerUtils> | null = null;
function workerUtils() {
if (!workerUtilsPromise) {
workerUtilsPromise = makeWorkerUtils({
connectionString: env.DATABASE_URL,
});
}
return workerUtilsPromise;
}
function defineTask<TPayload>(
taskIdentifier: string,
taskHandler: (payload: TPayload, helpers: Helpers) => Promise<void>,
) {
const enqueue = async (payload: TPayload, runAt?: Date) => {
console.log("Enqueuing task", taskIdentifier, payload);
await quickAddJob({ connectionString: env.DATABASE_URL }, taskIdentifier, payload, { runAt });
await (await workerUtils()).addJob(taskIdentifier, payload, { runAt });
};
const handler = (payload: TPayload, helpers: Helpers) => {

View File

@@ -17,6 +17,8 @@ export const requireNothing = (ctx: TRPCContext) => {
};
export const requireIsProjectAdmin = async (projectId: string, ctx: TRPCContext) => {
ctx.markAccessControlRun();
const userId = ctx.session?.user.id;
if (!userId) {
throw new TRPCError({ code: "UNAUTHORIZED" });
@@ -33,11 +35,11 @@ export const requireIsProjectAdmin = async (projectId: string, ctx: TRPCContext)
if (!isAdmin) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
ctx.markAccessControlRun();
};
export const requireCanViewProject = async (projectId: string, ctx: TRPCContext) => {
ctx.markAccessControlRun();
const userId = ctx.session?.user.id;
if (!userId) {
throw new TRPCError({ code: "UNAUTHORIZED" });
@@ -53,11 +55,11 @@ export const requireCanViewProject = async (projectId: string, ctx: TRPCContext)
if (!canView) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
ctx.markAccessControlRun();
};
export const requireCanModifyProject = async (projectId: string, ctx: TRPCContext) => {
ctx.markAccessControlRun();
const userId = ctx.session?.user.id;
if (!userId) {
throw new TRPCError({ code: "UNAUTHORIZED" });
@@ -74,11 +76,11 @@ export const requireCanModifyProject = async (projectId: string, ctx: TRPCContex
if (!canModify) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
ctx.markAccessControlRun();
};
export const requireCanViewDataset = async (datasetId: string, ctx: TRPCContext) => {
ctx.markAccessControlRun();
const dataset = await prisma.dataset.findFirst({
where: {
id: datasetId,
@@ -96,8 +98,6 @@ export const requireCanViewDataset = async (datasetId: string, ctx: TRPCContext)
if (!dataset) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
ctx.markAccessControlRun();
};
export const requireCanModifyDataset = async (datasetId: string, ctx: TRPCContext) => {
@@ -105,13 +105,10 @@ export const requireCanModifyDataset = async (datasetId: string, ctx: TRPCContex
await requireCanViewDataset(datasetId, ctx);
};
export const requireCanViewExperiment = async (experimentId: string, ctx: TRPCContext) => {
await prisma.experiment.findFirst({
where: { id: experimentId },
});
export const requireCanViewExperiment = (experimentId: string, ctx: TRPCContext): Promise<void> => {
// Right now all experiments are publicly viewable, so this is a no-op.
ctx.markAccessControlRun();
return Promise.resolve();
};
export const canModifyExperiment = async (experimentId: string, userId: string) => {
@@ -136,6 +133,8 @@ export const canModifyExperiment = async (experimentId: string, userId: string)
};
export const requireCanModifyExperiment = async (experimentId: string, ctx: TRPCContext) => {
ctx.markAccessControlRun();
const userId = ctx.session?.user.id;
if (!userId) {
throw new TRPCError({ code: "UNAUTHORIZED" });
@@ -144,6 +143,17 @@ export const requireCanModifyExperiment = async (experimentId: string, ctx: TRPC
if (!(await canModifyExperiment(experimentId, userId))) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
ctx.markAccessControlRun();
};
export const requireIsAdmin = async (ctx: TRPCContext) => {
ctx.markAccessControlRun();
const userId = ctx.session?.user.id;
if (!userId) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
if (!(await isAdmin(userId))) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
};

View File

@@ -15,8 +15,8 @@ export const useExperiments = () => {
export const useExperiment = () => {
const router = useRouter();
const experiment = api.experiments.get.useQuery(
{ id: router.query.id as string },
{ enabled: !!router.query.id },
{ slug: router.query.experimentSlug as string },
{ enabled: !!router.query.experimentSlug },
);
return experiment;

30
pnpm-lock.yaml generated
View File

@@ -134,6 +134,9 @@ importers:
kysely:
specifier: ^0.26.1
version: 0.26.1
kysely-codegen:
specifier: ^0.10.1
version: 0.10.1(kysely@0.26.1)(pg@8.11.2)
lodash-es:
specifier: ^4.17.21
version: 4.17.21
@@ -6256,7 +6259,7 @@ packages:
resolution: {integrity: sha512-7vuh85V5cdDofPyxn58nrPjBktZo0u9x1g8WtjQol+jZDaE+fhN+cIvTj11GndBnMnyfrUOG1sZQxCdjKh+DKg==}
engines: {node: '>= 10.13.0'}
dependencies:
'@types/node': 20.4.10
'@types/node': 18.16.0
merge-stream: 2.0.0
supports-color: 8.1.1
@@ -6391,6 +6394,30 @@ packages:
object.values: 1.1.6
dev: true
/kysely-codegen@0.10.1(kysely@0.26.1)(pg@8.11.2):
resolution: {integrity: sha512-8Bslh952gN5gtucRv4jTZDFD18RBioS6M50zHfe5kwb5iSyEAunU4ZYMdHzkHraa4zxjg5/183XlOryBCXLRIw==}
hasBin: true
peerDependencies:
better-sqlite3: '>=7.6.2'
kysely: '>=0.19.12'
mysql2: ^2.3.3 || ^3.0.0
pg: ^8.8.0
peerDependenciesMeta:
better-sqlite3:
optional: true
mysql2:
optional: true
pg:
optional: true
dependencies:
chalk: 4.1.2
dotenv: 16.3.1
kysely: 0.26.1
micromatch: 4.0.5
minimist: 1.2.8
pg: 8.11.2
dev: false
/kysely@0.26.1:
resolution: {integrity: sha512-FVRomkdZofBu3O8SiwAOXrwbhPZZr8mBN5ZeUWyprH29jzvy6Inzqbd0IMmGxpd4rcOCL9HyyBNWBa8FBqDAdg==}
engines: {node: '>=14.0.0'}
@@ -6611,7 +6638,6 @@ packages:
dependencies:
braces: 3.0.2
picomatch: 2.3.1
dev: true
/mime-db@1.52.0:
resolution: {integrity: sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==}