Move the external API into its own router

Auth logic isn't shared between the clients anyway, so co-locating them is confusing since you can't use the same clients to call both. This also makes the codegen clients less verbose.
This commit is contained in:
Kyle Corbitt
2023-08-14 16:56:50 -07:00
parent 8552baf632
commit c4cef35717
25 changed files with 291 additions and 235 deletions

View File

@@ -0,0 +1,95 @@
import type { ApiKey, Project } from "@prisma/client";
import { TRPCError, initTRPC } from "@trpc/server";
import { type CreateNextContextOptions } from "@trpc/server/adapters/next";
import superjson from "superjson";
import { type OpenApiMeta } from "trpc-openapi";
import { ZodError } from "zod";
import { prisma } from "~/server/db";
type CreateContextOptions = {
key:
| (ApiKey & {
project: Project;
})
| null;
};
/**
* This helper generates the "internals" for a tRPC context. If you need to use it, you can export
* it from here.
*
* Examples of things you may need it for:
* - testing, so we don't have to mock Next.js' req/res
* - tRPC's `createSSGHelpers`, where we don't have req/res
*
* @see https://create.t3.gg/en/usage/trpc#-serverapitrpcts
*/
export const createInnerTRPCContext = (opts: CreateContextOptions) => {
return {
key: opts.key,
};
};
export const createOpenApiContext = async (opts: CreateNextContextOptions) => {
const { req, res } = opts;
const apiKey = req.headers.authorization?.split(" ")[1] as string | null;
if (!apiKey) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
const key = await prisma.apiKey.findUnique({
where: { apiKey },
include: { project: true },
});
if (!key) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
return createInnerTRPCContext({
key,
});
};
export type TRPCContext = Awaited<ReturnType<typeof createOpenApiContext>>;
const t = initTRPC
.context<typeof createOpenApiContext>()
.meta<OpenApiMeta>()
.create({
transformer: superjson,
errorFormatter({ shape, error }) {
return {
...shape,
data: {
...shape.data,
zodError: error.cause instanceof ZodError ? error.cause.flatten() : null,
},
};
},
});
export const createOpenApiRouter = t.router;
export const openApiPublicProc = t.procedure;
/** Reusable middleware that enforces users are logged in before running the procedure. */
const enforceApiKey = t.middleware(async ({ ctx, next }) => {
if (!ctx.key) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
return next({
ctx: { key: ctx.key },
});
});
/**
* Protected (authenticated) procedure
*
* If you want a query or mutation to ONLY be accessible to logged in users, use this. It verifies
* the session is valid and guarantees `ctx.session.user` is not null.
*
* @see https://trpc.io/docs/procedures
*/
export const openApiProtectedProc = t.procedure.use(enforceApiKey);

View File

@@ -2,9 +2,6 @@ import { type Prisma } from "@prisma/client";
import { type JsonValue } from "type-fest";
import { z } from "zod";
import { v4 as uuidv4 } from "uuid";
import { TRPCError } from "@trpc/server";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { hashRequest } from "~/server/utils/hashObject";
import modelProvider from "~/modelProviders/openai-ChatCompletion";
@@ -12,6 +9,7 @@ import {
type ChatCompletion,
type CompletionCreateParams,
} from "openai/resources/chat/completions";
import { createOpenApiRouter, openApiProtectedProc } from "./openApiTrpc";
const reqValidator = z.object({
model: z.string(),
@@ -28,12 +26,12 @@ const respValidator = z.object({
),
});
export const externalApiRouter = createTRPCRouter({
checkCache: publicProcedure
export const v1ApiRouter = createOpenApiRouter({
checkCache: openApiProtectedProc
.meta({
openapi: {
method: "POST",
path: "/v1/check-cache",
path: "/check-cache",
description: "Check if a prompt is cached",
protect: true,
},
@@ -56,18 +54,8 @@ export const externalApiRouter = createTRPCRouter({
}),
)
.mutation(async ({ input, ctx }) => {
const apiKey = ctx.apiKey;
if (!apiKey) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
const key = await prisma.apiKey.findUnique({
where: { apiKey },
});
if (!key) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
const reqPayload = await reqValidator.spa(input.reqPayload);
const cacheKey = hashRequest(key.projectId, reqPayload as JsonValue);
const cacheKey = hashRequest(ctx.key.projectId, reqPayload as JsonValue);
const existingResponse = await prisma.loggedCallModelResponse.findFirst({
where: { cacheKey },
@@ -79,7 +67,7 @@ export const externalApiRouter = createTRPCRouter({
await prisma.loggedCall.create({
data: {
projectId: key.projectId,
projectId: ctx.key.projectId,
requestedAt: new Date(input.requestedAt),
cacheHit: true,
modelResponseId: existingResponse.id,
@@ -91,11 +79,11 @@ export const externalApiRouter = createTRPCRouter({
};
}),
report: publicProcedure
report: openApiProtectedProc
.meta({
openapi: {
method: "POST",
path: "/v1/report",
path: "/report",
description: "Report an API call",
protect: true,
},
@@ -119,20 +107,10 @@ export const externalApiRouter = createTRPCRouter({
.output(z.void())
.mutation(async ({ input, ctx }) => {
console.log("GOT TAGS", input.tags);
const apiKey = ctx.apiKey;
if (!apiKey) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
const key = await prisma.apiKey.findUnique({
where: { apiKey },
});
if (!key) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
const reqPayload = await reqValidator.spa(input.reqPayload);
const respPayload = await respValidator.spa(input.respPayload);
const requestHash = hashRequest(key.projectId, reqPayload as JsonValue);
const requestHash = hashRequest(ctx.key.project.id, reqPayload as JsonValue);
const newLoggedCallId = uuidv4();
const newModelResponseId = uuidv4();
@@ -151,7 +129,7 @@ export const externalApiRouter = createTRPCRouter({
prisma.loggedCall.create({
data: {
id: newLoggedCallId,
projectId: key.projectId,
projectId: ctx.key.project.id,
requestedAt: new Date(input.requestedAt),
cacheHit: false,
model,

View File

@@ -8,7 +8,6 @@ import { evaluationsRouter } from "./routers/evaluations.router";
import { worldChampsRouter } from "./routers/worldChamps.router";
import { datasetsRouter } from "./routers/datasets.router";
import { datasetEntries } from "./routers/datasetEntries.router";
import { externalApiRouter } from "./routers/externalApi.router";
import { projectsRouter } from "./routers/projects.router";
import { dashboardRouter } from "./routers/dashboard.router";
import { loggedCallsRouter } from "./routers/loggedCalls.router";
@@ -31,7 +30,6 @@ export const appRouter = createTRPCRouter({
projects: projectsRouter,
dashboard: dashboardRouter,
loggedCalls: loggedCallsRouter,
externalApi: externalApiRouter,
});
// export type definition of API

View File

@@ -27,7 +27,6 @@ import { capturePath } from "~/utils/analytics/serverAnalytics";
type CreateContextOptions = {
session: Session | null;
apiKey: string | null;
};
// eslint-disable-next-line @typescript-eslint/no-empty-function
@@ -46,7 +45,6 @@ const noOp = () => {};
export const createInnerTRPCContext = (opts: CreateContextOptions) => {
return {
session: opts.session,
apiKey: opts.apiKey,
prisma,
markAccessControlRun: noOp,
};
@@ -64,11 +62,8 @@ export const createTRPCContext = async (opts: CreateNextContextOptions) => {
// Get the session from the server using the getServerSession wrapper function
const session = await getServerAuthSession({ req, res });
const apiKey = req.headers.authorization?.split(" ")[1] as string | null;
return createInnerTRPCContext({
session,
apiKey,
});
};

View File

@@ -1,5 +1,5 @@
import "dotenv/config";
import { openApiDocument } from "~/pages/api/openapi.json";
import { openApiDocument } from "~/pages/api/v1/openapi.json";
import fs from "fs";
import path from "path";
import { execSync } from "child_process";