diff --git a/.github/workflows/job_test_api_canary.yaml b/.github/workflows/job_test_api_canary.yaml index f49c5e9423..fbbd2178a9 100644 --- a/.github/workflows/job_test_api_canary.yaml +++ b/.github/workflows/job_test_api_canary.yaml @@ -24,7 +24,7 @@ jobs: strategy: fail-fast: false matrix: - shard: ["1/9", "2/9", "3/9", "4/9", "5/9", "6/9","7/9", "8/9","9/9"] + shard: ["1/9", "2/9", "3/9", "4/9", "5/9", "6/9", "7/9", "8/9", "9/9"] steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/job_test_api_local.yaml b/.github/workflows/job_test_api_local.yaml index 4a0421a965..b73a8eab46 100644 --- a/.github/workflows/job_test_api_local.yaml +++ b/.github/workflows/job_test_api_local.yaml @@ -13,12 +13,9 @@ jobs: - name: Delete huge unnecessary tools folder run: rm -rf /opt/hostedtoolcache - - name: Run containers run: docker compose -f ./deployment/docker-compose.yaml up -d - - - name: Install uses: ./.github/actions/install with: @@ -28,7 +25,6 @@ jobs: - name: Build run: pnpm turbo run build --filter=./apps/api... - - name: Load Schema into MySQL run: pnpm drizzle-kit push working-directory: internal/db @@ -50,7 +46,7 @@ jobs: DATABASE_HOST: localhost:3900 DATABASE_USERNAME: unkey DATABASE_PASSWORD: password - CLICKHOUSE_URL: http://default:password@localhost:8123 + CLICKHOUSE_URL: http://default:password@localhost:8123 TEST_LOCAL: true - name: Dump logs diff --git a/.github/workflows/job_test_api_staging.yaml b/.github/workflows/job_test_api_staging.yaml index 14d330a048..31098beeb2 100644 --- a/.github/workflows/job_test_api_staging.yaml +++ b/.github/workflows/job_test_api_staging.yaml @@ -24,7 +24,7 @@ jobs: strategy: fail-fast: false matrix: - shard: ["1/9", "2/9", "3/9", "4/9", "5/9", "6/9","7/9", "8/9", "9/9"] + shard: ["1/9", "2/9", "3/9", "4/9", "5/9", "6/9", "7/9", "8/9", "9/9"] steps: - uses: actions/checkout@v4 @@ -34,8 +34,6 @@ jobs: - name: Wake ClickHouse run: curl -X GET ${{ secrets.CLICKHOUSE_URL }}/ping - - - name: Install uses: ./.github/actions/install with: diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 480fec38e2..ecee810096 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -46,5 +46,5 @@ jobs: # working-directory: apps/agent # test_go_api_local: - # name: Test GO API Local + # name: Test Go API Local # uses: ./.github/workflows/job_test_go_api_local.yaml diff --git a/apps/dashboard/app/(app)/logs/components/controls/components/logs-filters/components/paths-filter.tsx b/apps/dashboard/app/(app)/logs/components/controls/components/logs-filters/components/paths-filter.tsx index 36f40e4c86..f56da9b9ed 100644 --- a/apps/dashboard/app/(app)/logs/components/controls/components/logs-filters/components/paths-filter.tsx +++ b/apps/dashboard/app/(app)/logs/components/controls/components/logs-filters/components/paths-filter.tsx @@ -3,25 +3,21 @@ import { useFilters } from "@/app/(app)/logs/hooks/use-filters"; import { Checkbox } from "@/components/ui/checkbox"; import { trpc } from "@/lib/trpc/client"; import { Button } from "@unkey/ui"; -import { useCallback, useMemo } from "react"; +import { useCallback } from "react"; import { useCheckboxState } from "./hooks/use-checkbox-state"; export const PathsFilter = () => { - const dateNow = useMemo(() => Date.now(), []); - const { data: paths, isLoading } = trpc.logs.queryDistinctPaths.useQuery( - { currentDate: dateNow }, - { - select(paths) { - return paths - ? paths.map((path, index) => ({ - id: index + 1, - path, - checked: false, - })) - : []; - }, + const { data: paths, isLoading } = trpc.logs.queryDistinctPaths.useQuery(undefined, { + select(paths) { + return paths + ? paths.map((path, index) => ({ + id: index + 1, + path, + checked: false, + })) + : []; }, - ); + }); const { filters, updateFilters } = useFilters(); const { checkboxes, handleCheckboxChange, handleSelectAll, handleKeyDown } = useCheckboxState({ diff --git a/apps/dashboard/app/(app)/logs/components/controls/components/logs-search/index.tsx b/apps/dashboard/app/(app)/logs/components/controls/components/logs-search/index.tsx index c6805cfdf0..686b3885ab 100644 --- a/apps/dashboard/app/(app)/logs/components/controls/components/logs-search/index.tsx +++ b/apps/dashboard/app/(app)/logs/components/controls/components/logs-search/index.tsx @@ -50,10 +50,7 @@ export const LogsSearch = () => { const query = search.trim(); if (query) { try { - await queryLLMForStructuredOutput.mutateAsync({ - query, - timestamp: Date.now(), - }); + await queryLLMForStructuredOutput.mutateAsync(query); } catch (error) { console.error("Search failed:", error); } diff --git a/apps/dashboard/app/(app)/settings/billing/stripe/success/page.tsx b/apps/dashboard/app/(app)/settings/billing/stripe/success/page.tsx index 9adce0c266..35a45493fb 100644 --- a/apps/dashboard/app/(app)/settings/billing/stripe/success/page.tsx +++ b/apps/dashboard/app/(app)/settings/billing/stripe/success/page.tsx @@ -29,6 +29,11 @@ export default async function StripeSuccess(props: Props) { const ws = await db.query.workspaces.findFirst({ where: (table, { and, eq, isNull }) => and(eq(table.tenantId, tenantId), isNull(table.deletedAt)), + with: { + auditLogBuckets: { + where: (table, { eq }) => eq(table.name, "unkey_mutations"), + }, + }, }); if (!ws) { return redirect("/new"); @@ -99,7 +104,7 @@ export default async function StripeSuccess(props: Props) { .where(eq(schema.workspaces.id, ws.id)); if (isUpgradingPlan) { - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ws.auditLogBuckets[0].id, { workspaceId: ws.id, actor: { type: "user", id: user.id }, event: "workspace.update", diff --git a/apps/dashboard/app/new/create-ratelimit.tsx b/apps/dashboard/app/new/create-ratelimit.tsx index 927894b64f..c7908c065d 100644 --- a/apps/dashboard/app/new/create-ratelimit.tsx +++ b/apps/dashboard/app/new/create-ratelimit.tsx @@ -4,11 +4,15 @@ import { getTenantId } from "@/lib/auth"; import { router } from "@/lib/trpc/routers"; import { auth } from "@clerk/nextjs"; import { createCallerFactory } from "@trpc/server"; +import type { AuditLogBucket, Workspace } from "@unkey/db"; import { Button } from "@unkey/ui"; import { GlobeLock } from "lucide-react"; import Link from "next/link"; -export const CreateRatelimit: React.FC = async () => { +type Props = { + workspace: Workspace & { auditLogBucket: AuditLogBucket }; +}; +export const CreateRatelimit: React.FC = async (props) => { const { sessionClaims, userId } = auth(); if (!userId) { return null; @@ -20,6 +24,7 @@ export const CreateRatelimit: React.FC = async () => { user: { id: userId, }, + workspace: props.workspace, tenant: { id: tenantId, role: "", @@ -30,10 +35,6 @@ export const CreateRatelimit: React.FC = async () => { }, }); - await trpc.workspace.optIntoBeta({ - feature: "ratelimit", - }); - const rootKey = await trpc.rootKey.create({ name: "onboarding", permissions: ["ratelimit.*.create_namespace", "ratelimit.*.limit"], diff --git a/apps/dashboard/app/new/page.tsx b/apps/dashboard/app/new/page.tsx index 588f74cfc5..f831c09bd8 100644 --- a/apps/dashboard/app/new/page.tsx +++ b/apps/dashboard/app/new/page.tsx @@ -162,6 +162,11 @@ export default async function (props: Props) { const workspace = await db.query.workspaces.findFirst({ where: (table, { and, eq, isNull }) => and(eq(table.id, props.searchParams.workspaceId!), isNull(table.deletedAt)), + with: { + auditLogBuckets: { + where: (table, { eq }) => eq(table.name, "unkey_mutations"), + }, + }, }); if (!workspace) { return redirect("/new"); @@ -184,7 +189,9 @@ export default async function (props: Props) { - + ); } @@ -210,7 +217,17 @@ export default async function (props: Props) { subscriptions: null, createdAt: new Date(), }); - await insertAuditLogs(tx, { + + const bucketId = newId("auditLogBucket"); + await tx.insert(schema.auditLogBucket).values({ + id: bucketId, + workspaceId, + name: "unkey_mutations", + retentionDays: 30, + deleteProtection: true, + }); + + await insertAuditLogs(tx, bucketId, { workspaceId: workspaceId, event: "workspace.create", actor: { diff --git a/apps/dashboard/lib/audit.ts b/apps/dashboard/lib/audit.ts index b3d4c21eba..96c0b90663 100644 --- a/apps/dashboard/lib/audit.ts +++ b/apps/dashboard/lib/audit.ts @@ -91,14 +91,9 @@ export type UnkeyAuditLog = { }; }; -const BUCKET_NAME = "unkey_mutations"; - -type Key = `${string}::${string}`; -type BucketId = string; -const bucketCache = new Map(); - export async function insertAuditLogs( db: Transaction | Database, + bucketId: string, logOrLogs: MaybeArray, ) { const logs = Array.isArray(logOrLogs) ? logOrLogs : [logOrLogs]; @@ -108,35 +103,6 @@ export async function insertAuditLogs( } for (const log of logs) { - // 1. Get the bucketId or create one if necessary - const key: Key = `${log.workspaceId}::${BUCKET_NAME}`; - let bucketId = ""; - const cachedBucketId = bucketCache.get(key); - if (cachedBucketId) { - bucketId = cachedBucketId; - } else { - const bucket = await db.query.auditLogBucket.findFirst({ - where: (table, { eq, and }) => - and(eq(table.workspaceId, log.workspaceId), eq(table.name, BUCKET_NAME)), - columns: { - id: true, - }, - }); - if (bucket) { - bucketId = bucket.id; - } else { - bucketId = newId("auditLogBucket"); - await db.insert(schema.auditLogBucket).values({ - id: bucketId, - workspaceId: log.workspaceId, - name: BUCKET_NAME, - }); - } - } - bucketCache.set(key, bucketId); - - // 2. Insert the log - const auditLogId = newId("auditLog"); await db.insert(schema.auditLog).values({ id: auditLogId, diff --git a/apps/dashboard/lib/trpc/context.ts b/apps/dashboard/lib/trpc/context.ts index ace7487442..c6405c4149 100644 --- a/apps/dashboard/lib/trpc/context.ts +++ b/apps/dashboard/lib/trpc/context.ts @@ -2,10 +2,49 @@ import type { inferAsyncReturnType } from "@trpc/server"; import type { FetchCreateContextFnOptions } from "@trpc/server/adapters/fetch"; import { getAuth } from "@clerk/nextjs/server"; +import { newId } from "@unkey/id"; +import { type AuditLogBucket, type Workspace, db, schema } from "../db"; export async function createContext({ req }: FetchCreateContextFnOptions) { const { userId, orgId, orgRole } = getAuth(req as any); + let ws: (Workspace & { auditLogBucket: AuditLogBucket }) | undefined; + const tenantId = orgId ?? userId; + if (tenantId) { + await db.transaction(async (tx) => { + const res = await tx.query.workspaces.findFirst({ + where: (table, { eq, and, isNull }) => + and(eq(table.tenantId, tenantId), isNull(table.deletedAt)), + with: { + auditLogBuckets: { + where: (table, { eq }) => eq(table.name, "unkey_mutations"), + }, + }, + }); + if (res) { + let auditLogBucket = res.auditLogBuckets.at(0); + // @ts-expect-error it should be undefined + delete res.auditLogBuckets; // we don't need to pollute or context + if (!auditLogBucket) { + auditLogBucket = { + id: newId("auditLogBucket"), + name: "unkey_mutations", + createdAt: Date.now(), + deleteProtection: true, + workspaceId: res.id, + retentionDays: 30, + updatedAt: null, + }; + await tx.insert(schema.auditLogBucket).values(auditLogBucket); + } + ws = { + ...res, + auditLogBucket, + }; + } + }); + } + return { req, audit: { @@ -13,6 +52,7 @@ export async function createContext({ req }: FetchCreateContextFnOptions) { location: req.headers.get("x-forwarded-for") ?? process.env.VERCEL_REGION ?? "unknown", }, user: userId ? { id: userId } : null, + workspace: ws, tenant: orgId && orgRole ? { diff --git a/apps/dashboard/lib/trpc/routers/api/create.ts b/apps/dashboard/lib/trpc/routers/api/create.ts index bc5ad78add..2268c6bdbd 100644 --- a/apps/dashboard/lib/trpc/routers/api/create.ts +++ b/apps/dashboard/lib/trpc/routers/api/create.ts @@ -12,34 +12,16 @@ export const createApi = t.procedure z.object({ name: z .string() - .min(3, "workspace names must contain at least 3 characters") - .max(50, "workspace names must contain at most 50 characters"), + .min(3, "API names must contain at least 3 characters") + .max(50, "API names must contain at most 50 characters"), }), ) .mutation(async ({ input, ctx }) => { - const ws = await db.query.workspaces - .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.tenantId, ctx.tenant.id), isNull(table.deletedAt)), - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: "We are unable to create an API. Please try again or contact support@unkey.dev", - }); - }); - if (!ws) { - throw new TRPCError({ - code: "NOT_FOUND", - message: "The workspace does not exist.", - }); - } - const keyAuthId = newId("keyAuth"); try { await db.insert(schema.keyAuth).values({ id: keyAuthId, - workspaceId: ws.id, + workspaceId: ctx.workspace.id, createdAt: new Date(), }); } catch (_err) { @@ -58,7 +40,7 @@ export const createApi = t.procedure .values({ id: apiId, name: input.name, - workspaceId: ws.id, + workspaceId: ctx.workspace.id, keyAuthId, authType: "key", ipWhitelist: null, @@ -72,8 +54,8 @@ export const createApi = t.procedure }); }); - await insertAuditLogs(tx, { - workspaceId: ws.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id, diff --git a/apps/dashboard/lib/trpc/routers/api/delete.ts b/apps/dashboard/lib/trpc/routers/api/delete.ts index 90aa9e329c..23641be8b2 100644 --- a/apps/dashboard/lib/trpc/routers/api/delete.ts +++ b/apps/dashboard/lib/trpc/routers/api/delete.ts @@ -15,10 +15,11 @@ export const deleteApi = t.procedure const api = await db.query.apis .findFirst({ where: (table, { eq, and, isNull }) => - and(eq(table.id, input.apiId), isNull(table.deletedAt)), - with: { - workspace: true, - }, + and( + eq(table.workspaceId, ctx.workspace.id), + eq(table.id, input.apiId), + isNull(table.deletedAt), + ), }) .catch((_err) => { throw new TRPCError({ @@ -27,7 +28,7 @@ export const deleteApi = t.procedure "We are unable to delete this API. Please try again or contact support@unkey.dev", }); }); - if (!api || api.workspace.tenantId !== ctx.tenant.id) { + if (!api) { throw new TRPCError({ code: "NOT_FOUND", message: "The API does not exist. Please try again or contact support@unkey.dev", @@ -46,7 +47,7 @@ export const deleteApi = t.procedure .update(schema.apis) .set({ deletedAt: new Date() }) .where(eq(schema.apis.id, input.apiId)); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: api.workspaceId, actor: { type: "user", @@ -78,8 +79,9 @@ export const deleteApi = t.procedure .where(eq(schema.keys.keyAuthId, api.keyAuthId!)); await insertAuditLogs( tx, + ctx.workspace.auditLogBucket.id, keyIds.map(({ id }) => ({ - workspaceId: api.workspace.id, + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id, diff --git a/apps/dashboard/lib/trpc/routers/api/setDefaultBytes.ts b/apps/dashboard/lib/trpc/routers/api/setDefaultBytes.ts index 82d0ffb3e1..9533d39cd5 100644 --- a/apps/dashboard/lib/trpc/routers/api/setDefaultBytes.ts +++ b/apps/dashboard/lib/trpc/routers/api/setDefaultBytes.ts @@ -21,10 +21,11 @@ export const setDefaultApiBytes = t.procedure const keyAuth = await db.query.keyAuth .findFirst({ where: (table, { eq, and, isNull }) => - and(eq(table.id, input.keyAuthId), isNull(table.deletedAt)), - with: { - workspace: true, - }, + and( + eq(table.workspaceId, ctx.workspace.id), + eq(table.id, input.keyAuthId), + isNull(table.deletedAt), + ), }) .catch((_err) => { throw new TRPCError({ @@ -33,7 +34,7 @@ export const setDefaultApiBytes = t.procedure "We were unable to update the key auth. Please try again or contact support@unkey.dev", }); }); - if (!keyAuth || keyAuth.workspace.tenantId !== ctx.tenant.id) { + if (!keyAuth) { throw new TRPCError({ code: "NOT_FOUND", message: @@ -55,8 +56,8 @@ export const setDefaultApiBytes = t.procedure "We were unable to update the API default bytes. Please try again or contact support@unkey.dev.", }); }); - await insertAuditLogs(tx, { - workspaceId: keyAuth.workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id, diff --git a/apps/dashboard/lib/trpc/routers/api/setDefaultPrefix.ts b/apps/dashboard/lib/trpc/routers/api/setDefaultPrefix.ts index d5f6feeeae..1fd0dea302 100644 --- a/apps/dashboard/lib/trpc/routers/api/setDefaultPrefix.ts +++ b/apps/dashboard/lib/trpc/routers/api/setDefaultPrefix.ts @@ -17,10 +17,11 @@ export const setDefaultApiPrefix = t.procedure const keyAuth = await db.query.keyAuth .findFirst({ where: (table, { eq, and, isNull }) => - and(eq(table.id, input.keyAuthId), isNull(table.deletedAt)), - with: { - workspace: true, - }, + and( + eq(table.workspaceId, ctx.workspace.id), + eq(table.id, input.keyAuthId), + isNull(table.deletedAt), + ), }) .catch((_err) => { throw new TRPCError({ @@ -29,7 +30,7 @@ export const setDefaultApiPrefix = t.procedure "We were unable to update the key auth. Please try again or contact support@unkey.dev", }); }); - if (!keyAuth || keyAuth.workspace.tenantId !== ctx.tenant.id) { + if (!keyAuth) { throw new TRPCError({ code: "NOT_FOUND", message: @@ -52,8 +53,8 @@ export const setDefaultApiPrefix = t.procedure "We were unable to update the API default prefix. Please try again or contact support@unkey.dev.", }); }); - await insertAuditLogs(tx, { - workspaceId: keyAuth.workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id, diff --git a/apps/dashboard/lib/trpc/routers/api/updateDeleteProtection.ts b/apps/dashboard/lib/trpc/routers/api/updateDeleteProtection.ts index aabb412529..cea0e0b2e9 100644 --- a/apps/dashboard/lib/trpc/routers/api/updateDeleteProtection.ts +++ b/apps/dashboard/lib/trpc/routers/api/updateDeleteProtection.ts @@ -18,10 +18,11 @@ export const updateAPIDeleteProtection = t.procedure const api = await db.query.apis .findFirst({ where: (table, { eq, and, isNull }) => - and(eq(table.id, input.apiId), isNull(table.deletedAt)), - with: { - workspace: true, - }, + and( + eq(table.workspaceId, ctx.workspace.id), + eq(table.id, input.apiId), + isNull(table.deletedAt), + ), }) .catch((_err) => { throw new TRPCError({ @@ -30,7 +31,7 @@ export const updateAPIDeleteProtection = t.procedure "We were unable to update the API. Please try again or contact support@unkey.dev", }); }); - if (!api || api.workspace.tenantId !== ctx.tenant.id) { + if (!api) { throw new TRPCError({ code: "NOT_FOUND", message: @@ -53,8 +54,8 @@ export const updateAPIDeleteProtection = t.procedure "We were unable to update the API. Please try again or contact support@unkey.dev.", }); }); - await insertAuditLogs(tx, { - workspaceId: api.workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id, diff --git a/apps/dashboard/lib/trpc/routers/api/updateIpWhitelist.ts b/apps/dashboard/lib/trpc/routers/api/updateIpWhitelist.ts index f86ec639ca..53aa4997b2 100644 --- a/apps/dashboard/lib/trpc/routers/api/updateIpWhitelist.ts +++ b/apps/dashboard/lib/trpc/routers/api/updateIpWhitelist.ts @@ -26,17 +26,17 @@ export const updateApiIpWhitelist = t.procedure }) .nullable(), apiId: z.string(), - workspaceId: z.string(), }), ) .mutation(async ({ input, ctx }) => { const api = await db.query.apis .findFirst({ where: (table, { eq, and, isNull }) => - and(eq(table.id, input.apiId), isNull(table.deletedAt)), - with: { - workspace: true, - }, + and( + eq(table.workspaceId, ctx.workspace.id), + eq(table.id, input.apiId), + isNull(table.deletedAt), + ), }) .catch((_err) => { throw new TRPCError({ @@ -46,11 +46,7 @@ export const updateApiIpWhitelist = t.procedure }); }); - if ( - !api || - api.workspace.tenantId !== ctx.tenant.id || - input.workspaceId !== api.workspace.id - ) { + if (!api) { throw new TRPCError({ code: "NOT_FOUND", message: @@ -58,7 +54,7 @@ export const updateApiIpWhitelist = t.procedure }); } - if (!api.workspace.features.ipWhitelist) { + if (!ctx.workspace.features.ipWhitelist) { throw new TRPCError({ code: "FORBIDDEN", message: @@ -84,8 +80,8 @@ export const updateApiIpWhitelist = t.procedure }); }); - await insertAuditLogs(tx, { - workspaceId: api.workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id, diff --git a/apps/dashboard/lib/trpc/routers/api/updateName.ts b/apps/dashboard/lib/trpc/routers/api/updateName.ts index ee3cf687d1..90ec739793 100644 --- a/apps/dashboard/lib/trpc/routers/api/updateName.ts +++ b/apps/dashboard/lib/trpc/routers/api/updateName.ts @@ -19,10 +19,11 @@ export const updateApiName = t.procedure const api = await db.query.apis .findFirst({ where: (table, { eq, and, isNull }) => - and(eq(table.id, input.apiId), isNull(table.deletedAt)), - with: { - workspace: true, - }, + and( + eq(table.workspaceId, ctx.workspace.id), + eq(table.id, input.apiId), + isNull(table.deletedAt), + ), }) .catch((_err) => { throw new TRPCError({ @@ -31,7 +32,7 @@ export const updateApiName = t.procedure "We were unable to update the API name. Please try again or contact support@unkey.dev.", }); }); - if (!api || api.workspace.tenantId !== ctx.tenant.id) { + if (!api) { throw new TRPCError({ code: "NOT_FOUND", message: @@ -53,8 +54,8 @@ export const updateApiName = t.procedure "We were unable to update the API name. Please try again or contact support@unkey.dev.", }); }); - await insertAuditLogs(tx, { - workspaceId: api.workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id, diff --git a/apps/dashboard/lib/trpc/routers/audit/fetch.ts b/apps/dashboard/lib/trpc/routers/audit/fetch.ts index 765bb0093f..5703d86d17 100644 --- a/apps/dashboard/lib/trpc/routers/audit/fetch.ts +++ b/apps/dashboard/lib/trpc/routers/audit/fetch.ts @@ -26,26 +26,6 @@ export const fetchAuditLog = rateLimitedProcedure(ratelimit.update) .query(async ({ ctx, input }) => { const { bucketName, events, users, rootKeys, cursor, limit, endTime, startTime } = input; - const workspace = await db.query.workspaces - .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.tenantId, ctx.tenant.id), isNull(table.deletedAt)), - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: - "Failed to retrieve workspace logs due to an error. If this issue persists, please contact support@unkey.dev with the time this occurred.", - }); - }); - - if (!workspace) { - throw new TRPCError({ - code: "NOT_FOUND", - message: "Workspace not found, please contact support using support@unkey.dev.", - }); - } - const selectedActorIds = [...rootKeys, ...users]; const result = await queryAuditLogs( @@ -58,7 +38,7 @@ export const fetchAuditLog = rateLimitedProcedure(ratelimit.update) events, limit, }, - workspace, + ctx.workspace, ); if (!result) { diff --git a/apps/dashboard/lib/trpc/routers/index.ts b/apps/dashboard/lib/trpc/routers/index.ts index 1215b341f9..2aeff0f48a 100644 --- a/apps/dashboard/lib/trpc/routers/index.ts +++ b/apps/dashboard/lib/trpc/routers/index.ts @@ -19,7 +19,6 @@ import { updateKeyOwnerId } from "./key/updateOwnerId"; import { updateKeyRatelimit } from "./key/updateRatelimit"; import { updateKeyRemaining } from "./key/updateRemaining"; import { updateRootKeyName } from "./key/updateRootKeyName"; -import { deleteLlmGateway } from "./llmGateway/delete"; import { llmSearch } from "./logs/llm-search"; import { queryDistinctPaths } from "./logs/query-distinct-paths"; import { queryLogs } from "./logs/query-logs"; @@ -67,9 +66,6 @@ export const router = t.router({ remaining: updateKeyRemaining, }), }), - llmGateway: t.router({ - delete: deleteLlmGateway, - }), rootKey: t.router({ create: createRootKey, delete: deleteRootKeys, diff --git a/apps/dashboard/lib/trpc/routers/key/create.ts b/apps/dashboard/lib/trpc/routers/key/create.ts index 47b51729e8..273fb592cc 100644 --- a/apps/dashboard/lib/trpc/routers/key/create.ts +++ b/apps/dashboard/lib/trpc/routers/key/create.ts @@ -36,29 +36,10 @@ export const createKey = t.procedure }), ) .mutation(async ({ input, ctx }) => { - const workspace = await db.query.workspaces - .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.tenantId, ctx.tenant.id), isNull(table.deletedAt)), - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: - "We were unable to create a key for this API. Please try again or contact support@unkey.dev.", - }); - }); - if (!workspace) { - throw new TRPCError({ - code: "NOT_FOUND", - message: - "We are unable to find the correct workspace. Please try again or contact support@unkey.dev.", - }); - } - const keyAuth = await db.query.keyAuth .findFirst({ - where: (table, { eq }) => eq(table.id, input.keyAuthId), + where: (table, { and, eq }) => + and(eq(table.workspaceId, ctx.workspace.id), eq(table.id, input.keyAuthId)), with: { api: true, }, @@ -70,7 +51,7 @@ export const createKey = t.procedure "We were unable to create a key for this API. Please try again or contact support@unkey.dev.", }); }); - if (!keyAuth || keyAuth.workspaceId !== workspace.id) { + if (!keyAuth) { throw new TRPCError({ code: "NOT_FOUND", message: @@ -93,7 +74,7 @@ export const createKey = t.procedure start, ownerId: input.ownerId, meta: JSON.stringify(input.meta ?? {}), - workspaceId: workspace.id, + workspaceId: ctx.workspace.id, forWorkspaceId: null, expires: input.expires ? new Date(input.expires) : null, createdAt: new Date(), @@ -109,8 +90,8 @@ export const createKey = t.procedure environment: input.environment, }); - await insertAuditLogs(tx, { - workspaceId: workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "key.create", description: `Created ${keyId}`, diff --git a/apps/dashboard/lib/trpc/routers/key/createRootKey.ts b/apps/dashboard/lib/trpc/routers/key/createRootKey.ts index 35562275eb..0267708b1c 100644 --- a/apps/dashboard/lib/trpc/routers/key/createRootKey.ts +++ b/apps/dashboard/lib/trpc/routers/key/createRootKey.ts @@ -22,32 +22,13 @@ export const createRootKey = t.procedure }), ) .mutation(async ({ ctx, input }) => { - const workspace = await db.query.workspaces - .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.tenantId, ctx.tenant.id), isNull(table.deletedAt)), - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: - "We were unable to create a root key for this workspace. Please try again or contact support@unkey.dev.", - }); - }); - if (!workspace) { - throw new TRPCError({ - code: "NOT_FOUND", - message: - "We are unable to find the correct workspace. Please try again or contact support@unkey.dev.", - }); - } - const unkeyApi = await db.query.apis .findFirst({ - where: eq(schema.apis.id, env().UNKEY_API_ID), - with: { - workspace: true, - }, + where: (table, { and, eq }) => + and( + eq(table.workspaceId, env().UNKEY_WORKSPACE_ID), + eq(schema.apis.id, env().UNKEY_API_ID), + ), }) .catch((_err) => { throw new TRPCError({ @@ -87,7 +68,7 @@ export const createRootKey = t.procedure start, ownerId: ctx.user.id, workspaceId: env().UNKEY_WORKSPACE_ID, - forWorkspaceId: workspace.id, + forWorkspaceId: ctx.workspace.id, expires: null, createdAt: new Date(), remaining: null, @@ -99,7 +80,7 @@ export const createRootKey = t.procedure }); auditLogs.push({ - workspaceId: workspace.id, + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "key.create", description: `Created ${keyId}`, @@ -129,11 +110,11 @@ export const createRootKey = t.procedure identityId = newId("identity"); await tx.insert(schema.identities).values({ id: identityId, - workspaceId: workspace.id, + workspaceId: ctx.workspace.id, externalId: ctx.user.id, }); auditLogs.push({ - workspaceId: workspace.id, + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "identity.create", description: `Created ${identityId}`, @@ -160,7 +141,7 @@ export const createRootKey = t.procedure auditLogs.push( ...permissions.map((p) => ({ - workspaceId: workspace.id, + workspaceId: ctx.workspace.id, actor: { type: "user" as const, id: ctx.user.id }, event: "authorization.connect_permission_and_key" as const, description: `Connected ${p.id} and ${keyId}`, @@ -188,7 +169,7 @@ export const createRootKey = t.procedure workspaceId: env().UNKEY_WORKSPACE_ID, })), ); - await insertAuditLogs(tx, auditLogs); + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, auditLogs); }); } catch (_err) { throw new TRPCError({ diff --git a/apps/dashboard/lib/trpc/routers/key/delete.ts b/apps/dashboard/lib/trpc/routers/key/delete.ts index 33e38ac132..b089610c34 100644 --- a/apps/dashboard/lib/trpc/routers/key/delete.ts +++ b/apps/dashboard/lib/trpc/routers/key/delete.ts @@ -57,6 +57,7 @@ export const deleteKeys = t.procedure ); insertAuditLogs( tx, + ctx.workspace.auditLogBucket.id, workspace.keys.map((key) => ({ workspaceId: workspace.id, actor: { type: "user", id: ctx.user.id }, diff --git a/apps/dashboard/lib/trpc/routers/key/deleteRootKey.ts b/apps/dashboard/lib/trpc/routers/key/deleteRootKey.ts index 46703a92cd..48803e24b5 100644 --- a/apps/dashboard/lib/trpc/routers/key/deleteRootKey.ts +++ b/apps/dashboard/lib/trpc/routers/key/deleteRootKey.ts @@ -12,32 +12,11 @@ export const deleteRootKeys = t.procedure }), ) .mutation(async ({ ctx, input }) => { - const workspace = await db.query.workspaces - .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.tenantId, ctx.tenant.id), isNull(table.deletedAt)), - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: - "We were unable to delete this root key. Please try again or contact support@unkey.dev.", - }); - }); - - if (!workspace) { - throw new TRPCError({ - code: "NOT_FOUND", - message: - "We are unable to find the correct workspace. Please try again or contact support@unkey.dev.", - }); - } - const rootKeys = await db.query.keys.findMany({ where: (table, { eq, inArray, isNull, and }) => and( eq(table.workspaceId, env().UNKEY_WORKSPACE_ID), - eq(table.forWorkspaceId, workspace.id), + eq(table.forWorkspaceId, ctx.workspace.id), inArray(table.id, input.keyIds), isNull(table.deletedAt), ), @@ -58,8 +37,9 @@ export const deleteRootKeys = t.procedure ); await insertAuditLogs( tx, + ctx.workspace.auditLogBucket.id, rootKeys.map((key) => ({ - workspaceId: workspace.id, + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "key.delete", description: `Deleted ${key.id}`, diff --git a/apps/dashboard/lib/trpc/routers/key/updateEnabled.ts b/apps/dashboard/lib/trpc/routers/key/updateEnabled.ts index be05e3ed8a..4035761835 100644 --- a/apps/dashboard/lib/trpc/routers/key/updateEnabled.ts +++ b/apps/dashboard/lib/trpc/routers/key/updateEnabled.ts @@ -50,7 +50,7 @@ export const updateKeyEnabled = t.procedure "We were unable to update enabled on this key. Please try again or contact support@unkey.dev", }); }); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: key.workspace.id, actor: { type: "user", diff --git a/apps/dashboard/lib/trpc/routers/key/updateExpiration.ts b/apps/dashboard/lib/trpc/routers/key/updateExpiration.ts index fff50f21b6..e5eee583a8 100644 --- a/apps/dashboard/lib/trpc/routers/key/updateExpiration.ts +++ b/apps/dashboard/lib/trpc/routers/key/updateExpiration.ts @@ -36,10 +36,11 @@ export const updateKeyExpiration = t.procedure const key = await db.query.keys .findFirst({ where: (table, { eq, and, isNull }) => - and(eq(table.id, input.keyId), isNull(table.deletedAt)), - with: { - workspace: true, - }, + and( + eq(table.workspaceId, ctx.workspace.id), + eq(table.id, input.keyId), + isNull(table.deletedAt), + ), }) .catch((_err) => { throw new TRPCError({ @@ -48,7 +49,7 @@ export const updateKeyExpiration = t.procedure "We were unable to update expiration on this key. Please try again or contact support@unkey.dev", }); }); - if (!key || key.workspace.tenantId !== ctx.tenant.id) { + if (!key) { throw new TRPCError({ message: "We are unable to find the the correct key. Please try again or contact support@unkey.dev.", @@ -70,8 +71,8 @@ export const updateKeyExpiration = t.procedure "We were unable to update expiration on this key. Please try again or contact support@unkey.dev", }); }); - await insertAuditLogs(tx, { - workspaceId: key.workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id, diff --git a/apps/dashboard/lib/trpc/routers/key/updateMetadata.ts b/apps/dashboard/lib/trpc/routers/key/updateMetadata.ts index 6694b56b9c..96e6bb6285 100644 --- a/apps/dashboard/lib/trpc/routers/key/updateMetadata.ts +++ b/apps/dashboard/lib/trpc/routers/key/updateMetadata.ts @@ -29,10 +29,11 @@ export const updateKeyMetadata = t.procedure const key = await db.query.keys .findFirst({ where: (table, { eq, isNull, and }) => - and(eq(table.id, input.keyId), isNull(table.deletedAt)), - with: { - workspace: true, - }, + and( + eq(table.workspaceId, ctx.workspace.id), + eq(table.id, input.keyId), + isNull(table.deletedAt), + ), }) .catch((_err) => { throw new TRPCError({ @@ -41,7 +42,7 @@ export const updateKeyMetadata = t.procedure "We were unable to update metadata on this key. Please try again or contact support@unkey.dev", }); }); - if (!key || key.workspace.tenantId !== ctx.tenant.id) { + if (!key) { throw new TRPCError({ message: "We are unable to find the correct key. Please try again or contact support@unkey.dev.", @@ -63,8 +64,8 @@ export const updateKeyMetadata = t.procedure "We are unable to update metadata on this key. Please try again or contact support@unkey.dev", }); }); - await insertAuditLogs(tx, { - workspaceId: key.workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id, diff --git a/apps/dashboard/lib/trpc/routers/key/updateName.ts b/apps/dashboard/lib/trpc/routers/key/updateName.ts index 539b7384f6..b139939320 100644 --- a/apps/dashboard/lib/trpc/routers/key/updateName.ts +++ b/apps/dashboard/lib/trpc/routers/key/updateName.ts @@ -15,10 +15,11 @@ export const updateKeyName = t.procedure const key = await db.query.keys .findFirst({ where: (table, { eq, isNull, and }) => - and(eq(table.id, input.keyId), isNull(table.deletedAt)), - with: { - workspace: true, - }, + and( + eq(table.workspaceId, ctx.workspace.id), + eq(table.id, input.keyId), + isNull(table.deletedAt), + ), }) .catch((_err) => { throw new TRPCError({ @@ -27,7 +28,7 @@ export const updateKeyName = t.procedure "We were unable to update the name on this key. Please try again or contact support@unkey.dev", }); }); - if (!key || key.workspace.tenantId !== ctx.tenant.id) { + if (!key) { throw new TRPCError({ message: "We are unable to find the correct key. Please try again or contact support@unkey.dev.", @@ -49,8 +50,8 @@ export const updateKeyName = t.procedure "We are unable to update name on this key. Please try again or contact support@unkey.dev", }); }); - await insertAuditLogs(tx, { - workspaceId: key.workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id, diff --git a/apps/dashboard/lib/trpc/routers/key/updateOwnerId.ts b/apps/dashboard/lib/trpc/routers/key/updateOwnerId.ts index 60f7e8ac3c..d376e20fff 100644 --- a/apps/dashboard/lib/trpc/routers/key/updateOwnerId.ts +++ b/apps/dashboard/lib/trpc/routers/key/updateOwnerId.ts @@ -15,10 +15,11 @@ export const updateKeyOwnerId = t.procedure const key = await db.query.keys .findFirst({ where: (table, { eq, and, isNull }) => - and(eq(table.id, input.keyId), isNull(table.deletedAt)), - with: { - workspace: true, - }, + and( + eq(table.workspaceId, ctx.workspace.id), + eq(table.id, input.keyId), + isNull(table.deletedAt), + ), }) .catch((_err) => { throw new TRPCError({ @@ -27,7 +28,7 @@ export const updateKeyOwnerId = t.procedure "We were unable to update ownerId on this key. Please try again or contact support@unkey.dev", }); }); - if (!key || key.workspace.tenantId !== ctx.tenant.id) { + if (!key) { throw new TRPCError({ message: "We are unable to find the correct key. Please try again or contact support@unkey.dev.", @@ -50,8 +51,8 @@ export const updateKeyOwnerId = t.procedure "We were unable to update ownerId on this key. Please try again or contact support@unkey.dev", }); }); - await insertAuditLogs(tx, { - workspaceId: key.workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id, diff --git a/apps/dashboard/lib/trpc/routers/key/updateRatelimit.ts b/apps/dashboard/lib/trpc/routers/key/updateRatelimit.ts index 8316b8d503..961544daf3 100644 --- a/apps/dashboard/lib/trpc/routers/key/updateRatelimit.ts +++ b/apps/dashboard/lib/trpc/routers/key/updateRatelimit.ts @@ -19,10 +19,11 @@ export const updateKeyRatelimit = t.procedure const key = await db.query.keys .findFirst({ where: (table, { eq, and, isNull }) => - and(eq(table.id, input.keyId), isNull(table.deletedAt)), - with: { - workspace: true, - }, + and( + eq(table.workspaceId, ctx.workspace.id), + eq(table.id, input.keyId), + isNull(table.deletedAt), + ), }) .catch((_err) => { throw new TRPCError({ @@ -31,7 +32,7 @@ export const updateKeyRatelimit = t.procedure "We were unable to update ratelimits on this key. Please try again or contact support@unkey.dev", }); }); - if (!key || key.workspace.tenantId !== ctx.tenant.id) { + if (!key) { throw new TRPCError({ message: "We are unable to find the correct key. Please try again or contact support@unkey.dev.", @@ -60,8 +61,8 @@ export const updateKeyRatelimit = t.procedure ratelimitDuration, }) .where(eq(schema.keys.id, key.id)); - await insertAuditLogs(tx, { - workspaceId: key.workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id, @@ -97,8 +98,8 @@ export const updateKeyRatelimit = t.procedure }) .where(eq(schema.keys.id, key.id)); - await insertAuditLogs(tx, { - workspaceId: key.workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id, diff --git a/apps/dashboard/lib/trpc/routers/key/updateRemaining.ts b/apps/dashboard/lib/trpc/routers/key/updateRemaining.ts index e95fa7f7d6..d6b930c171 100644 --- a/apps/dashboard/lib/trpc/routers/key/updateRemaining.ts +++ b/apps/dashboard/lib/trpc/routers/key/updateRemaining.ts @@ -28,12 +28,13 @@ export const updateKeyRemaining = t.procedure .transaction(async (tx) => { const key = await tx.query.keys.findFirst({ where: (table, { eq, and, isNull }) => - and(eq(table.id, input.keyId), isNull(table.deletedAt)), - with: { - workspace: true, - }, + and( + eq(table.workspaceId, ctx.workspace.id), + eq(table.id, input.keyId), + isNull(table.deletedAt), + ), }); - if (!key || key.workspace.tenantId !== ctx.tenant.id) { + if (!key) { throw new TRPCError({ message: "We are unable to find the correct key. Please try again or contact support@unkey.dev.", @@ -56,8 +57,8 @@ export const updateKeyRemaining = t.procedure "We were unable to update remaining on this key. Please try again or contact support@unkey.dev", }); }); - await insertAuditLogs(tx, { - workspaceId: key.workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id, diff --git a/apps/dashboard/lib/trpc/routers/key/updateRootKeyName.ts b/apps/dashboard/lib/trpc/routers/key/updateRootKeyName.ts index c67ab39789..1d5372ed05 100644 --- a/apps/dashboard/lib/trpc/routers/key/updateRootKeyName.ts +++ b/apps/dashboard/lib/trpc/routers/key/updateRootKeyName.ts @@ -14,31 +14,14 @@ export const updateRootKeyName = t.procedure .mutation(async ({ input, ctx }) => { const key = await db.query.keys.findFirst({ where: (table, { eq, isNull, and }) => - and(eq(table.id, input.keyId), isNull(table.deletedAt)), + and( + eq(table.workspaceId, ctx.workspace.id), + eq(table.id, input.keyId), + isNull(table.deletedAt), + ), }); - const workspace = await db.query.workspaces - .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.tenantId, ctx.tenant.id), isNull(table.deletedAt)), - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: - "We were unable to update root key name. Please try again or contact support@unkey.dev", - }); - }); - - if (!workspace) { - throw new TRPCError({ - code: "NOT_FOUND", - message: - "We are unable to find the correct workspace. Please try again or contact support@unkey.dev.", - }); - } - - if (!key || key.forWorkspaceId !== workspace.id) { + if (!key) { throw new TRPCError({ message: "We are unable to find the correct key. Please try again or contact support@unkey.dev.", @@ -61,8 +44,8 @@ export const updateRootKeyName = t.procedure "We are unable to update root key name. Please try again or contact support@unkey.dev", }); }); - await insertAuditLogs(tx, { - workspaceId: workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id, diff --git a/apps/dashboard/lib/trpc/routers/llmGateway/create.ts b/apps/dashboard/lib/trpc/routers/llmGateway/create.ts deleted file mode 100644 index f31500999b..0000000000 --- a/apps/dashboard/lib/trpc/routers/llmGateway/create.ts +++ /dev/null @@ -1,83 +0,0 @@ -import { insertAuditLogs } from "@/lib/audit"; -import { db, schema } from "@/lib/db"; -import { DatabaseError } from "@planetscale/database"; -import { TRPCError } from "@trpc/server"; -import { newId } from "@unkey/id"; -import { z } from "zod"; -import { auth, t } from "../../trpc"; -export const createLlmGateway = t.procedure - .use(auth) - .input( - z.object({ - subdomain: z.string().min(1).max(50), - }), - ) - .mutation(async ({ input, ctx }) => { - const ws = await db.query.workspaces - .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.tenantId, ctx.tenant.id), isNull(table.deletedAt)), - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: - "We were unable to create LLM gateway. Please try again or contact support@unkey.dev", - }); - }); - if (!ws) { - throw new TRPCError({ - code: "NOT_FOUND", - message: - "We are unable to find the correct workspace. Please try again or contact support@unkey.dev.", - }); - } - - const llmGatewayId = newId("llmGateway"); - - await db - .transaction(async (tx) => { - await tx.insert(schema.llmGateways).values({ - id: llmGatewayId, - subdomain: input.subdomain, - name: input.subdomain, - workspaceId: ws.id, - }); - await insertAuditLogs(tx, { - workspaceId: ws.id, - actor: { - type: "user", - id: ctx.user.id, - }, - event: "llmGateway.create", - description: `Created ${llmGatewayId}`, - resources: [ - { - type: "gateway", - id: llmGatewayId, - }, - ], - context: { - location: ctx.audit.location, - userAgent: ctx.audit.userAgent, - }, - }); - }) - .catch((err) => { - if (err instanceof DatabaseError && err.body.message.includes("Duplicate entry")) { - throw new TRPCError({ - code: "PRECONDITION_FAILED", - message: - "Gateway subdomains must have unique names. Please try a different subdomain.
If you believe this is an error and the subdomain should not be in use already, please contact support at support@unkey.dev", - }); - } - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: "Unable to create gateway, please contact support at support@unkey.dev", - }); - }); - - return { - id: llmGatewayId, - }; - }); diff --git a/apps/dashboard/lib/trpc/routers/llmGateway/delete.ts b/apps/dashboard/lib/trpc/routers/llmGateway/delete.ts deleted file mode 100644 index cfe2220cd9..0000000000 --- a/apps/dashboard/lib/trpc/routers/llmGateway/delete.ts +++ /dev/null @@ -1,81 +0,0 @@ -import { TRPCError } from "@trpc/server"; -import { z } from "zod"; - -import { insertAuditLogs } from "@/lib/audit"; -import { db, eq, schema } from "@/lib/db"; -import { auth, t } from "../../trpc"; -export const deleteLlmGateway = t.procedure - .use(auth) - .input(z.object({ gatewayId: z.string() })) - .mutation(async ({ ctx, input }) => { - const llmGateway = await db.query.llmGateways - .findFirst({ - where: (table, { eq, and }) => and(eq(table.id, input.gatewayId)), - with: { - workspace: { - columns: { - id: true, - tenantId: true, - }, - }, - }, - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: - "We are unable to delete LLM gateway. Please try again or contact support@unkey.dev", - }); - }); - - if (!llmGateway || llmGateway.workspace.tenantId !== ctx.tenant.id) { - throw new TRPCError({ - code: "NOT_FOUND", - message: "LLM gateway not found. Please try again or contact support@unkey.dev.", - }); - } - - await db - .transaction(async (tx) => { - await tx - .delete(schema.llmGateways) - .where(eq(schema.llmGateways.id, input.gatewayId)) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: - "We are unable to delete the LLM gateway. Please try again or contact support@unkey.dev", - }); - }); - await insertAuditLogs(tx, { - workspaceId: llmGateway.workspace.id, - actor: { - type: "user", - id: ctx.user.id, - }, - event: "llmGateway.delete", - description: `Deleted ${llmGateway.id}`, - resources: [ - { - type: "gateway", - id: llmGateway.id, - }, - ], - context: { - location: ctx.audit.location, - userAgent: ctx.audit.userAgent, - }, - }); - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: - "We are unable to delete LLM gateway. Please try again or contact support@unkey.dev", - }); - }); - - return { - id: llmGateway.id, - }; - }); diff --git a/apps/dashboard/lib/trpc/routers/logs/llm-search.ts b/apps/dashboard/lib/trpc/routers/logs/llm-search.ts new file mode 100644 index 0000000000..f6e111e6cb --- /dev/null +++ b/apps/dashboard/lib/trpc/routers/logs/llm-search.ts @@ -0,0 +1,159 @@ +import { METHODS } from "@/app/(app)/logs/constants"; +import { filterFieldConfig, filterOutputSchema } from "@/app/(app)/logs/filters.schema"; +import { env } from "@/lib/env"; +import { rateLimitedProcedure, ratelimit } from "@/lib/trpc/ratelimitProcedure"; +import { TRPCError } from "@trpc/server"; +import OpenAI from "openai"; +import { zodResponseFormat } from "openai/helpers/zod"; +import { z } from "zod"; + +const openai = env().OPENAI_API_KEY + ? new OpenAI({ + apiKey: env().OPENAI_API_KEY, + }) + : null; + +async function getStructuredSearchFromLLM(userSearchMsg: string) { + try { + if (!openai) { + return null; // Skip LLM processing in development environment when OpenAI API key is not configured + } + const completion = await openai.beta.chat.completions.parse({ + // Don't change the model only a few models allow structured outputs + model: "gpt-4o-2024-08-06", + temperature: 0.2, // Range 0-2, lower = more focused/deterministic + top_p: 0.1, // Alternative to temperature, controls randomness + frequency_penalty: 0.5, // Range -2 to 2, higher = less repetition + presence_penalty: 0.5, // Range -2 to 2, higher = more topic diversity + n: 1, // Number of completions to generate + messages: [ + { + role: "system", + content: getSystemPrompt(), + }, + { + role: "user", + content: userSearchMsg, + }, + ], + response_format: zodResponseFormat(filterOutputSchema, "searchQuery"), + }); + + if (!completion.choices[0].message.parsed) { + throw new TRPCError({ + code: "UNPROCESSABLE_CONTENT", + message: + "Try using phrases like:\n" + + "• 'find all POST requests'\n" + + "• 'show requests with status 404'\n" + + "• 'find requests to api/v1'\n" + + "• 'show requests from test.example.com'\n" + + "• 'find all GET and POST requests'\n" + + "For additional help, contact support@unkey.dev", + }); + } + + return completion.choices[0].message.parsed; + } catch (error) { + console.error( + `Something went wrong when querying OpenAI. Input: ${JSON.stringify( + userSearchMsg, + )}\n Output ${(error as Error).message}}`, + ); + if (error instanceof TRPCError) { + throw error; + } + + if ((error as any).response?.status === 429) { + throw new TRPCError({ + code: "TOO_MANY_REQUESTS", + message: "Search rate limit exceeded. Please try again in a few minutes.", + }); + } + + throw new TRPCError({ + code: "INTERNAL_SERVER_ERROR", + message: + "Failed to process your search query. Please try again or contact support@unkey.dev if the issue persists.", + }); + } +} + +export const llmSearch = rateLimitedProcedure(ratelimit.update) + .input(z.string()) + .mutation(async ({ input }) => { + return await getStructuredSearchFromLLM(input); + }); + +// HELPERS + +const getSystemPrompt = () => { + const operatorsByField = Object.entries(filterFieldConfig) + .map(([field, config]) => { + const operators = config.operators.join(", "); + let constraints = ""; + + if (field === "methods") { + constraints = ` and must be one of: ${METHODS.join(", ")}`; + } else if (field === "status") { + constraints = " and must be between 100-599"; + } + + return `- ${field} accepts ${operators} operator${ + config.operators.length > 1 ? "s" : "" + }${constraints}`; + }) + .join("\n"); + + return `You are an expert at converting natural language queries into filters. For queries with multiple conditions, output all relevant filters. We will process them in sequence to build the complete filter. Examples: + +Query: "path should start with /api/oz and method should be POST" +Result: [ + { + field: "paths", + filters: [{ operator: "startsWith", value: "/api/oz" }] + }, + { + field: "methods", + filters: [{ operator: "is", value: "POST" }] + } +] + +Query: "find POST and GET requests to api/v1" +Result: [ + { + field: "paths", + filters: [{ operator: "startsWith", value: "api/v1" }] + }, + { + field: "methods", + filters: [ + { operator: "is", value: "POST" }, + { operator: "is", value: "GET" } + ] + } +] + +Query: "show 404 requests from test.example.com" +Result: [ + { + field: "host", + filters: [{ operator: "is", value: "test.example.com" }] + }, + { + field: "status", + filters: [{ operator: "is", value: 404 }] + } +] + +Query: "find all POST requests" +Result: [ + { + field: "methods", + filters: [{ operator: "is", value: "POST" }] + } +] + +Remember: +${operatorsByField}`; +}; diff --git a/apps/dashboard/lib/trpc/routers/logs/llm-search/index.ts b/apps/dashboard/lib/trpc/routers/logs/llm-search/index.ts index 28ccde169f..0bf796c810 100644 --- a/apps/dashboard/lib/trpc/routers/logs/llm-search/index.ts +++ b/apps/dashboard/lib/trpc/routers/logs/llm-search/index.ts @@ -1,7 +1,5 @@ -import { db } from "@/lib/db"; import { env } from "@/lib/env"; import { rateLimitedProcedure, ratelimit } from "@/lib/trpc/ratelimitProcedure"; -import { TRPCError } from "@trpc/server"; import OpenAI from "openai"; import { z } from "zod"; import { getStructuredSearchFromLLM } from "./utils"; @@ -14,26 +12,6 @@ const openai = env().OPENAI_API_KEY export const llmSearch = rateLimitedProcedure(ratelimit.update) .input(z.object({ query: z.string(), timestamp: z.number() })) - .mutation(async ({ ctx, input }) => { - const workspace = await db.query.workspaces - .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.tenantId, ctx.tenant.id), isNull(table.deletedAt)), - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: - "Failed to verify workspace access. Please try again or contact support@unkey.dev if this persists.", - }); - }); - - if (!workspace) { - throw new TRPCError({ - code: "NOT_FOUND", - message: "Workspace not found, please contact support using support@unkey.dev.", - }); - } - + .mutation(async ({ input }) => { return await getStructuredSearchFromLLM(openai, input.query, input.timestamp); }); diff --git a/apps/dashboard/lib/trpc/routers/logs/query-distinct-paths.ts b/apps/dashboard/lib/trpc/routers/logs/query-distinct-paths.ts index e473e46df7..3d13e5e293 100644 --- a/apps/dashboard/lib/trpc/routers/logs/query-distinct-paths.ts +++ b/apps/dashboard/lib/trpc/routers/logs/query-distinct-paths.ts @@ -1,48 +1,23 @@ import { clickhouse } from "@/lib/clickhouse"; -import { db } from "@/lib/db"; import { rateLimitedProcedure, ratelimit } from "@/lib/trpc/ratelimitProcedure"; -import { TRPCError } from "@trpc/server"; import { z } from "zod"; -export const queryDistinctPaths = rateLimitedProcedure(ratelimit.update) - .input(z.object({ currentDate: z.number() })) - .query(async ({ ctx, input }) => { - const workspace = await db.query.workspaces - .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.tenantId, ctx.tenant.id), isNull(table.deletedAt)), - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: - "Failed to retrieve distinct paths due to an error. If this issue persists, please contact support@unkey.dev with the time this occurred.", - }); - }); - - if (!workspace) { - throw new TRPCError({ - code: "NOT_FOUND", - message: "Workspace not found, please contact support using support@unkey.dev.", - }); - } - - const fromDate = input.currentDate - 12 * 60 * 60 * 1000; // 12 hours ago - const result = await clickhouse.querier.query({ - query: ` - SELECT DISTINCT path - FROM metrics.raw_api_requests_v1 +export const queryDistinctPaths = rateLimitedProcedure(ratelimit.update).query(async ({ ctx }) => { + const result = await clickhouse.querier.query({ + query: ` + SELECT DISTINCT path + FROM metrics.raw_api_requests_v1 WHERE workspace_id = {workspaceId: String} AND time >= {fromDate: UInt64}`, - schema: z.object({ path: z.string() }), - params: z.object({ - workspaceId: z.string(), - fromDate: z.number(), - }), - })({ - workspaceId: workspace.id, - fromDate, - }); - - return result.val?.map((i) => i.path) ?? []; + schema: z.object({ path: z.string() }), + params: z.object({ + workspaceId: z.string(), + fromDate: z.number(), + }), + })({ + workspaceId: ctx.workspace.id, + fromDate: Date.now() - 12 * 60 * 60 * 1000, // 12 hours ag, }); + + return result.val?.map((i) => i.path) ?? []; +}); diff --git a/apps/dashboard/lib/trpc/routers/logs/query-timeseries/index.ts b/apps/dashboard/lib/trpc/routers/logs/query-timeseries/index.ts index 1cd6a4de19..0d6839e749 100644 --- a/apps/dashboard/lib/trpc/routers/logs/query-timeseries/index.ts +++ b/apps/dashboard/lib/trpc/routers/logs/query-timeseries/index.ts @@ -1,6 +1,5 @@ import { queryTimeseriesPayload } from "@/app/(app)/logs/components/charts/query-timeseries.schema"; import { clickhouse } from "@/lib/clickhouse"; -import { db } from "@/lib/db"; import { rateLimitedProcedure, ratelimit } from "@/lib/trpc/ratelimitProcedure"; import { TRPCError } from "@trpc/server"; import { transformFilters } from "./utils"; @@ -8,30 +7,10 @@ import { transformFilters } from "./utils"; export const queryTimeseries = rateLimitedProcedure(ratelimit.update) .input(queryTimeseriesPayload) .query(async ({ ctx, input }) => { - const workspace = await db.query.workspaces - .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.tenantId, ctx.tenant.id), isNull(table.deletedAt)), - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: - "Failed to retrieve timeseries analytics due to an workspace error. If this issue persists, please contact support@unkey.dev with the time this occurred.", - }); - }); - - if (!workspace) { - throw new TRPCError({ - code: "NOT_FOUND", - message: "Workspace not found, please contact support using support@unkey.dev.", - }); - } - const { params: transformedInputs, granularity } = transformFilters(input); const result = await clickhouse.api.timeseries[granularity]({ ...transformedInputs, - workspaceId: workspace.id, + workspaceId: ctx.workspace.id, }); if (result.err) { diff --git a/apps/dashboard/lib/trpc/routers/ratelimit/createNamespace.ts b/apps/dashboard/lib/trpc/routers/ratelimit/createNamespace.ts index 181adbb14e..1655415595 100644 --- a/apps/dashboard/lib/trpc/routers/ratelimit/createNamespace.ts +++ b/apps/dashboard/lib/trpc/routers/ratelimit/createNamespace.ts @@ -14,38 +14,18 @@ export const createNamespace = t.procedure }), ) .mutation(async ({ input, ctx }) => { - const ws = await db.query.workspaces - .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.tenantId, ctx.tenant.id), isNull(table.deletedAt)), - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: - "We are unable to create a new namespace. Please try again or contact support@unkey.dev", - }); - }); - if (!ws) { - throw new TRPCError({ - code: "NOT_FOUND", - message: - "We are unable to find the correct workspace. Please try again or contact support@unkey.dev.", - }); - } - const namespaceId = newId("ratelimitNamespace"); await db .transaction(async (tx) => { await tx.insert(schema.ratelimitNamespaces).values({ id: namespaceId, name: input.name, - workspaceId: ws.id, + workspaceId: ctx.workspace.id, createdAt: new Date(), }); - await insertAuditLogs(tx, { - workspaceId: ws.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id, diff --git a/apps/dashboard/lib/trpc/routers/ratelimit/createOverride.ts b/apps/dashboard/lib/trpc/routers/ratelimit/createOverride.ts index 839424c874..e002a9d330 100644 --- a/apps/dashboard/lib/trpc/routers/ratelimit/createOverride.ts +++ b/apps/dashboard/lib/trpc/routers/ratelimit/createOverride.ts @@ -20,16 +20,11 @@ export const createOverride = t.procedure const namespace = await db.query.ratelimitNamespaces .findFirst({ where: (table, { and, eq, isNull }) => - and(eq(table.id, input.namespaceId), isNull(table.deletedAt)), - with: { - workspace: { - columns: { - id: true, - tenantId: true, - features: true, - }, - }, - }, + and( + eq(table.workspaceId, ctx.workspace.id), + eq(table.id, input.namespaceId), + isNull(table.deletedAt), + ), }) .catch((_err) => { throw new TRPCError({ @@ -38,7 +33,7 @@ export const createOverride = t.procedure "We are unable to create an override for this namespace. Please try again or contact support@unkey.dev", }); }); - if (!namespace || namespace.workspace.tenantId !== ctx.tenant.id) { + if (!namespace) { throw new TRPCError({ code: "NOT_FOUND", message: @@ -60,8 +55,8 @@ export const createOverride = t.procedure ) .then((res) => Number(res.at(0)?.count ?? 0)); const max = - typeof namespace.workspace.features.ratelimitOverrides === "number" - ? namespace.workspace.features.ratelimitOverrides + typeof ctx.workspace.features.ratelimitOverrides === "number" + ? ctx.workspace.features.ratelimitOverrides : 5; if (existing >= max) { throw new TRPCError({ @@ -71,7 +66,7 @@ export const createOverride = t.procedure } await tx.insert(schema.ratelimitOverrides).values({ - workspaceId: namespace.workspace.id, + workspaceId: ctx.workspace.id, namespaceId: namespace.id, identifier: input.identifier, id, @@ -80,8 +75,8 @@ export const createOverride = t.procedure createdAt: new Date(), async: input.async, }); - await insertAuditLogs(tx, { - workspaceId: namespace.workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id, diff --git a/apps/dashboard/lib/trpc/routers/ratelimit/deleteNamespace.ts b/apps/dashboard/lib/trpc/routers/ratelimit/deleteNamespace.ts index 02f327c48f..850e1f313b 100644 --- a/apps/dashboard/lib/trpc/routers/ratelimit/deleteNamespace.ts +++ b/apps/dashboard/lib/trpc/routers/ratelimit/deleteNamespace.ts @@ -15,16 +15,11 @@ export const deleteNamespace = t.procedure const namespace = await db.query.ratelimitNamespaces .findFirst({ where: (table, { eq, and, isNull }) => - and(eq(table.id, input.namespaceId), isNull(table.deletedAt)), - - with: { - workspace: { - columns: { - id: true, - tenantId: true, - }, - }, - }, + and( + eq(table.workspaceId, ctx.workspace.id), + eq(table.id, input.namespaceId), + isNull(table.deletedAt), + ), }) .catch((_err) => { throw new TRPCError({ @@ -33,7 +28,7 @@ export const deleteNamespace = t.procedure "We are unable to delete namespace. Please try again or contact support@unkey.dev", }); }); - if (!namespace || namespace.workspace.tenantId !== ctx.tenant.id) { + if (!namespace) { throw new TRPCError({ code: "NOT_FOUND", message: @@ -47,7 +42,7 @@ export const deleteNamespace = t.procedure .set({ deletedAt: new Date() }) .where(eq(schema.ratelimitNamespaces.id, input.namespaceId)); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: namespace.workspaceId, actor: { type: "user", @@ -86,8 +81,9 @@ export const deleteNamespace = t.procedure }); await insertAuditLogs( tx, + ctx.workspace.auditLogBucket.id, overrides.map(({ id }) => ({ - workspaceId: namespace.workspace.id, + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id, diff --git a/apps/dashboard/lib/trpc/routers/ratelimit/deleteOverride.ts b/apps/dashboard/lib/trpc/routers/ratelimit/deleteOverride.ts index e505ac5f22..298bf5261b 100644 --- a/apps/dashboard/lib/trpc/routers/ratelimit/deleteOverride.ts +++ b/apps/dashboard/lib/trpc/routers/ratelimit/deleteOverride.ts @@ -14,20 +14,17 @@ export const deleteOverride = t.procedure .mutation(async ({ ctx, input }) => { const override = await db.query.ratelimitOverrides .findFirst({ - where: (table, { and, eq, isNull }) => and(eq(table.id, input.id), isNull(table.deletedAt)), + where: (table, { and, eq, isNull }) => + and( + eq(table.workspaceId, ctx.workspace.id), + eq(table.id, input.id), + isNull(table.deletedAt), + ), with: { namespace: { columns: { id: true, }, - with: { - workspace: { - columns: { - id: true, - tenantId: true, - }, - }, - }, }, }, }) @@ -39,7 +36,7 @@ export const deleteOverride = t.procedure }); }); - if (!override || override.namespace.workspace.tenantId !== ctx.tenant.id) { + if (!override) { throw new TRPCError({ message: "We are unable to find the correct override. Please try again or contact support@unkey.dev.", @@ -59,8 +56,8 @@ export const deleteOverride = t.procedure code: "INTERNAL_SERVER_ERROR", }); }); - await insertAuditLogs(tx, { - workspaceId: override.namespace.workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id, diff --git a/apps/dashboard/lib/trpc/routers/ratelimit/updateNamespaceName.ts b/apps/dashboard/lib/trpc/routers/ratelimit/updateNamespaceName.ts index 2251988124..dbcb36b157 100644 --- a/apps/dashboard/lib/trpc/routers/ratelimit/updateNamespaceName.ts +++ b/apps/dashboard/lib/trpc/routers/ratelimit/updateNamespaceName.ts @@ -14,17 +14,16 @@ export const updateNamespaceName = t.procedure }), ) .mutation(async ({ ctx, input }) => { - const ws = await db.query.workspaces + const namespace = await db.query.ratelimitNamespaces .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.id, input.workspaceId), isNull(table.deletedAt)), - with: { - ratelimitNamespaces: { - where: (table, { eq, and, isNull }) => - and(isNull(table.deletedAt), eq(schema.ratelimitNamespaces.id, input.namespaceId)), - }, - }, + where: (table, { eq, and, isNull }) => + and( + eq(table.workspaceId, ctx.workspace.id), + isNull(table.deletedAt), + eq(table.id, input.namespaceId), + ), }) + .catch((_err) => { throw new TRPCError({ code: "INTERNAL_SERVER_ERROR", @@ -33,18 +32,10 @@ export const updateNamespaceName = t.procedure }); }); - if (!ws || ws.tenantId !== ctx.tenant.id) { - throw new TRPCError({ - message: - "We are unable to find the correct workspace. Please try again or contact support@unkey.dev", - code: "NOT_FOUND", - }); - } - const namespace = ws.ratelimitNamespaces.find((ns) => ns.id === input.namespaceId); if (!namespace) { throw new TRPCError({ message: - "We are unable to find the correct namespace. Please try again or contact support@unkey.dev", + "We are unable to find the correct workspace. Please try again or contact support@unkey.dev", code: "NOT_FOUND", }); } @@ -63,8 +54,8 @@ export const updateNamespaceName = t.procedure code: "INTERNAL_SERVER_ERROR", }); }); - await insertAuditLogs(tx, { - workspaceId: ws.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id, diff --git a/apps/dashboard/lib/trpc/routers/ratelimit/updateOverride.ts b/apps/dashboard/lib/trpc/routers/ratelimit/updateOverride.ts index 9ab5058972..0e5b9b7a14 100644 --- a/apps/dashboard/lib/trpc/routers/ratelimit/updateOverride.ts +++ b/apps/dashboard/lib/trpc/routers/ratelimit/updateOverride.ts @@ -17,20 +17,17 @@ export const updateOverride = t.procedure .mutation(async ({ ctx, input }) => { const override = await db.query.ratelimitOverrides .findFirst({ - where: (table, { and, eq, isNull }) => and(eq(table.id, input.id), isNull(table.deletedAt)), + where: (table, { and, eq, isNull }) => + and( + eq(table.workspaceId, ctx.workspace.id), + eq(table.id, input.id), + isNull(table.deletedAt), + ), with: { namespace: { columns: { id: true, }, - with: { - workspace: { - columns: { - id: true, - tenantId: true, - }, - }, - }, }, }, }) @@ -42,7 +39,7 @@ export const updateOverride = t.procedure }); }); - if (!override || override.namespace.workspace.tenantId !== ctx.tenant.id) { + if (!override) { throw new TRPCError({ message: "We are unable to find the correct override. Please try again or contact support@unkey.dev.", @@ -68,8 +65,8 @@ export const updateOverride = t.procedure code: "INTERNAL_SERVER_ERROR", }); }); - await insertAuditLogs(tx, { - workspaceId: override.namespace.workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id, diff --git a/apps/dashboard/lib/trpc/routers/rbac.ts b/apps/dashboard/lib/trpc/routers/rbac.ts index 0514d6cfdf..8a8142f522 100644 --- a/apps/dashboard/lib/trpc/routers/rbac.ts +++ b/apps/dashboard/lib/trpc/routers/rbac.ts @@ -75,7 +75,7 @@ export const rbacRouter = t.router({ workspaceId: permissions[0].workspaceId, }) .onDuplicateKeyUpdate({ set: { permissionId: permissions[0].id } }); - await insertAuditLogs(tx, auditLogs); + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, auditLogs); }); }), removePermissionFromRootKey: rateLimitedProcedure(ratelimit.update) @@ -135,7 +135,7 @@ export const rbacRouter = t.router({ eq(schema.keysPermissions.permissionId, permissionRelation.permissionId), ), ); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: permissionRelation.workspaceId, actor: { type: "user", id: ctx.user!.id }, event: "authorization.disconnect_permission_and_key", @@ -210,7 +210,7 @@ export const rbacRouter = t.router({ .onDuplicateKeyUpdate({ set: { ...tuple, updatedAt: new Date() }, }); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: tuple.workspaceId, actor: { type: "user", id: ctx.user!.id }, event: "authorization.connect_role_and_permission", @@ -260,7 +260,7 @@ export const rbacRouter = t.router({ eq(schema.rolesPermissions.permissionId, input.permissionId), ), ); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: workspace.id, actor: { type: "user", id: ctx.user!.id }, event: "authorization.disconnect_role_and_permissions", @@ -335,7 +335,7 @@ export const rbacRouter = t.router({ .onDuplicateKeyUpdate({ set: { ...tuple, updatedAt: new Date() }, }); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: tuple.workspaceId, actor: { type: "user", id: ctx.user!.id }, event: "authorization.connect_role_and_key", @@ -413,7 +413,7 @@ export const rbacRouter = t.router({ description: input.description, workspaceId: workspace.id, }); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: workspace.id, event: "role.create", actor: { @@ -444,6 +444,7 @@ export const rbacRouter = t.router({ ); await insertAuditLogs( tx, + ctx.workspace.auditLogBucket.id, input.permissionIds.map((permissionId) => ({ workspaceId: workspace.id, event: "authorization.connect_role_and_permission", @@ -503,7 +504,7 @@ export const rbacRouter = t.router({ } await db.transaction(async (tx) => { await tx.update(schema.roles).set(input).where(eq(schema.roles.id, input.id)); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: workspace.id, event: "role.update", actor: { @@ -555,7 +556,7 @@ export const rbacRouter = t.router({ .where( and(eq(schema.roles.id, input.roleId), eq(schema.roles.workspaceId, workspace.id)), ); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: workspace.id, event: "role.delete", actor: { @@ -599,7 +600,7 @@ export const rbacRouter = t.router({ description: input.description, workspaceId: workspace.id, }); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: workspace.id, event: "permission.create", actor: { @@ -663,7 +664,7 @@ export const rbacRouter = t.router({ updatedAt: new Date(), }) .where(eq(schema.permissions.id, input.id)); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: workspace.id, event: "permission.update", actor: { @@ -723,7 +724,7 @@ export const rbacRouter = t.router({ eq(schema.permissions.workspaceId, workspace.id), ), ); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: workspace.id, event: "permission.delete", actor: { diff --git a/apps/dashboard/lib/trpc/routers/rbac/addPermissionToRootKey.ts b/apps/dashboard/lib/trpc/routers/rbac/addPermissionToRootKey.ts index 765a59c0fa..bc3244bec0 100644 --- a/apps/dashboard/lib/trpc/routers/rbac/addPermissionToRootKey.ts +++ b/apps/dashboard/lib/trpc/routers/rbac/addPermissionToRootKey.ts @@ -22,29 +22,9 @@ export const addPermissionToRootKey = t.procedure }); } - const workspace = await db.query.workspaces - .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.tenantId, ctx.tenant.id), isNull(table.deletedAt)), - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: - "We are unable to add permission to the rootkey. Please try again or contact support@unkey.dev", - }); - }); - if (!workspace) { - throw new TRPCError({ - code: "NOT_FOUND", - message: - "We are unable to find the correct workspace. Please try again or contact support@unkey.dev.", - }); - } - const rootKey = await db.query.keys.findFirst({ where: (table, { eq, and }) => - and(eq(table.forWorkspaceId, workspace.id), eq(table.id, input.rootKeyId)), + and(eq(table.forWorkspaceId, ctx.workspace.id), eq(table.id, input.rootKeyId)), with: { permissions: { with: { @@ -82,10 +62,10 @@ export const addPermissionToRootKey = t.procedure "We are unable to add permission to the root key. Please try again or contact support@unkey.dev.", }); }); - await insertAuditLogs(tx, [ + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, [ ...auditLogs, { - workspaceId: workspace.id, + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "authorization.connect_permission_and_key", description: `Attached ${p.id} to ${rootKey.id}`, diff --git a/apps/dashboard/lib/trpc/routers/rbac/connectPermissionToRole.ts b/apps/dashboard/lib/trpc/routers/rbac/connectPermissionToRole.ts index 0f863ff81f..b1d646055a 100644 --- a/apps/dashboard/lib/trpc/routers/rbac/connectPermissionToRole.ts +++ b/apps/dashboard/lib/trpc/routers/rbac/connectPermissionToRole.ts @@ -76,7 +76,7 @@ export const connectPermissionToRole = t.procedure "We are unable to connect the permission to the role. Please try again or contact support@unkey.dev.", }); }); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: workspace.id, actor: { type: "user", id: ctx.user.id }, event: "authorization.connect_role_and_permission", diff --git a/apps/dashboard/lib/trpc/routers/rbac/connectRoleToKey.ts b/apps/dashboard/lib/trpc/routers/rbac/connectRoleToKey.ts index 92a58b9663..bd215c4bbb 100644 --- a/apps/dashboard/lib/trpc/routers/rbac/connectRoleToKey.ts +++ b/apps/dashboard/lib/trpc/routers/rbac/connectRoleToKey.ts @@ -76,7 +76,7 @@ export const connectRoleToKey = t.procedure "We are unable to connect the role and key. Please try again or contact support@unkey.dev.", }); }); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: workspace.id, actor: { type: "user", id: ctx.user.id }, event: "authorization.connect_role_and_key", diff --git a/apps/dashboard/lib/trpc/routers/rbac/createPermission.ts b/apps/dashboard/lib/trpc/routers/rbac/createPermission.ts index 4431ebbb3d..a57262449f 100644 --- a/apps/dashboard/lib/trpc/routers/rbac/createPermission.ts +++ b/apps/dashboard/lib/trpc/routers/rbac/createPermission.ts @@ -21,32 +21,12 @@ export const createPermission = t.procedure }), ) .mutation(async ({ input, ctx }) => { - const workspace = await db.query.workspaces - .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.tenantId, ctx.tenant.id), isNull(table.deletedAt)), - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: - "We are unable to create permission. Please try again or contact support@unkey.dev", - }); - }); - - if (!workspace) { - throw new TRPCError({ - code: "NOT_FOUND", - message: - "We are unable to find the correct workspace. Please try again or contact support@unkey.dev.", - }); - } const permissionId = newId("permission"); await db .transaction(async (tx) => { const existing = await tx.query.permissions.findFirst({ where: (table, { and, eq }) => - and(eq(table.workspaceId, workspace.id), eq(table.name, input.name)), + and(eq(table.workspaceId, ctx.workspace.id), eq(table.name, input.name)), }); if (existing) { throw new TRPCError({ @@ -60,10 +40,10 @@ export const createPermission = t.procedure id: permissionId, name: input.name, description: input.description, - workspaceId: workspace.id, + workspaceId: ctx.workspace.id, }); - await insertAuditLogs(tx, { - workspaceId: workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, event: "permission.create", actor: { type: "user", diff --git a/apps/dashboard/lib/trpc/routers/rbac/createRole.ts b/apps/dashboard/lib/trpc/routers/rbac/createRole.ts index 9a40ceff03..fdd6b09803 100644 --- a/apps/dashboard/lib/trpc/routers/rbac/createRole.ts +++ b/apps/dashboard/lib/trpc/routers/rbac/createRole.ts @@ -22,31 +22,12 @@ export const createRole = t.procedure }), ) .mutation(async ({ input, ctx }) => { - const workspace = await db.query.workspaces - .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.tenantId, ctx.tenant.id), isNull(table.deletedAt)), - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: "We are unable to create role. Please try again or contact support@unkey.dev", - }); - }); - - if (!workspace) { - throw new TRPCError({ - code: "NOT_FOUND", - message: - "We are unable to find the correct workspace. Please try again or contact support@unkey.dev.", - }); - } const roleId = newId("role"); await db .transaction(async (tx) => { const existing = await tx.query.roles.findFirst({ where: (table, { and, eq }) => - and(eq(table.workspaceId, workspace.id), eq(table.name, input.name)), + and(eq(table.workspaceId, ctx.workspace.id), eq(table.name, input.name)), }); if (existing) { throw new TRPCError({ @@ -62,7 +43,7 @@ export const createRole = t.procedure id: roleId, name: input.name, description: input.description, - workspaceId: workspace.id, + workspaceId: ctx.workspace.id, }) .catch((_err) => { throw new TRPCError({ @@ -71,8 +52,8 @@ export const createRole = t.procedure "We are unable to create a role. Please try again or contact support@unkey.dev.", }); }); - await insertAuditLogs(tx, { - workspaceId: workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, event: "role.create", actor: { type: "user", @@ -97,13 +78,14 @@ export const createRole = t.procedure input.permissionIds.map((permissionId) => ({ permissionId, roleId: roleId, - workspaceId: workspace.id, + workspaceId: ctx.workspace.id, })), ); await insertAuditLogs( tx, + ctx.workspace.auditLogBucket.id, input.permissionIds.map((permissionId) => ({ - workspaceId: workspace.id, + workspaceId: ctx.workspace.id, event: "authorization.connect_role_and_permission", actor: { type: "user", diff --git a/apps/dashboard/lib/trpc/routers/rbac/deletePermission.ts b/apps/dashboard/lib/trpc/routers/rbac/deletePermission.ts index c978716fa0..de063da541 100644 --- a/apps/dashboard/lib/trpc/routers/rbac/deletePermission.ts +++ b/apps/dashboard/lib/trpc/routers/rbac/deletePermission.ts @@ -11,50 +11,30 @@ export const deletePermission = t.procedure }), ) .mutation(async ({ input, ctx }) => { - const workspace = await db.query.workspaces - .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.tenantId, ctx.tenant.id), isNull(table.deletedAt)), - with: { - permissions: { - where: (table, { eq }) => eq(table.id, input.permissionId), - }, - }, - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: - "We are unable to delete this permission. Please try again or contact support@unkey.dev", - }); - }); - - if (!workspace) { - throw new TRPCError({ - code: "NOT_FOUND", - message: - "We are unable to find the correct workspace. Please try again or contact support@unkey.dev.", - }); - } - if (workspace.permissions.length === 0) { - throw new TRPCError({ - code: "NOT_FOUND", - message: - "We are unable to find the correct permission. Please try again or contact support@unkey.dev.", - }); - } await db .transaction(async (tx) => { + const permission = await tx.query.permissions.findFirst({ + where: (table, { and, eq }) => + and(eq(table.workspaceId, ctx.workspace.id), eq(table.id, input.permissionId)), + }); + + if (!permission) { + throw new TRPCError({ + code: "NOT_FOUND", + message: + "We are unable to find the correct permission. Please try again or contact support@unkey.dev.", + }); + } await tx .delete(schema.permissions) .where( and( - eq(schema.permissions.id, input.permissionId), - eq(schema.permissions.workspaceId, workspace.id), + eq(schema.permissions.id, permission.id), + eq(schema.permissions.workspaceId, ctx.workspace.id), ), ); - await insertAuditLogs(tx, { - workspaceId: workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "permission.delete", description: `Deleted permission ${input.permissionId}`, diff --git a/apps/dashboard/lib/trpc/routers/rbac/deleteRole.ts b/apps/dashboard/lib/trpc/routers/rbac/deleteRole.ts index b59496b7a1..1e22384fac 100644 --- a/apps/dashboard/lib/trpc/routers/rbac/deleteRole.ts +++ b/apps/dashboard/lib/trpc/routers/rbac/deleteRole.ts @@ -11,50 +11,34 @@ export const deleteRole = t.procedure }), ) .mutation(async ({ input, ctx }) => { - const workspace = await db.query.workspaces - .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.tenantId, ctx.tenant.id), isNull(table.deletedAt)), - with: { - roles: { - where: (table, { eq }) => eq(table.id, input.roleId), - }, - }, - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: "We are unable to delete role. Please try again or contact support@unkey.dev", - }); + await db.transaction(async (tx) => { + const role = await tx.query.roles.findFirst({ + where: (table, { and, eq }) => + and(eq(table.workspaceId, ctx.workspace.id), eq(table.id, input.roleId)), }); - if (!workspace) { - throw new TRPCError({ - code: "NOT_FOUND", - message: - "We are unable to find the correct workspace. Please try again or contact support@unkey.dev.", - }); - } - if (workspace.roles.length === 0) { - throw new TRPCError({ - code: "NOT_FOUND", - message: - "We are unable to find the correct role. Please try again or contact support@unkey.dev.", - }); - } - await db.transaction(async (tx) => { + if (!role) { + throw new TRPCError({ + code: "NOT_FOUND", + message: + "We are unable to find the correct role. Please try again or contact support@unkey.dev.", + }); + } await tx .delete(schema.roles) - .where(and(eq(schema.roles.id, input.roleId), eq(schema.roles.workspaceId, workspace.id))) - .catch((_err) => { + .where( + and(eq(schema.roles.id, input.roleId), eq(schema.roles.workspaceId, ctx.workspace.id)), + ) + .catch((err) => { + console.error("Failed to delete role:", err); throw new TRPCError({ code: "INTERNAL_SERVER_ERROR", message: "We are unable to delete the role. Please try again or contact support@unkey.dev", }); }); - await insertAuditLogs(tx, { - workspaceId: workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "role.delete", description: `Deleted role ${input.roleId}`, diff --git a/apps/dashboard/lib/trpc/routers/rbac/disconnectPermissionFromRole.ts b/apps/dashboard/lib/trpc/routers/rbac/disconnectPermissionFromRole.ts index a6023c03e9..987f30d062 100644 --- a/apps/dashboard/lib/trpc/routers/rbac/disconnectPermissionFromRole.ts +++ b/apps/dashboard/lib/trpc/routers/rbac/disconnectPermissionFromRole.ts @@ -12,38 +12,19 @@ export const disconnectPermissionFromRole = t.procedure }), ) .mutation(async ({ input, ctx }) => { - const workspace = await db.query.workspaces - .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.tenantId, ctx.tenant.id), isNull(table.deletedAt)), - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: - "We are unable to remove permission from the role. Please try again or contact support@unkey.dev", - }); - }); - if (!workspace) { - throw new TRPCError({ - code: "NOT_FOUND", - message: - "We are unable to find the correct workspace. Please try again or contact support@unkey.dev.", - }); - } await db .transaction(async (tx) => { await tx .delete(schema.rolesPermissions) .where( and( - eq(schema.rolesPermissions.workspaceId, workspace.id), + eq(schema.rolesPermissions.workspaceId, ctx.workspace.id), eq(schema.rolesPermissions.roleId, input.roleId), eq(schema.rolesPermissions.permissionId, input.permissionId), ), ); - await insertAuditLogs(tx, { - workspaceId: workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "authorization.disconnect_role_and_permissions", description: `Disconnect role ${input.roleId} from permission ${input.permissionId}`, diff --git a/apps/dashboard/lib/trpc/routers/rbac/disconnectRoleFromKey.ts b/apps/dashboard/lib/trpc/routers/rbac/disconnectRoleFromKey.ts index 3610cac54f..afc9de24d3 100644 --- a/apps/dashboard/lib/trpc/routers/rbac/disconnectRoleFromKey.ts +++ b/apps/dashboard/lib/trpc/routers/rbac/disconnectRoleFromKey.ts @@ -12,38 +12,19 @@ export const disconnectRoleFromKey = t.procedure }), ) .mutation(async ({ input, ctx }) => { - const workspace = await db.query.workspaces - .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.tenantId, ctx.tenant.id), isNull(table.deletedAt)), - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: - "We are unable to disconnect the role from the key. Please try again or contact support@unkey.dev", - }); - }); - if (!workspace) { - throw new TRPCError({ - code: "NOT_FOUND", - message: - "We are unable to find the correct workspace. Please try again or contact support@unkey.dev.", - }); - } await db .transaction(async (tx) => { await tx .delete(schema.keysRoles) .where( and( - eq(schema.keysRoles.workspaceId, workspace.id), + eq(schema.keysRoles.workspaceId, ctx.workspace.id), eq(schema.keysRoles.roleId, input.roleId), eq(schema.keysRoles.keyId, input.keyId), ), ); - await insertAuditLogs(tx, { - workspaceId: workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "authorization.disconnect_role_and_key", description: `Disconnect role ${input.roleId} from ${input.keyId}`, diff --git a/apps/dashboard/lib/trpc/routers/rbac/removePermissionFromRootKey.ts b/apps/dashboard/lib/trpc/routers/rbac/removePermissionFromRootKey.ts index 32aff38da2..734f027ca1 100644 --- a/apps/dashboard/lib/trpc/routers/rbac/removePermissionFromRootKey.ts +++ b/apps/dashboard/lib/trpc/routers/rbac/removePermissionFromRootKey.ts @@ -12,33 +12,12 @@ export const removePermissionFromRootKey = t.procedure }), ) .mutation(async ({ input, ctx }) => { - const workspace = await db.query.workspaces - .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.tenantId, ctx.tenant.id), isNull(table.deletedAt)), - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: - "We are unable to remove permission from the root key. Please try again or contact support@unkey.dev", - }); - }); - - if (!workspace) { - throw new TRPCError({ - code: "NOT_FOUND", - message: - "We are unable to find the correct workspace. Please try again or contact support@unkey.dev.", - }); - } - await db .transaction(async (tx) => { const key = await tx.query.keys.findFirst({ where: (table, { and, eq, isNull }) => and( - eq(schema.keys.forWorkspaceId, workspace.id), + eq(schema.keys.forWorkspaceId, ctx.workspace.id), eq(schema.keys.id, input.rootKeyId), isNull(table.deletedAt), ), @@ -77,8 +56,8 @@ export const removePermissionFromRootKey = t.procedure eq(schema.keysPermissions.permissionId, permissionRelation.permissionId), ), ); - await insertAuditLogs(tx, { - workspaceId: workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "authorization.disconnect_permission_and_key", description: `Disconnect ${input.permissionName} from ${input.rootKeyId}`, diff --git a/apps/dashboard/lib/trpc/routers/rbac/updatePermission.ts b/apps/dashboard/lib/trpc/routers/rbac/updatePermission.ts index 9e77e15f0e..0b08134e34 100644 --- a/apps/dashboard/lib/trpc/routers/rbac/updatePermission.ts +++ b/apps/dashboard/lib/trpc/routers/rbac/updatePermission.ts @@ -21,32 +21,12 @@ export const updatePermission = t.procedure }), ) .mutation(async ({ input, ctx }) => { - const workspace = await db.query.workspaces - .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.tenantId, ctx.tenant.id), isNull(table.deletedAt)), - with: { - permissions: { - where: (table, { eq }) => eq(table.id, input.id), - }, - }, - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: - "We are unable to update permission. Please try again or contact support@unkey.dev", - }); - }); + const permission = await db.query.permissions.findFirst({ + where: (table, { and, eq }) => + and(eq(table.workspaceId, ctx.workspace.id), eq(table.id, input.id)), + }); - if (!workspace) { - throw new TRPCError({ - code: "NOT_FOUND", - message: - "We are unable to find the correct workspace. Please try again or contact support@unkey.dev.", - }); - } - if (workspace.permissions.length === 0) { + if (!permission) { throw new TRPCError({ code: "NOT_FOUND", message: @@ -63,9 +43,9 @@ export const updatePermission = t.procedure description: input.description, updatedAt: new Date(), }) - .where(eq(schema.permissions.id, input.id)); - await insertAuditLogs(tx, { - workspaceId: workspace.id, + .where(eq(schema.permissions.id, permission.id)); + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "permission.update", description: `Update permission ${input.id}`, diff --git a/apps/dashboard/lib/trpc/routers/rbac/updateRole.ts b/apps/dashboard/lib/trpc/routers/rbac/updateRole.ts index 0843d7f406..570758877f 100644 --- a/apps/dashboard/lib/trpc/routers/rbac/updateRole.ts +++ b/apps/dashboard/lib/trpc/routers/rbac/updateRole.ts @@ -21,33 +21,12 @@ export const updateRole = t.procedure }), ) .mutation(async ({ input, ctx }) => { - const workspace = await db.query.workspaces - .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.tenantId, ctx.tenant.id), isNull(table.deletedAt)), - with: { - roles: { - where: (table, { eq }) => eq(table.id, input.id), - }, - }, - }) - .catch((err) => { - console.error(err); - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: - "We are unable to update the role. Please try again or contact support@unkey.dev", - }); - }); + const role = await db.query.roles.findFirst({ + where: (table, { and, eq }) => + and(eq(table.workspaceId, ctx.workspace.id), eq(table.id, input.id)), + }); - if (!workspace) { - throw new TRPCError({ - code: "NOT_FOUND", - message: - "We are unable to find the correct workspace. Please try again or contact support@unkey.dev.", - }); - } - if (workspace.roles.length === 0) { + if (!role) { throw new TRPCError({ code: "NOT_FOUND", message: @@ -56,9 +35,9 @@ export const updateRole = t.procedure } await db .transaction(async (tx) => { - await tx.update(schema.roles).set(input).where(eq(schema.roles.id, input.id)); - await insertAuditLogs(tx, { - workspaceId: workspace.id, + await tx.update(schema.roles).set(input).where(eq(schema.roles.id, role.id)); + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "role.update", description: `Updated role ${input.id}`, diff --git a/apps/dashboard/lib/trpc/routers/rbac/upsertPermission.ts b/apps/dashboard/lib/trpc/routers/rbac/upsertPermission.ts index fdd06c02a5..7fb1435dde 100644 --- a/apps/dashboard/lib/trpc/routers/rbac/upsertPermission.ts +++ b/apps/dashboard/lib/trpc/routers/rbac/upsertPermission.ts @@ -4,16 +4,12 @@ import { TRPCError } from "@trpc/server"; import { newId } from "@unkey/id"; import type { Context } from "../../context"; -export async function upsertPermission( - ctx: Context, - workspaceId: string, - name: string, -): Promise { +export async function upsertPermission(ctx: Context, name: string): Promise { return await db.transaction(async (tx) => { const existingPermission = await tx.query.permissions .findFirst({ where: (table, { and, eq }) => - and(eq(table.workspaceId, workspaceId), eq(table.name, name)), + and(eq(table.workspaceId, ctx.workspace!.id), eq(table.name, name)), }) .catch((_err) => { throw new TRPCError({ @@ -28,7 +24,7 @@ export async function upsertPermission( const permission: Permission = { id: newId("permission"), - workspaceId, + workspaceId: ctx.workspace!.id, name, description: null, createdAt: new Date(), @@ -45,8 +41,8 @@ export async function upsertPermission( "We are unable to upsert the permission. Please try again or contact support@unkey.dev.", }); }); - await insertAuditLogs(tx, { - workspaceId, + await insertAuditLogs(tx, ctx.workspace!.auditLogBucket.id, { + workspaceId: ctx.workspace!.id, actor: { type: "user", id: ctx.user!.id }, event: "permission.create", description: `Created ${permission.id}`, diff --git a/apps/dashboard/lib/trpc/routers/vercel.ts b/apps/dashboard/lib/trpc/routers/vercel.ts index de7c37cd04..b9f5ffd667 100644 --- a/apps/dashboard/lib/trpc/routers/vercel.ts +++ b/apps/dashboard/lib/trpc/routers/vercel.ts @@ -73,7 +73,7 @@ export const vercelRouter = t.router({ remaining: null, deletedAt: null, }); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: integration.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "key.create", @@ -123,7 +123,7 @@ export const vercelRouter = t.router({ workspaceId: integration.workspace.id, integrationId: integration.id, }); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: integration.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "vercelBinding.create", @@ -174,7 +174,7 @@ export const vercelRouter = t.router({ workspaceId: integration.workspace.id, integrationId: integration.id, }); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: integration.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "vercelBinding.create", @@ -254,7 +254,7 @@ export const vercelRouter = t.router({ lastEditedBy: ctx.user.id, }) .where(eq(schema.vercelBindings.id, existingBinding.id)); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: integration.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "vercelBinding.update", @@ -294,7 +294,7 @@ export const vercelRouter = t.router({ workspaceId: integration.workspace.id, integrationId: integration.id, }); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: integration.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "vercelBinding.create", @@ -374,7 +374,7 @@ export const vercelRouter = t.router({ remaining: null, deletedAt: null, }); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: integration.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "key.create", @@ -419,7 +419,7 @@ export const vercelRouter = t.router({ lastEditedBy: ctx.user.id, }) .where(eq(schema.vercelBindings.id, existingBinding.id)); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: integration.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "vercelBinding.update", @@ -460,7 +460,7 @@ export const vercelRouter = t.router({ integrationId: integration.id, }); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: integration.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "vercelBinding.create", @@ -527,7 +527,7 @@ export const vercelRouter = t.router({ .update(schema.vercelBindings) .set({ deletedAt: new Date() }) .where(eq(schema.vercelBindings.id, binding.id)); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: binding.vercelIntegrations.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "vercelBinding.delete", @@ -584,7 +584,7 @@ export const vercelRouter = t.router({ .update(schema.vercelBindings) .set({ deletedAt: new Date() }) .where(eq(schema.vercelBindings.id, binding.id)); - await insertAuditLogs(tx, { + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { workspaceId: integration.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "vercelBinding.delete", diff --git a/apps/dashboard/lib/trpc/routers/workspace/changeName.ts b/apps/dashboard/lib/trpc/routers/workspace/changeName.ts index 67e1ad6206..4681bb3e90 100644 --- a/apps/dashboard/lib/trpc/routers/workspace/changeName.ts +++ b/apps/dashboard/lib/trpc/routers/workspace/changeName.ts @@ -13,21 +13,6 @@ export const changeWorkspaceName = t.procedure }), ) .mutation(async ({ ctx, input }) => { - const ws = await db.query.workspaces - .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.id, input.workspaceId), isNull(table.deletedAt)), - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: - "We are unable to update the workspace name. Please try again or contact support@unkey.dev", - }); - }); - if (!ws || ws.tenantId !== ctx.tenant.id) { - throw new Error("Workspace not found, Please sign back in and try again"); - } await db .transaction(async (tx) => { await tx @@ -43,15 +28,15 @@ export const changeWorkspaceName = t.procedure "We are unable to update the workspace name. Please try again or contact support@unkey.dev", }); }); - await insertAuditLogs(tx, { - workspaceId: ws.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "workspace.update", - description: `Changed name from ${ws.name} to ${input.name}`, + description: `Changed name from ${ctx.workspace.name} to ${input.name}`, resources: [ { type: "workspace", - id: ws.id, + id: ctx.workspace.id, }, ], context: { diff --git a/apps/dashboard/lib/trpc/routers/workspace/changePlan.ts b/apps/dashboard/lib/trpc/routers/workspace/changePlan.ts index bf8af62d2f..0e5cfe30c5 100644 --- a/apps/dashboard/lib/trpc/routers/workspace/changePlan.ts +++ b/apps/dashboard/lib/trpc/routers/workspace/changePlan.ts @@ -26,37 +26,13 @@ export const changeWorkspacePlan = t.procedure apiVersion: "2023-10-16", typescript: true, }); - const workspace = await db.query.workspaces - .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.id, input.workspaceId), isNull(table.deletedAt)), - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: "We are unable to change plans. Please try again or contact support@unkey.dev", - }); - }); - if (!workspace) { - throw new TRPCError({ - code: "NOT_FOUND", - message: "Workspace not found, Please try again or contact support@unkey.dev.", - }); - } - if (workspace.tenantId !== ctx.tenant.id) { - throw new TRPCError({ - code: "UNAUTHORIZED", - message: - "You do not have permission to modify this workspace. Please speak to your organization's administrator.", - }); - } const now = new Date(); if ( - workspace.planChanged && - workspace.planChanged.getUTCFullYear() === now.getUTCFullYear() && - workspace.planChanged.getUTCMonth() === now.getUTCMonth() + ctx.workspace.planChanged && + ctx.workspace.planChanged.getUTCFullYear() === now.getUTCFullYear() && + ctx.workspace.planChanged.getUTCMonth() === now.getUTCMonth() ) { throw new TRPCError({ code: "PRECONDITION_FAILED", @@ -65,8 +41,8 @@ export const changeWorkspacePlan = t.procedure }); } - if (workspace.plan === input.plan) { - if (workspace.planDowngradeRequest) { + if (ctx.workspace.plan === input.plan) { + if (ctx.workspace.planDowngradeRequest) { // The user wants to resubscribe await db .transaction(async (tx) => { @@ -77,15 +53,15 @@ export const changeWorkspacePlan = t.procedure }) .where(eq(schema.workspaces.id, input.workspaceId)); - await insertAuditLogs(tx, { - workspaceId: workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "workspace.update", description: "Removed downgrade request", resources: [ { type: "workspace", - id: workspace.id, + id: ctx.workspace.id, }, ], context: { @@ -121,15 +97,15 @@ export const changeWorkspacePlan = t.procedure planDowngradeRequest: "free", }) .where(eq(schema.workspaces.id, input.workspaceId)); - await insertAuditLogs(tx, { - workspaceId: workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "workspace.update", description: "Requested downgrade to 'free'", resources: [ { type: "workspace", - id: workspace.id, + id: ctx.workspace.id, }, ], context: { @@ -145,14 +121,14 @@ export const changeWorkspacePlan = t.procedure }; } case "pro": { - if (!workspace.stripeCustomerId) { + if (!ctx.workspace.stripeCustomerId) { throw new TRPCError({ code: "PRECONDITION_FAILED", message: "You do not have a payment method. Please add one before upgrading.", }); } const paymentMethods = await stripe.customers.listPaymentMethods( - workspace.stripeCustomerId, + ctx.workspace.stripeCustomerId, ); if (!paymentMethods || paymentMethods.data.length === 0) { throw new TRPCError({ @@ -171,15 +147,15 @@ export const changeWorkspacePlan = t.procedure planDowngradeRequest: null, }) .where(eq(schema.workspaces.id, input.workspaceId)); - await insertAuditLogs(tx, { - workspaceId: workspace.id, + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "workspace.update", description: "Changed plan to 'pro'", resources: [ { type: "workspace", - id: workspace.id, + id: ctx.workspace.id, }, ], context: { diff --git a/apps/dashboard/lib/trpc/routers/workspace/create.ts b/apps/dashboard/lib/trpc/routers/workspace/create.ts index cd158b026b..b38578064f 100644 --- a/apps/dashboard/lib/trpc/routers/workspace/create.ts +++ b/apps/dashboard/lib/trpc/routers/workspace/create.ts @@ -60,7 +60,7 @@ export const createWorkspace = t.procedure name: "unkey_mutations", deleteProtection: true, }); - await insertAuditLogs(tx, [ + await insertAuditLogs(tx, auditLogBucketId, [ { workspaceId: workspace.id, actor: { type: "user", id: ctx.user.id }, diff --git a/apps/dashboard/lib/trpc/routers/workspace/optIntoBeta.ts b/apps/dashboard/lib/trpc/routers/workspace/optIntoBeta.ts index ca4d9a39e0..2d16c01f6a 100644 --- a/apps/dashboard/lib/trpc/routers/workspace/optIntoBeta.ts +++ b/apps/dashboard/lib/trpc/routers/workspace/optIntoBeta.ts @@ -12,37 +12,13 @@ export const optWorkspaceIntoBeta = t.procedure }), ) .mutation(async ({ ctx, input }) => { - const workspace = await db.query.workspaces - .findFirst({ - where: (table, { and, eq, isNull }) => - and(eq(table.tenantId, ctx.tenant.id), isNull(table.deletedAt)), - }) - .catch((_err) => { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: - "We are unable opt you in to this beta feature. Please try again or contact support@unkey.dev", - }); - }); - - if (!workspace) { - throw new TRPCError({ - code: "NOT_FOUND", - message: "Workspace not found, Please try again or contact support@unkey.dev.", - }); - } - switch (input.feature) { case "rbac": { - workspace.betaFeatures.rbac = true; + ctx.workspace.betaFeatures.rbac = true; break; } case "identities": { - workspace.betaFeatures.identities = true; - break; - } - case "ratelimit": { - workspace.betaFeatures.ratelimit = true; + ctx.workspace.betaFeatures.identities = true; break; } } @@ -51,18 +27,18 @@ export const optWorkspaceIntoBeta = t.procedure await tx .update(schema.workspaces) .set({ - betaFeatures: workspace.betaFeatures, + betaFeatures: ctx.workspace.betaFeatures, }) - .where(eq(schema.workspaces.id, workspace.id)); - await insertAuditLogs(tx, { - workspaceId: workspace.id, + .where(eq(schema.workspaces.id, ctx.workspace.id)); + await insertAuditLogs(tx, ctx.workspace.auditLogBucket.id, { + workspaceId: ctx.workspace.id, actor: { type: "user", id: ctx.user.id }, event: "workspace.opt_in", - description: `Opted ${workspace.id} into beta: ${input.feature}`, + description: `Opted ${ctx.workspace.id} into beta: ${input.feature}`, resources: [ { type: "workspace", - id: workspace.id, + id: ctx.workspace.id, }, ], context: { diff --git a/apps/dashboard/lib/trpc/trpc.ts b/apps/dashboard/lib/trpc/trpc.ts index 4d78fd11c6..b9140bda00 100644 --- a/apps/dashboard/lib/trpc/trpc.ts +++ b/apps/dashboard/lib/trpc/trpc.ts @@ -9,9 +9,13 @@ export const auth = t.middleware(({ next, ctx }) => { if (!ctx.user?.id) { throw new TRPCError({ code: "UNAUTHORIZED" }); } + if (!ctx.workspace) { + throw new TRPCError({ code: "NOT_FOUND", message: "workspace not found in context" }); + } return next({ ctx: { + workspace: ctx.workspace, user: ctx.user, tenant: ctx.tenant ?? { id: ctx.user.id, role: "owner" }, }, diff --git a/go/buf.gen.yaml b/go/buf.gen.yaml new file mode 100644 index 0000000000..240a7ace7e --- /dev/null +++ b/go/buf.gen.yaml @@ -0,0 +1,10 @@ +version: v2 +plugins: + - remote: buf.build/protocolbuffers/go + out: gen + opt: paths=source_relative + - remote: buf.build/connectrpc/go:v1.16.2 + out: gen + opt: paths=source_relative + + \ No newline at end of file diff --git a/go/buf.yaml b/go/buf.yaml new file mode 100644 index 0000000000..c42a77e898 --- /dev/null +++ b/go/buf.yaml @@ -0,0 +1,10 @@ +version: v1 +breaking: + use: + - FILE + - PACKAGE + - WIRE + - WIRE_JSON +lint: + use: + - STANDARD diff --git a/go/cmd/api/routes/v2_liveness/handler.go b/go/cmd/api/routes/v2_liveness/handler.go new file mode 100644 index 0000000000..31bfb5ee51 --- /dev/null +++ b/go/cmd/api/routes/v2_liveness/handler.go @@ -0,0 +1,20 @@ +package v2Liveness + +import ( + "net/http" + + openapi "github.com/unkeyed/unkey/go/api" + zen "github.com/unkeyed/unkey/go/pkg/zen" +) + +type Response = openapi.V2LivenessResponseBody + +func New() zen.Route { + return zen.NewRoute("GET", "/v2/liveness", func(s *zen.Session) error { + + res := Response{ + Message: "we're cooking", + } + return s.JSON(http.StatusOK, res) + }) +} diff --git a/go/gen/proto/ratelimit/v1/ratelimitv1connect/service.connect.go b/go/gen/proto/ratelimit/v1/ratelimitv1connect/service.connect.go new file mode 100644 index 0000000000..717ab9981e --- /dev/null +++ b/go/gen/proto/ratelimit/v1/ratelimitv1connect/service.connect.go @@ -0,0 +1,146 @@ +// Code generated by protoc-gen-connect-go. DO NOT EDIT. +// +// Source: proto/ratelimit/v1/service.proto + +package ratelimitv1connect + +import ( + connect "connectrpc.com/connect" + context "context" + errors "errors" + v1 "github.com/unkeyed/unkey/apps/agent/gen/proto/ratelimit/v1" + http "net/http" + strings "strings" +) + +// This is a compile-time assertion to ensure that this generated file and the connect package are +// compatible. If you get a compiler error that this constant is not defined, this code was +// generated with a version of connect newer than the one compiled into your binary. You can fix the +// problem by either regenerating this code with an older version of connect or updating the connect +// version compiled into your binary. +const _ = connect.IsAtLeastVersion1_13_0 + +const ( + // RatelimitServiceName is the fully-qualified name of the RatelimitService service. + RatelimitServiceName = "ratelimit.v1.RatelimitService" +) + +// These constants are the fully-qualified names of the RPCs defined in this package. They're +// exposed at runtime as Spec.Procedure and as the final two segments of the HTTP route. +// +// Note that these are different from the fully-qualified method names used by +// google.golang.org/protobuf/reflect/protoreflect. To convert from these constants to +// reflection-formatted method names, remove the leading slash and convert the remaining slash to a +// period. +const ( + // RatelimitServiceReplayProcedure is the fully-qualified name of the RatelimitService's Replay RPC. + RatelimitServiceReplayProcedure = "/ratelimit.v1.RatelimitService/Replay" +) + +// These variables are the protoreflect.Descriptor objects for the RPCs defined in this package. +var ( + ratelimitServiceServiceDescriptor = v1.File_proto_ratelimit_v1_service_proto.Services().ByName("RatelimitService") + ratelimitServiceReplayMethodDescriptor = ratelimitServiceServiceDescriptor.Methods().ByName("Replay") +) + +// RatelimitServiceClient is a client for the ratelimit.v1.RatelimitService service. +type RatelimitServiceClient interface { + // Replay synchronizes rate limit state between nodes using consistent hashing. + // + // Key behaviors: + // - Each identifier maps to exactly one origin server via consistent hashing + // - Edge nodes replay their local rate limit decisions to the origin + // - Origin maintains the source of truth for rate limit state + // - Edge nodes must update their state based on origin responses + // + // Flow: + // 1. Edge node receives rate limit request + // 2. Edge makes local decision (may be defensive) + // 3. Edge replays decision to origin + // 4. Origin processes and returns authoritative state + // 5. Edge updates local state and returns result to client + // + // This approach ensures eventual consistency while allowing for + // fast local decisions at the edge. + Replay(context.Context, *connect.Request[v1.ReplayRequest]) (*connect.Response[v1.ReplayResponse], error) +} + +// NewRatelimitServiceClient constructs a client for the ratelimit.v1.RatelimitService service. By +// default, it uses the Connect protocol with the binary Protobuf Codec, asks for gzipped responses, +// and sends uncompressed requests. To use the gRPC or gRPC-Web protocols, supply the +// connect.WithGRPC() or connect.WithGRPCWeb() options. +// +// The URL supplied here should be the base URL for the Connect or gRPC server (for example, +// http://api.acme.com or https://acme.com/grpc). +func NewRatelimitServiceClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) RatelimitServiceClient { + baseURL = strings.TrimRight(baseURL, "/") + return &ratelimitServiceClient{ + replay: connect.NewClient[v1.ReplayRequest, v1.ReplayResponse]( + httpClient, + baseURL+RatelimitServiceReplayProcedure, + connect.WithSchema(ratelimitServiceReplayMethodDescriptor), + connect.WithClientOptions(opts...), + ), + } +} + +// ratelimitServiceClient implements RatelimitServiceClient. +type ratelimitServiceClient struct { + replay *connect.Client[v1.ReplayRequest, v1.ReplayResponse] +} + +// Replay calls ratelimit.v1.RatelimitService.Replay. +func (c *ratelimitServiceClient) Replay(ctx context.Context, req *connect.Request[v1.ReplayRequest]) (*connect.Response[v1.ReplayResponse], error) { + return c.replay.CallUnary(ctx, req) +} + +// RatelimitServiceHandler is an implementation of the ratelimit.v1.RatelimitService service. +type RatelimitServiceHandler interface { + // Replay synchronizes rate limit state between nodes using consistent hashing. + // + // Key behaviors: + // - Each identifier maps to exactly one origin server via consistent hashing + // - Edge nodes replay their local rate limit decisions to the origin + // - Origin maintains the source of truth for rate limit state + // - Edge nodes must update their state based on origin responses + // + // Flow: + // 1. Edge node receives rate limit request + // 2. Edge makes local decision (may be defensive) + // 3. Edge replays decision to origin + // 4. Origin processes and returns authoritative state + // 5. Edge updates local state and returns result to client + // + // This approach ensures eventual consistency while allowing for + // fast local decisions at the edge. + Replay(context.Context, *connect.Request[v1.ReplayRequest]) (*connect.Response[v1.ReplayResponse], error) +} + +// NewRatelimitServiceHandler builds an HTTP handler from the service implementation. It returns the +// path on which to mount the handler and the handler itself. +// +// By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf +// and JSON codecs. They also support gzip compression. +func NewRatelimitServiceHandler(svc RatelimitServiceHandler, opts ...connect.HandlerOption) (string, http.Handler) { + ratelimitServiceReplayHandler := connect.NewUnaryHandler( + RatelimitServiceReplayProcedure, + svc.Replay, + connect.WithSchema(ratelimitServiceReplayMethodDescriptor), + connect.WithHandlerOptions(opts...), + ) + return "/ratelimit.v1.RatelimitService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case RatelimitServiceReplayProcedure: + ratelimitServiceReplayHandler.ServeHTTP(w, r) + default: + http.NotFound(w, r) + } + }) +} + +// UnimplementedRatelimitServiceHandler returns CodeUnimplemented from all methods. +type UnimplementedRatelimitServiceHandler struct{} + +func (UnimplementedRatelimitServiceHandler) Replay(context.Context, *connect.Request[v1.ReplayRequest]) (*connect.Response[v1.ReplayResponse], error) { + return nil, connect.NewError(connect.CodeUnimplemented, errors.New("ratelimit.v1.RatelimitService.Replay is not implemented")) +} diff --git a/go/gen/proto/ratelimit/v1/service.pb.go b/go/gen/proto/ratelimit/v1/service.pb.go new file mode 100644 index 0000000000..d4ceb641c6 --- /dev/null +++ b/go/gen/proto/ratelimit/v1/service.pb.go @@ -0,0 +1,548 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.4 +// protoc (unknown) +// source: proto/ratelimit/v1/service.proto + +package ratelimitv1 + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// RatelimitRequest represents a request to check or consume rate limit tokens. +// This is typically the first point of contact when a client wants to verify +// if they are allowed to perform an action under the rate limit constraints. +type RatelimitRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Unique identifier for the rate limit subject. + // This could be: + // - A user ID + // - An API key + // - An IP address + // - Any other unique identifier that needs rate limiting + Identifier string `protobuf:"bytes,1,opt,name=identifier,proto3" json:"identifier,omitempty"` + // Maximum number of tokens allowed within the duration. + // Once this limit is reached, subsequent requests will be denied + // until there is more capacity. + Limit int64 `protobuf:"varint,2,opt,name=limit,proto3" json:"limit,omitempty"` + // Duration of the rate limit window in milliseconds. + // After this duration, a new window begins. + // Common values might be: + // - 1000 (1 second) + // - 60000 (1 minute) + // - 3600000 (1 hour) + Duration int64 `protobuf:"varint,3,opt,name=duration,proto3" json:"duration,omitempty"` + // Number of tokens to consume in this request. + // Defaults to 1 if not specified. + // Higher values can be used for operations that should count more heavily + // against the rate limit (e.g., batch operations). + Cost *int64 `protobuf:"varint,4,opt,name=cost,proto3,oneof" json:"cost,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RatelimitRequest) Reset() { + *x = RatelimitRequest{} + mi := &file_proto_ratelimit_v1_service_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RatelimitRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RatelimitRequest) ProtoMessage() {} + +func (x *RatelimitRequest) ProtoReflect() protoreflect.Message { + mi := &file_proto_ratelimit_v1_service_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RatelimitRequest.ProtoReflect.Descriptor instead. +func (*RatelimitRequest) Descriptor() ([]byte, []int) { + return file_proto_ratelimit_v1_service_proto_rawDescGZIP(), []int{0} +} + +func (x *RatelimitRequest) GetIdentifier() string { + if x != nil { + return x.Identifier + } + return "" +} + +func (x *RatelimitRequest) GetLimit() int64 { + if x != nil { + return x.Limit + } + return 0 +} + +func (x *RatelimitRequest) GetDuration() int64 { + if x != nil { + return x.Duration + } + return 0 +} + +func (x *RatelimitRequest) GetCost() int64 { + if x != nil && x.Cost != nil { + return *x.Cost + } + return 0 +} + +// RatelimitResponse contains the result of a rate limit check. +// This response includes all necessary information for clients to understand +// their current rate limit status and when they can retry if limited. +type RatelimitResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Total limit configured for this window. + // This matches the limit specified in the request and is included + // for convenience in client implementations. + Limit int64 `protobuf:"varint,1,opt,name=limit,proto3" json:"limit,omitempty"` + // Number of tokens remaining in the current window. + // Clients can use this to implement progressive backoff or + // warn users when they're close to their limit. + Remaining int64 `protobuf:"varint,2,opt,name=remaining,proto3" json:"remaining,omitempty"` + // Unix timestamp (in milliseconds) when the current window expires. + // Clients can use this to: + // - Display time until reset to users + // - Implement automatic retry after window reset + // - Schedule future requests optimally + Reset_ int64 `protobuf:"varint,3,opt,name=reset,proto3" json:"reset,omitempty"` + // Whether the rate limit check was successful. + // true = request is allowed + // false = request is denied due to rate limit exceeded + Success bool `protobuf:"varint,4,opt,name=success,proto3" json:"success,omitempty"` + // Current token count in this window. + // This represents how many tokens have been consumed so far, + // useful for monitoring and debugging purposes. + Current int64 `protobuf:"varint,5,opt,name=current,proto3" json:"current,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RatelimitResponse) Reset() { + *x = RatelimitResponse{} + mi := &file_proto_ratelimit_v1_service_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RatelimitResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RatelimitResponse) ProtoMessage() {} + +func (x *RatelimitResponse) ProtoReflect() protoreflect.Message { + mi := &file_proto_ratelimit_v1_service_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RatelimitResponse.ProtoReflect.Descriptor instead. +func (*RatelimitResponse) Descriptor() ([]byte, []int) { + return file_proto_ratelimit_v1_service_proto_rawDescGZIP(), []int{1} +} + +func (x *RatelimitResponse) GetLimit() int64 { + if x != nil { + return x.Limit + } + return 0 +} + +func (x *RatelimitResponse) GetRemaining() int64 { + if x != nil { + return x.Remaining + } + return 0 +} + +func (x *RatelimitResponse) GetReset_() int64 { + if x != nil { + return x.Reset_ + } + return 0 +} + +func (x *RatelimitResponse) GetSuccess() bool { + if x != nil { + return x.Success + } + return false +} + +func (x *RatelimitResponse) GetCurrent() int64 { + if x != nil { + return x.Current + } + return 0 +} + +// Window represents a rate limiting time window with its state. +// The system uses a sliding window approach to provide smooth +// rate limiting behavior across window boundaries. +type Window struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Monotonically increasing sequence number for window ordering. + // The sequence is calculated as follows: + // sequence = time.Now().UnixMilli() % duration + Sequence int64 `protobuf:"varint,1,opt,name=sequence,proto3" json:"sequence,omitempty"` + // Duration of the window in milliseconds. + // This matches the duration from the original request and defines + // how long this window remains active. + Duration int64 `protobuf:"varint,2,opt,name=duration,proto3" json:"duration,omitempty"` + // Current token count in this window. + // This is the actual count of tokens consumed during this window's + // lifetime. It must never exceed the configured limit. + Counter int64 `protobuf:"varint,3,opt,name=counter,proto3" json:"counter,omitempty"` + // Start time of the window (Unix timestamp in milliseconds). + // Used to: + // - Calculate window expiration + // - Determine if a window is still active + // - Handle sliding window calculations between current and previous windows + Start int64 `protobuf:"varint,4,opt,name=start,proto3" json:"start,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Window) Reset() { + *x = Window{} + mi := &file_proto_ratelimit_v1_service_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Window) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Window) ProtoMessage() {} + +func (x *Window) ProtoReflect() protoreflect.Message { + mi := &file_proto_ratelimit_v1_service_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Window.ProtoReflect.Descriptor instead. +func (*Window) Descriptor() ([]byte, []int) { + return file_proto_ratelimit_v1_service_proto_rawDescGZIP(), []int{2} +} + +func (x *Window) GetSequence() int64 { + if x != nil { + return x.Sequence + } + return 0 +} + +func (x *Window) GetDuration() int64 { + if x != nil { + return x.Duration + } + return 0 +} + +func (x *Window) GetCounter() int64 { + if x != nil { + return x.Counter + } + return 0 +} + +func (x *Window) GetStart() int64 { + if x != nil { + return x.Start + } + return 0 +} + +// ReplayRequest is used to synchronize rate limit state between nodes. +// This is a crucial part of maintaining consistency in a distributed +// rate limiting system. +type ReplayRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Original rate limit request that triggered the replay. + // Contains all the parameters needed to evaluate the rate limit + // on the origin server. + Request *RatelimitRequest `protobuf:"bytes,1,opt,name=request,proto3" json:"request,omitempty"` + // Indicates if the edge node denied the request. + // When false: The origin must increment the counter regardless of its own evaluation + // When true: The origin can evaluate the request fresh + // This field is crucial for maintaining consistency when edge nodes + // make defensive denials due to network issues or uncertainty. + Denied bool `protobuf:"varint,2,opt,name=denied,proto3" json:"denied,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ReplayRequest) Reset() { + *x = ReplayRequest{} + mi := &file_proto_ratelimit_v1_service_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ReplayRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ReplayRequest) ProtoMessage() {} + +func (x *ReplayRequest) ProtoReflect() protoreflect.Message { + mi := &file_proto_ratelimit_v1_service_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ReplayRequest.ProtoReflect.Descriptor instead. +func (*ReplayRequest) Descriptor() ([]byte, []int) { + return file_proto_ratelimit_v1_service_proto_rawDescGZIP(), []int{3} +} + +func (x *ReplayRequest) GetRequest() *RatelimitRequest { + if x != nil { + return x.Request + } + return nil +} + +func (x *ReplayRequest) GetDenied() bool { + if x != nil { + return x.Denied + } + return false +} + +// ReplayResponse contains the synchronized rate limit state that +// should be used to update both the origin and edge nodes. +type ReplayResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Current active window state. + // This represents the authoritative state of the current window + // as determined by the origin server. + Current *Window `protobuf:"bytes,1,opt,name=current,proto3" json:"current,omitempty"` + // Previous window state for sliding window calculations. + // Used to smooth out rate limiting across window boundaries and + // prevent sharp cliffs in availability during window transitions. + Previous *Window `protobuf:"bytes,2,opt,name=previous,proto3" json:"previous,omitempty"` + // Rate limit response that should be used by the edge node. + // This is the authoritative response that should be returned to + // the client and used to update edge state. + Response *RatelimitResponse `protobuf:"bytes,3,opt,name=response,proto3" json:"response,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ReplayResponse) Reset() { + *x = ReplayResponse{} + mi := &file_proto_ratelimit_v1_service_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ReplayResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ReplayResponse) ProtoMessage() {} + +func (x *ReplayResponse) ProtoReflect() protoreflect.Message { + mi := &file_proto_ratelimit_v1_service_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ReplayResponse.ProtoReflect.Descriptor instead. +func (*ReplayResponse) Descriptor() ([]byte, []int) { + return file_proto_ratelimit_v1_service_proto_rawDescGZIP(), []int{4} +} + +func (x *ReplayResponse) GetCurrent() *Window { + if x != nil { + return x.Current + } + return nil +} + +func (x *ReplayResponse) GetPrevious() *Window { + if x != nil { + return x.Previous + } + return nil +} + +func (x *ReplayResponse) GetResponse() *RatelimitResponse { + if x != nil { + return x.Response + } + return nil +} + +var File_proto_ratelimit_v1_service_proto protoreflect.FileDescriptor + +var file_proto_ratelimit_v1_service_proto_rawDesc = string([]byte{ + 0x0a, 0x20, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x72, 0x61, 0x74, 0x65, 0x6c, 0x69, 0x6d, 0x69, + 0x74, 0x2f, 0x76, 0x31, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x12, 0x0c, 0x72, 0x61, 0x74, 0x65, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x2e, 0x76, 0x31, + 0x22, 0x86, 0x01, 0x0a, 0x10, 0x52, 0x61, 0x74, 0x65, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x66, + 0x69, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x69, 0x64, 0x65, 0x6e, 0x74, + 0x69, 0x66, 0x69, 0x65, 0x72, 0x12, 0x14, 0x0a, 0x05, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x64, + 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x64, + 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x17, 0x0a, 0x04, 0x63, 0x6f, 0x73, 0x74, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x03, 0x48, 0x00, 0x52, 0x04, 0x63, 0x6f, 0x73, 0x74, 0x88, 0x01, 0x01, + 0x42, 0x07, 0x0a, 0x05, 0x5f, 0x63, 0x6f, 0x73, 0x74, 0x22, 0x91, 0x01, 0x0a, 0x11, 0x52, 0x61, + 0x74, 0x65, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, + 0x14, 0x0a, 0x05, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, + 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x72, 0x65, 0x6d, 0x61, 0x69, 0x6e, 0x69, + 0x6e, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x09, 0x72, 0x65, 0x6d, 0x61, 0x69, 0x6e, + 0x69, 0x6e, 0x67, 0x12, 0x14, 0x0a, 0x05, 0x72, 0x65, 0x73, 0x65, 0x74, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x03, 0x52, 0x05, 0x72, 0x65, 0x73, 0x65, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, + 0x63, 0x65, 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, + 0x65, 0x73, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x75, 0x72, 0x72, 0x65, 0x6e, 0x74, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x63, 0x75, 0x72, 0x72, 0x65, 0x6e, 0x74, 0x22, 0x70, 0x0a, + 0x06, 0x57, 0x69, 0x6e, 0x64, 0x6f, 0x77, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x71, 0x75, 0x65, + 0x6e, 0x63, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x73, 0x65, 0x71, 0x75, 0x65, + 0x6e, 0x63, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, + 0x18, 0x0a, 0x07, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x07, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, + 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x22, + 0x61, 0x0a, 0x0d, 0x52, 0x65, 0x70, 0x6c, 0x61, 0x79, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x12, 0x38, 0x0a, 0x07, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x1e, 0x2e, 0x72, 0x61, 0x74, 0x65, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x2e, 0x76, 0x31, + 0x2e, 0x52, 0x61, 0x74, 0x65, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x52, 0x07, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x65, + 0x6e, 0x69, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x64, 0x65, 0x6e, 0x69, + 0x65, 0x64, 0x22, 0xaf, 0x01, 0x0a, 0x0e, 0x52, 0x65, 0x70, 0x6c, 0x61, 0x79, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x2e, 0x0a, 0x07, 0x63, 0x75, 0x72, 0x72, 0x65, 0x6e, 0x74, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x72, 0x61, 0x74, 0x65, 0x6c, 0x69, 0x6d, + 0x69, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x69, 0x6e, 0x64, 0x6f, 0x77, 0x52, 0x07, 0x63, 0x75, + 0x72, 0x72, 0x65, 0x6e, 0x74, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x72, 0x65, 0x76, 0x69, 0x6f, 0x75, + 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x72, 0x61, 0x74, 0x65, 0x6c, 0x69, + 0x6d, 0x69, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x69, 0x6e, 0x64, 0x6f, 0x77, 0x52, 0x08, 0x70, + 0x72, 0x65, 0x76, 0x69, 0x6f, 0x75, 0x73, 0x12, 0x3b, 0x0a, 0x08, 0x72, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x72, 0x61, 0x74, 0x65, + 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x61, 0x74, 0x65, 0x6c, 0x69, 0x6d, + 0x69, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x52, 0x08, 0x72, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x32, 0x59, 0x0a, 0x10, 0x52, 0x61, 0x74, 0x65, 0x6c, 0x69, 0x6d, 0x69, + 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x06, 0x52, 0x65, 0x70, 0x6c, + 0x61, 0x79, 0x12, 0x1b, 0x2e, 0x72, 0x61, 0x74, 0x65, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x2e, 0x76, + 0x31, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x61, 0x79, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x1c, 0x2e, 0x72, 0x61, 0x74, 0x65, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x52, + 0x65, 0x70, 0x6c, 0x61, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, + 0x48, 0x5a, 0x46, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x75, 0x6e, + 0x6b, 0x65, 0x79, 0x65, 0x64, 0x2f, 0x75, 0x6e, 0x6b, 0x65, 0x79, 0x2f, 0x61, 0x70, 0x70, 0x73, + 0x2f, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x2f, 0x72, 0x61, 0x74, 0x65, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x2f, 0x76, 0x31, 0x3b, 0x72, 0x61, + 0x74, 0x65, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x33, +}) + +var ( + file_proto_ratelimit_v1_service_proto_rawDescOnce sync.Once + file_proto_ratelimit_v1_service_proto_rawDescData []byte +) + +func file_proto_ratelimit_v1_service_proto_rawDescGZIP() []byte { + file_proto_ratelimit_v1_service_proto_rawDescOnce.Do(func() { + file_proto_ratelimit_v1_service_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_proto_ratelimit_v1_service_proto_rawDesc), len(file_proto_ratelimit_v1_service_proto_rawDesc))) + }) + return file_proto_ratelimit_v1_service_proto_rawDescData +} + +var file_proto_ratelimit_v1_service_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_proto_ratelimit_v1_service_proto_goTypes = []any{ + (*RatelimitRequest)(nil), // 0: ratelimit.v1.RatelimitRequest + (*RatelimitResponse)(nil), // 1: ratelimit.v1.RatelimitResponse + (*Window)(nil), // 2: ratelimit.v1.Window + (*ReplayRequest)(nil), // 3: ratelimit.v1.ReplayRequest + (*ReplayResponse)(nil), // 4: ratelimit.v1.ReplayResponse +} +var file_proto_ratelimit_v1_service_proto_depIdxs = []int32{ + 0, // 0: ratelimit.v1.ReplayRequest.request:type_name -> ratelimit.v1.RatelimitRequest + 2, // 1: ratelimit.v1.ReplayResponse.current:type_name -> ratelimit.v1.Window + 2, // 2: ratelimit.v1.ReplayResponse.previous:type_name -> ratelimit.v1.Window + 1, // 3: ratelimit.v1.ReplayResponse.response:type_name -> ratelimit.v1.RatelimitResponse + 3, // 4: ratelimit.v1.RatelimitService.Replay:input_type -> ratelimit.v1.ReplayRequest + 4, // 5: ratelimit.v1.RatelimitService.Replay:output_type -> ratelimit.v1.ReplayResponse + 5, // [5:6] is the sub-list for method output_type + 4, // [4:5] is the sub-list for method input_type + 4, // [4:4] is the sub-list for extension type_name + 4, // [4:4] is the sub-list for extension extendee + 0, // [0:4] is the sub-list for field type_name +} + +func init() { file_proto_ratelimit_v1_service_proto_init() } +func file_proto_ratelimit_v1_service_proto_init() { + if File_proto_ratelimit_v1_service_proto != nil { + return + } + file_proto_ratelimit_v1_service_proto_msgTypes[0].OneofWrappers = []any{} + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_proto_ratelimit_v1_service_proto_rawDesc), len(file_proto_ratelimit_v1_service_proto_rawDesc)), + NumEnums: 0, + NumMessages: 5, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_proto_ratelimit_v1_service_proto_goTypes, + DependencyIndexes: file_proto_ratelimit_v1_service_proto_depIdxs, + MessageInfos: file_proto_ratelimit_v1_service_proto_msgTypes, + }.Build() + File_proto_ratelimit_v1_service_proto = out.File + file_proto_ratelimit_v1_service_proto_goTypes = nil + file_proto_ratelimit_v1_service_proto_depIdxs = nil +} diff --git a/go/internal/services/ratelimit/bucket.go b/go/internal/services/ratelimit/bucket.go new file mode 100644 index 0000000000..65dc62fbbf --- /dev/null +++ b/go/internal/services/ratelimit/bucket.go @@ -0,0 +1,85 @@ +package ratelimit + +import ( + "fmt" + "sync" + "time" + + ratelimitv1 "github.com/unkeyed/unkey/go/gen/proto/ratelimit/v1" +) + +// Generally there is one bucket per identifier. +// However if the same identifier is used with different config, such as limit +// or duration, there will be multiple buckets for the same identifier. +// +// A bucket is always uniquely identified by this triplet: identifier, limit, duration. +// See `bucketKey` for more details. +// +// A bucket reaches its lifetime when the last window has expired at least 1 * duration ago. +// In other words, we can remove a bucket when it is no longer relevant for +// ratelimit decisions. +type bucket struct { + sync.RWMutex + limit int64 + duration time.Duration + // sequence -> window + windows map[int64]*ratelimitv1.Window +} + +// bucketKey returns a unique key for an identifier and duration config +// the duration is required to ensure a change in ratelimit config will not +// reuse the same bucket and mess up the sequence numbers +type bucketKey struct { + identifier string + limit int64 + duration time.Duration +} + +func (b bucketKey) toString() string { + return fmt.Sprintf("%s-%d-%d", b.identifier, b.limit, b.duration.Milliseconds()) +} + +// getBucket returns a bucket for the given key and will create one if it does not exist. +// It returns the bucket and a boolean indicating if the bucket existed before. +func (s *service) getBucket(key bucketKey) (*bucket, bool) { + s.bucketsLock.RLock() + b, ok := s.buckets[key.toString()] + s.bucketsLock.RUnlock() + if !ok { + b = &bucket{ + limit: key.limit, + duration: key.duration, + windows: make(map[int64]*ratelimitv1.Window), + } + s.bucketsLock.Lock() + s.buckets[key.toString()] = b + s.bucketsLock.Unlock() + } + return b, ok +} + +// must be called while holding a lock on the bucket +func (b *bucket) getCurrentWindow(now time.Time) *ratelimitv1.Window { + sequence := calculateSequence(now, b.duration) + + w, ok := b.windows[sequence] + if !ok { + w = newWindow(sequence, now.Truncate(b.duration), b.duration) + b.windows[sequence] = w + } + + return w +} + +// must be called while holding a lock on the bucket +func (b *bucket) getPreviousWindow(now time.Time) *ratelimitv1.Window { + sequence := calculateSequence(now, b.duration) - 1 + + w, ok := b.windows[sequence] + if !ok { + w = newWindow(sequence, now.Add(-b.duration).Truncate(b.duration), b.duration) + b.windows[sequence] = w + } + + return w +} diff --git a/go/internal/services/ratelimit/interface.go b/go/internal/services/ratelimit/interface.go new file mode 100644 index 0000000000..de94c8d328 --- /dev/null +++ b/go/internal/services/ratelimit/interface.go @@ -0,0 +1,75 @@ +package ratelimit + +import ( + "context" + "time" +) + +type Service interface { + Ratelimit(context.Context, RatelimitRequest) (RatelimitResponse, error) +} + +// RatelimitRequest represents a request to check or consume rate limit tokens. +// This is typically the first point of contact when a client wants to verify +// if they are allowed to perform an action under the rate limit constraints. +type RatelimitRequest struct { + // Unique identifier for the rate limit subject. + // This could be: + // - A user ID + // - An API key + // - An IP address + // - Any other unique identifier that needs rate limiting + Identifier string + + // Maximum number of tokens allowed within the duration. + // Once this limit is reached, subsequent requests will be denied + // until there is more capacity. + Limit int64 + + // Duration of the rate limit window in milliseconds. + // After this duration, a new window begins. + // Common values might be: + // - 1000 (1 second) + // - 60000 (1 minute) + // - 3600000 (1 hour) + Duration time.Duration + + // Number of tokens to consume in this request. + // Defaults to 1 if not specified. + // Higher values can be used for operations that should count more heavily + // against the rate limit (e.g., batch operations). + Cost int64 +} + +// RatelimitResponse contains the result of a rate limit check. +// This response includes all necessary information for clients to understand +// their current rate limit status and when they can retry if limited. +type RatelimitResponse struct { + // Total limit configured for this window. + // This matches the limit specified in the request and is included + // for convenience in client implementations. + Limit int64 + + // Number of tokens remaining in the current window. + // Clients can use this to implement progressive backoff or + // warn users when they're close to their limit. + Remaining int64 + + // Unix timestamp (in milliseconds) when the current window expires. + // Clients can use this to: + // - Display time until reset to users + // - Implement automatic retry after window reset + // - Schedule future requests optimally + Reset int64 + + // Whether the rate limit check was successful. + // true = request is allowed + // false = request is denied due to rate limit exceeded + Success bool + + // Current token count in this window. + // This represents how many tokens have been consumed so far, + // useful for monitoring and debugging purposes. + Current int64 +} +type Middleware func(Service) Service diff --git a/go/internal/services/ratelimit/peer.go b/go/internal/services/ratelimit/peer.go new file mode 100644 index 0000000000..8bfb480363 --- /dev/null +++ b/go/internal/services/ratelimit/peer.go @@ -0,0 +1,42 @@ +package ratelimit + +import ( + "context" + "fmt" + "log/slog" + "net/http" + "strings" + + "connectrpc.com/connect" + "connectrpc.com/otelconnect" + "github.com/unkeyed/unkey/go/gen/proto/ratelimit/v1/ratelimitv1connect" + "github.com/unkeyed/unkey/go/pkg/tracing" +) + +func (s *service) getPeer(peerID string) (ratelimitv1connect.RatelimitServiceClient, error) { + s.peerMu.RLock() + defer s.peerMu.RUnlock() + + peer, ok := s.peers[peerID] + if !ok { + return nil, fmt.Errorf("peer not found") + } + return peer, nil + +} + +func (s *service) newPeer(id string, rpcAddr string) (ratelimitv1connect.RatelimitServiceClient, error) { + + if !strings.Contains(rpcAddr, "://") { + rpcAddr = "http://" + rpcAddr + } + + interceptor, err := otelconnect.NewInterceptor(otelconnect.WithTracerProvider(tracing.GetGlobalTraceProvider())) + if err != nil { + s.logger.Error(context.Background(), "failed to create interceptor", slog.String("error", err.Error())) + return nil, err + } + + c := ratelimitv1connect.NewRatelimitServiceClient(http.DefaultClient, rpcAddr, connect.WithInterceptors(interceptor)) + return c, nil +} diff --git a/go/internal/services/ratelimit/replay.go b/go/internal/services/ratelimit/replay.go new file mode 100644 index 0000000000..aff2429c84 --- /dev/null +++ b/go/internal/services/ratelimit/replay.go @@ -0,0 +1,86 @@ +package ratelimit + +import ( + "context" + "time" + + "connectrpc.com/connect" + "github.com/unkeyed/unkey/apps/agent/pkg/prometheus" + ratelimitv1 "github.com/unkeyed/unkey/go/gen/proto/ratelimit/v1" + "github.com/unkeyed/unkey/go/pkg/tracing" +) + +func (s *service) replay(req RatelimitRequest) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + ctx, span := tracing.Start(ctx, "ratelimit.replay") + defer span.End() + + now := s.clock.Now() + + key := bucketKey{req.Identifier, req.Limit, req.Duration}.toString() + client, peer, err := s.getPeerClient(ctx, key) + if err != nil { + tracing.RecordError(span, err) + s.logger.Warn().Err(err).Str("key", key).Msg("unable to create peer client") + return + } + if peer.Id == s.cluster.NodeId() { + return + } + + res, err := s.syncCircuitBreaker.Do(ctx, func(innerCtx context.Context) (*connect.Response[ratelimitv1.PushPullResponse], error) { + innerCtx, cancel = context.WithTimeout(innerCtx, 10*time.Second) + defer cancel() + return client.PushPull(innerCtx, connect.NewRequest(req.req)) + }) + if err != nil { + s.peersMu.Lock() + s.logger.Warn().Str("peerId", peer.Id).Err(err).Msg("resetting peer client due to error") + delete(s.peers, peer.Id) + s.peersMu.Unlock() + tracing.RecordError(span, err) + s.logger.Warn().Err(err).Str("peerId", peer.Id).Str("addr", peer.RpcAddr).Msg("failed to push pull") + return + } + + err = s.SetCounter(ctx, + setCounterRequest{ + Identifier: req.req.Request.Identifier, + Limit: req.req.Request.Limit, + Counter: res.Msg.Current.Counter, + Sequence: res.Msg.Current.Sequence, + Duration: duration, + Time: t, + }, + setCounterRequest{ + Identifier: req.req.Request.Identifier, + + Counter: res.Msg.Previous.Counter, + Sequence: res.Msg.Previous.Sequence, + Duration: duration, + Time: t, + }, + ) + + if req.localPassed == res.Msg.Response.Success { + ratelimitAccuracy.WithLabelValues("true").Inc() + } else { + ratelimitAccuracy.WithLabelValues("false").Inc() + } + + // req.events is guaranteed to have at least element + // and the first one should be the oldest event, so we can use it to get the max latency + latency := time.Since(t) + labels := map[string]string{ + "nodeId": s.cluster.NodeId(), + "peerId": peer.Id, + } + prometheus.RatelimitPushPullEvents.With(labels).Inc() + + prometheus.RatelimitPushPullLatency.With(labels).Observe(latency.Seconds()) + + // if we got this far, we pushpulled successfully with a peer and don't need to try the rest + +} diff --git a/go/internal/services/ratelimit/sequence.go b/go/internal/services/ratelimit/sequence.go new file mode 100644 index 0000000000..10b07421d9 --- /dev/null +++ b/go/internal/services/ratelimit/sequence.go @@ -0,0 +1,7 @@ +package ratelimit + +import "time" + +func calculateSequence(t time.Time, duration time.Duration) int64 { + return t.UnixMilli() / duration.Milliseconds() +} diff --git a/go/internal/services/ratelimit/sliding_window.go b/go/internal/services/ratelimit/sliding_window.go new file mode 100644 index 0000000000..ce8ebb257c --- /dev/null +++ b/go/internal/services/ratelimit/sliding_window.go @@ -0,0 +1,94 @@ +package ratelimit + +import ( + "context" + "math" + "sync" + + "connectrpc.com/connect" + ratelimitv1 "github.com/unkeyed/unkey/go/gen/proto/ratelimit/v1" + "github.com/unkeyed/unkey/go/gen/proto/ratelimit/v1/ratelimitv1connect" + "github.com/unkeyed/unkey/go/pkg/circuitbreaker" + "github.com/unkeyed/unkey/go/pkg/clock" + "github.com/unkeyed/unkey/go/pkg/cluster" + "github.com/unkeyed/unkey/go/pkg/logging" + "github.com/unkeyed/unkey/go/pkg/tracing" + "go.opentelemetry.io/otel/attribute" +) + +type service struct { + clock clock.Clock + + logger logging.Logger + cluster cluster.Cluster + + shutdownCh chan struct{} + + bucketsLock sync.RWMutex + // identifier+sequence -> bucket + buckets map[string]*bucket + + peerMu sync.RWMutex + peers map[string]ratelimitv1connect.RatelimitServiceClient + + replayCircuitBreaker circuitbreaker.CircuitBreaker[*connect.Response[ratelimitv1.ReplayResponse]] +} + +func (r *service) Ratelimit(ctx context.Context, req RatelimitRequest) RatelimitResponse { + ctx, span := tracing.Start(ctx, "slidingWindow.Ratelimit") + defer span.End() + + now := r.clock.Now() + + key := bucketKey{req.Identifier, req.Limit, req.Duration} + span.SetAttributes(attribute.String("key", string(key.toString()))) + + bucket, _ := r.getBucket(key) + + bucket.Lock() + defer bucket.Unlock() + + currentWindow := bucket.getCurrentWindow(now) + previousWindow := bucket.getPreviousWindow(now) + currentWindowPercentage := float64(now.UnixMilli()-currentWindow.Start) / float64(req.Duration) + previousWindowPercentage := 1.0 - currentWindowPercentage + + // Calculate the current count including all leases + fromPreviousWindow := float64(previousWindow.Counter) * previousWindowPercentage + fromCurrentWindow := float64(currentWindow.Counter) + + current := int64(math.Ceil(fromCurrentWindow + fromPreviousWindow)) + + // Evaluate if the request should pass or not + + if current+req.Cost > req.Limit { + + remaining := req.Limit - current + if remaining < 0 { + remaining = 0 + } + return RatelimitResponse{ + Success: false, + Remaining: remaining, + Reset: currentWindow.Start + currentWindow.Duration, + Limit: req.Limit, + Current: current, + } + } + + currentWindow.Counter += req.Cost + + current += req.Cost + + remaining := req.Limit - current + if remaining < 0 { + remaining = 0 + } + return RatelimitResponse{ + Success: true, + Remaining: remaining, + Reset: currentWindow.Start + currentWindow.Duration, + Limit: req.Limit, + Current: current, + } +} diff --git a/go/internal/services/ratelimit/window.go b/go/internal/services/ratelimit/window.go new file mode 100644 index 0000000000..3957620107 --- /dev/null +++ b/go/internal/services/ratelimit/window.go @@ -0,0 +1,16 @@ +package ratelimit + +import ( + "time" + + ratelimitv1 "github.com/unkeyed/unkey/go/gen/proto/ratelimit/v1" +) + +func newWindow(sequence int64, t time.Time, duration time.Duration) *ratelimitv1.Window { + return &ratelimitv1.Window{ + Sequence: sequence, + Start: t.Truncate(duration).UnixMilli(), + Duration: duration.Milliseconds(), + Counter: 0, + } +} diff --git a/go/pkg/certificate/amazon_certificate_manager.go b/go/pkg/certificate/amazon_certificate_manager.go new file mode 100644 index 0000000000..e683327fc8 --- /dev/null +++ b/go/pkg/certificate/amazon_certificate_manager.go @@ -0,0 +1,4 @@ +package certificate + +type ACM struct { +} diff --git a/go/pkg/certificate/autocert.go b/go/pkg/certificate/autocert.go new file mode 100644 index 0000000000..9535cff619 --- /dev/null +++ b/go/pkg/certificate/autocert.go @@ -0,0 +1,34 @@ +package certificate + +import ( + "crypto/tls" + "log" + "sync/atomic" + + "golang.org/x/crypto/acme/autocert" +) + +type devCertificateSource struct { + manager *autocert.Manager + counter atomic.Uint64 +} + +func (cs *devCertificateSource) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + log.Println("getCertificate", cs.counter.Add(1)) + + return cs.manager.GetCertificate(info) +} + +func NewDevCertificateSource() (*devCertificateSource, error) { + + m := &autocert.Manager{ + Prompt: autocert.AcceptTOS, + // HostPolicy: autocert.HostWhitelist("andreas.localhost.com"), + } + + return &devCertificateSource{ + manager: m, + counter: atomic.Uint64{}, + }, nil + +} diff --git a/go/pkg/certificate/interface.go b/go/pkg/certificate/interface.go new file mode 100644 index 0000000000..20eb22f598 --- /dev/null +++ b/go/pkg/certificate/interface.go @@ -0,0 +1,9 @@ +package certificate + +import ( + "crypto/tls" +) + +type Source interface { + GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) +} diff --git a/go/pkg/circuitbreaker/interface.go b/go/pkg/circuitbreaker/interface.go new file mode 100644 index 0000000000..bc1332ae59 --- /dev/null +++ b/go/pkg/circuitbreaker/interface.go @@ -0,0 +1,29 @@ +package circuitbreaker + +import ( + "context" + "errors" +) + +type State string + +var ( + // Open state means the circuit breaker is open and requests are not allowed + // to pass through + Open State = "open" + // HalfOpen state means the circuit breaker is in a state of testing the + // upstream service to see if it has recovered + HalfOpen State = "halfopen" + // Closed state means the circuit breaker is allowing requests to pass + // through to the upstream service + Closed State = "closed" +) + +var ( + ErrTripped = errors.New("circuit breaker is open") + ErrTooManyRequests = errors.New("too many requests during half open state") +) + +type CircuitBreaker[Res any] interface { + Do(ctx context.Context, fn func(context.Context) (Res, error)) (Res, error) +} diff --git a/go/pkg/circuitbreaker/lib.go b/go/pkg/circuitbreaker/lib.go new file mode 100644 index 0000000000..327ebb2b09 --- /dev/null +++ b/go/pkg/circuitbreaker/lib.go @@ -0,0 +1,227 @@ +package circuitbreaker + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/unkeyed/unkey/apps/agent/pkg/clock" + "github.com/unkeyed/unkey/apps/agent/pkg/logging" + "github.com/unkeyed/unkey/apps/agent/pkg/tracing" +) + +type CB[Res any] struct { + sync.Mutex + // This is a pointer to the configuration of the circuit breaker because we + // need to modify the clock for testing + config *config + + logger logging.Logger + + // State of the circuit + state State + + // reset the counters every cyclic period + resetCountersAt time.Time + + // reset the state every recoveryTimeout + resetStateAt time.Time + + // counters are protected by the mutex and are reset every cyclic period + requests int + successes int + failures int + consecutiveSuccesses int + consecutiveFailures int +} + +type config struct { + name string + // Max requests that may pass through the circuit breaker in its half-open state + // If all requests are successful, the circuit will close + // If any request fails, the circuit will remaing half open until the next cycle + maxRequests int + + // Interval to clear counts while the circuit is closed + cyclicPeriod time.Duration + + // How long the circuit will stay open before transitioning to half-open + timeout time.Duration + + // Determine whether the error is a downstream error or not + // If the error is a downstream error, the circuit will count it + // If the error is not a downstream error, the circuit will not count it + isDownstreamError func(error) bool + + // How many downstream errors within a cyclic period are allowed before the + // circuit trips and opens + tripThreshold int + + // Clock to use for timing, defaults to the system clock but can be overridden for testing + clock clock.Clock + + logger logging.Logger +} + +func WithMaxRequests(maxRequests int) applyConfig { + return func(c *config) { + c.maxRequests = maxRequests + } +} + +func WithCyclicPeriod(cyclicPeriod time.Duration) applyConfig { + return func(c *config) { + c.cyclicPeriod = cyclicPeriod + } +} +func WithIsDownstreamError(isDownstreamError func(error) bool) applyConfig { + return func(c *config) { + c.isDownstreamError = isDownstreamError + } +} +func WithTripThreshold(tripThreshold int) applyConfig { + return func(c *config) { + c.tripThreshold = tripThreshold + } +} + +func WithTimeout(timeout time.Duration) applyConfig { + return func(c *config) { + c.timeout = timeout + } +} + +// for testing +func WithClock(clock clock.Clock) applyConfig { + return func(c *config) { + c.clock = clock + } +} + +func WithLogger(logger logging.Logger) applyConfig { + return func(c *config) { + c.logger = logger + } +} + +// applyConfig applies a config setting to the circuit breaker +type applyConfig func(*config) + +func New[Res any](name string, applyConfigs ...applyConfig) *CB[Res] { + + cfg := &config{ + name: name, + maxRequests: 10, + cyclicPeriod: 5 * time.Second, + timeout: time.Minute, + isDownstreamError: func(err error) bool { + return err != nil + }, + tripThreshold: 5, + clock: clock.New(), + logger: logging.New(nil), + } + + for _, apply := range applyConfigs { + apply(cfg) + } + + cb := &CB[Res]{ + config: cfg, + logger: cfg.logger, + state: Closed, + resetCountersAt: cfg.clock.Now().Add(cfg.cyclicPeriod), + resetStateAt: cfg.clock.Now().Add(cfg.timeout), + } + + return cb +} + +var _ CircuitBreaker[any] = &CB[any]{} + +func (cb *CB[Res]) Do(ctx context.Context, fn func(context.Context) (Res, error)) (res Res, err error) { + ctx, span := tracing.Start(ctx, tracing.NewSpanName(fmt.Sprintf("circuitbreaker.%s", cb.config.name), "Do")) + defer span.End() + + err = cb.preflight(ctx) + if err != nil { + return res, err + } + + ctx, fnSpan := tracing.Start(ctx, tracing.NewSpanName(fmt.Sprintf("circuitbreaker.%s", cb.config.name), "fn")) + res, err = fn(ctx) + fnSpan.End() + + cb.postflight(ctx, err) + + return res, err + +} + +// preflight checks if the circuit is ready to accept a request +func (cb *CB[Res]) preflight(ctx context.Context) error { + ctx, span := tracing.Start(ctx, tracing.NewSpanName(fmt.Sprintf("circuitbreaker.%s", cb.config.name), "preflight")) + defer span.End() + cb.Lock() + defer cb.Unlock() + + now := cb.config.clock.Now() + + if now.After(cb.resetCountersAt) { + cb.requests = 0 + cb.successes = 0 + cb.failures = 0 + cb.consecutiveSuccesses = 0 + cb.consecutiveFailures = 0 + cb.resetCountersAt = now.Add(cb.config.cyclicPeriod) + } + if cb.state == Open && now.After(cb.resetStateAt) { + cb.state = HalfOpen + cb.resetStateAt = now.Add(cb.config.timeout) + } + + requests.WithLabelValues(cb.config.name, string(cb.state)).Inc() + + if cb.state == Open { + return ErrTripped + } + + cb.logger.Debug().Str("state", string(cb.state)).Int("requests", cb.requests).Int("maxRequests", cb.config.maxRequests).Msg("circuit breaker state") + if cb.state == HalfOpen && cb.requests >= cb.config.maxRequests { + return ErrTooManyRequests + } + return nil +} + +// postflight updates the circuit breaker state based on the result of the request +func (cb *CB[Res]) postflight(ctx context.Context, err error) { + ctx, span := tracing.Start(ctx, tracing.NewSpanName(fmt.Sprintf("circuitbreaker.%s", cb.config.name), "postflight")) + defer span.End() + cb.Lock() + defer cb.Unlock() + cb.requests++ + if cb.config.isDownstreamError(err) { + cb.failures++ + cb.consecutiveFailures++ + cb.consecutiveSuccesses = 0 + } else { + cb.successes++ + cb.consecutiveSuccesses++ + cb.consecutiveFailures = 0 + } + + switch cb.state { + + case Closed: + if cb.failures >= cb.config.tripThreshold { + cb.state = Open + } + + case HalfOpen: + if cb.consecutiveSuccesses >= cb.config.maxRequests { + cb.state = Closed + } + } + +} diff --git a/go/pkg/circuitbreaker/lib_test.go b/go/pkg/circuitbreaker/lib_test.go new file mode 100644 index 0000000000..b4ea0f5260 --- /dev/null +++ b/go/pkg/circuitbreaker/lib_test.go @@ -0,0 +1,93 @@ +package circuitbreaker + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/unkeyed/unkey/apps/agent/pkg/clock" + "github.com/unkeyed/unkey/apps/agent/pkg/logging" +) + +var errTestDownstream = errors.New("downstream test error") + +func TestCircuitBreakerStates(t *testing.T) { + + c := clock.NewTestClock() + cb := New[int]("test", WithCyclicPeriod(5*time.Second), WithClock(c), WithTripThreshold(3), WithLogger(logging.New(nil))) + + // Test Closed State + for i := 0; i < 3; i++ { + _, err := cb.Do(context.Background(), func(ctx context.Context) (int, error) { + return 0, errTestDownstream + }) + require.ErrorIs(t, err, errTestDownstream) + } + require.Equal(t, Open, cb.state) + + // Test Open State + _, err := cb.Do(context.Background(), func(ctx context.Context) (int, error) { + return 0, errTestDownstream + }) + require.ErrorIs(t, err, ErrTripped) + require.Equal(t, Open, cb.state) + + // Test Half-Open State + c.Tick(2 * time.Minute) // Advance time to reset + _, err = cb.Do(context.Background(), func(ctx context.Context) (int, error) { + return 42, nil + }) + require.NoError(t, err) + require.Equal(t, HalfOpen, cb.state) +} + +func TestCircuitBreakerReset(t *testing.T) { + + c := clock.NewTestClock() + cb := New[int]("test", WithCyclicPeriod(5*time.Second), WithClock(c), WithTripThreshold(3), WithTimeout(20*time.Second)) + + // Trigger circuit breaker to open + for i := 0; i < 3; i++ { + _, err := cb.Do(context.Background(), func(ctx context.Context) (int, error) { + return 0, errTestDownstream + }) + require.ErrorIs(t, err, errTestDownstream) + } + + require.Equal(t, Open, cb.state) + + // Advance time to reset + c.Tick(30 * time.Second) + + // Next request should be allowed (Half-Open state) + _, err := cb.Do(context.Background(), func(ctx context.Context) (int, error) { + return 42, nil + }) + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + require.Equal(t, HalfOpen, cb.state) + +} + +func TestCircuitBreakerRecovers(t *testing.T) { + + cb := New[int]("test", WithMaxRequests(2)) + + // Reset to Half-Open state + cb.state = HalfOpen + + // Two requests should succeed + for i := 0; i < 2; i++ { + _, err := cb.Do(context.Background(), func(ctx context.Context) (int, error) { + return 42, nil + }) + require.NoError(t, err) + } + + // Circuit should close + require.Equal(t, Closed, cb.state) +} diff --git a/go/pkg/circuitbreaker/metrics.go b/go/pkg/circuitbreaker/metrics.go new file mode 100644 index 0000000000..ccc0167340 --- /dev/null +++ b/go/pkg/circuitbreaker/metrics.go @@ -0,0 +1,14 @@ +package circuitbreaker + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var ( + requests = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "agent", + Subsystem: "circuitbreaker", + Name: "requests", + }, []string{"name", "state"}) +) diff --git a/go/pkg/membership/fake.go b/go/pkg/membership/fake.go new file mode 100644 index 0000000000..b9216af281 --- /dev/null +++ b/go/pkg/membership/fake.go @@ -0,0 +1,117 @@ +package membership + +import ( + "context" + "fmt" + "log/slog" + "sync" + + "github.com/unkeyed/unkey/go/pkg/events" + "github.com/unkeyed/unkey/go/pkg/logging" +) + +type fakeMembership struct { + mu sync.Mutex + started bool + logger logging.Logger + rpcAddr string + joinEvents events.Topic[Member] + leaveEvents events.Topic[Member] + + members []Member + + nodeID string +} + +type FakeConfig struct { + NodeID string + RpcAddr string + Logger logging.Logger +} + +func NewFake(config FakeConfig) (*fakeMembership, error) { + + return &fakeMembership{ + mu: sync.Mutex{}, + logger: config.Logger.With(slog.String("pkg", "service discovery"), slog.String("type", "fake")), + rpcAddr: config.RpcAddr, + joinEvents: events.NewTopic[Member](), + leaveEvents: events.NewTopic[Member](), + nodeID: config.NodeID, + members: []Member{}, + }, nil + +} + +func (m *fakeMembership) AddMember(member Member) { + m.mu.Lock() + defer m.mu.Unlock() + + for _, existing := range m.members { + if existing.ID == member.ID { + return + } + } + m.members = append(m.members, member) + m.joinEvents.Emit(context.Background(), member) +} + +func (m *fakeMembership) RemoveMember(member Member) { + m.mu.Lock() + defer m.mu.Unlock() + + for i, existing := range m.members { + if existing.ID == member.ID { + m.members[i] = m.members[len(m.members)-1] + m.members = m.members[:len(m.members)-1] + m.leaveEvents.Emit(context.Background(), member) + + return + } + } +} + +func (m *fakeMembership) Join(ctx context.Context) (int, error) { + + m.mu.Lock() + defer m.mu.Unlock() + if m.started { + return 0, fmt.Errorf("Membership already started") + } + m.started = true + + members, err := m.Members(ctx) + if err != nil { + return 0, fmt.Errorf("failed to get members: %w", err) + } + return len(members), nil + +} + +func (m *fakeMembership) heartbeatRedisKey() string { + return fmt.Sprintf("cluster::membership::nodes::%s", m.nodeID) +} + +func (m *fakeMembership) Leave(ctx context.Context) error { + return nil +} + +func (m *fakeMembership) Members(ctx context.Context) ([]Member, error) { + m.mu.Lock() + defer m.mu.Unlock() + + membersCopy := make([]Member, len(m.members)) + copy(membersCopy, m.members) + + return membersCopy, nil +} +func (m *fakeMembership) Addr() string { + return m.rpcAddr +} +func (m *fakeMembership) SubscribeJoinEvents() <-chan Member { + return m.joinEvents.Subscribe("cluster_join_events") +} + +func (m *fakeMembership) SubscribeLeaveEvents() <-chan Member { + return m.leaveEvents.Subscribe("cluster_leave_events") +} diff --git a/go/pkg/zen/middleware_auth.go b/go/pkg/zen/middleware_auth.go new file mode 100644 index 0000000000..da57190bb0 --- /dev/null +++ b/go/pkg/zen/middleware_auth.go @@ -0,0 +1,18 @@ +package zen + +import ( + "github.com/unkeyed/unkey/go/pkg/zen/validation" +) + +func WithValidation(validator *validation.Validator) Middleware { + return func(next HandleFunc) HandleFunc { + return func(s *Session) error { + err, valid := validator.Validate(s.r) + if !valid { + err.RequestId = s.requestID + return s.JSON(err.Status, err) + } + return next(s) + } + } +} diff --git a/go/proto/ratelimit/v1/service.proto b/go/proto/ratelimit/v1/service.proto new file mode 100644 index 0000000000..573c812c89 --- /dev/null +++ b/go/proto/ratelimit/v1/service.proto @@ -0,0 +1,157 @@ +syntax = "proto3"; + +package ratelimit.v1; + +option go_package = "github.com/unkeyed/unkey/apps/agent/gen/proto/ratelimit/v1;ratelimitv1"; + +// RatelimitRequest represents a request to check or consume rate limit tokens. +// This is typically the first point of contact when a client wants to verify +// if they are allowed to perform an action under the rate limit constraints. +message RatelimitRequest { + // Unique identifier for the rate limit subject. + // This could be: + // - A user ID + // - An API key + // - An IP address + // - Any other unique identifier that needs rate limiting + string identifier = 1; + + // Maximum number of tokens allowed within the duration. + // Once this limit is reached, subsequent requests will be denied + // until there is more capacity. + int64 limit = 2; + + // Duration of the rate limit window in milliseconds. + // After this duration, a new window begins. + // Common values might be: + // - 1000 (1 second) + // - 60000 (1 minute) + // - 3600000 (1 hour) + int64 duration = 3; + + // Number of tokens to consume in this request. + // Defaults to 1 if not specified. + // Higher values can be used for operations that should count more heavily + // against the rate limit (e.g., batch operations). + optional int64 cost = 4; +} + +// RatelimitResponse contains the result of a rate limit check. +// This response includes all necessary information for clients to understand +// their current rate limit status and when they can retry if limited. +message RatelimitResponse { + // Total limit configured for this window. + // This matches the limit specified in the request and is included + // for convenience in client implementations. + int64 limit = 1; + + // Number of tokens remaining in the current window. + // Clients can use this to implement progressive backoff or + // warn users when they're close to their limit. + int64 remaining = 2; + + // Unix timestamp (in milliseconds) when the current window expires. + // Clients can use this to: + // - Display time until reset to users + // - Implement automatic retry after window reset + // - Schedule future requests optimally + int64 reset = 3; + + // Whether the rate limit check was successful. + // true = request is allowed + // false = request is denied due to rate limit exceeded + bool success = 4; + + // Current token count in this window. + // This represents how many tokens have been consumed so far, + // useful for monitoring and debugging purposes. + int64 current = 5; +} + +// Window represents a rate limiting time window with its state. +// The system uses a sliding window approach to provide smooth +// rate limiting behavior across window boundaries. +message Window { + // Monotonically increasing sequence number for window ordering. + // The sequence is calculated as follows: + // sequence = time.Now().UnixMilli() / duration + int64 sequence = 1; + + // Duration of the window in milliseconds. + // This matches the duration from the original request and defines + // how long this window remains active. + int64 duration = 2; + + // Current token count in this window. + // This is the actual count of tokens consumed during this window's + // lifetime. It must never exceed the configured limit. + int64 counter = 3; + + // Start time of the window (Unix timestamp in milliseconds). + // Used to: + // - Calculate window expiration + // - Determine if a window is still active + // - Handle sliding window calculations between current and previous windows + int64 start = 4; +} + +// ReplayRequest is used to synchronize rate limit state between nodes. +// This is a crucial part of maintaining consistency in a distributed +// rate limiting system. +message ReplayRequest { + // Original rate limit request that triggered the replay. + // Contains all the parameters needed to evaluate the rate limit + // on the origin server. + RatelimitRequest request = 1; + + // Indicates if the edge node denied the request. + // When false: The origin must increment the counter regardless of its own evaluation + // When true: The origin can evaluate the request fresh + // This field is crucial for maintaining consistency when edge nodes + // make defensive denials due to network issues or uncertainty. + bool denied = 2; +} + +// ReplayResponse contains the synchronized rate limit state that +// should be used to update both the origin and edge nodes. +message ReplayResponse { + // Current active window state. + // This represents the authoritative state of the current window + // as determined by the origin server. + Window current = 1; + + // Previous window state for sliding window calculations. + // Used to smooth out rate limiting across window boundaries and + // prevent sharp cliffs in availability during window transitions. + Window previous = 2; + + // Rate limit response that should be used by the edge node. + // This is the authoritative response that should be returned to + // the client and used to update edge state. + RatelimitResponse response = 3; +} + +// RatelimitService provides rate limiting functionality in a distributed system. +// The service is designed to work in a multi-node environment where consistency +// and reliability are crucial. +service RatelimitService { + + // Replay synchronizes rate limit state between nodes using consistent hashing. + // + // Key behaviors: + // - Each identifier maps to exactly one origin server via consistent hashing + // - Edge nodes replay their local rate limit decisions to the origin + // - Origin maintains the source of truth for rate limit state + // - Edge nodes must update their state based on origin responses + // + // Flow: + // 1. Edge node receives rate limit request + // 2. Edge makes local decision (may be defensive) + // 3. Edge replays decision to origin + // 4. Origin processes and returns authoritative state + // 5. Edge updates local state and returns result to client + // + // This approach ensures eventual consistency while allowing for + // fast local decisions at the edge. + rpc Replay(ReplayRequest) returns (ReplayResponse) {} +} diff --git a/internal/db/src/schema/workspaces.ts b/internal/db/src/schema/workspaces.ts index 0498f1497f..58ee4816b7 100644 --- a/internal/db/src/schema/workspaces.ts +++ b/internal/db/src/schema/workspaces.ts @@ -58,7 +58,6 @@ export const workspaces = mysqlTable( */ rbac?: boolean; - ratelimit?: boolean; identities?: boolean; /** diff --git a/internal/db/src/types.ts b/internal/db/src/types.ts index ee288c50ba..be93c7a286 100644 --- a/internal/db/src/types.ts +++ b/internal/db/src/types.ts @@ -21,4 +21,5 @@ export type KeyPermission = InferSelectModel; export type Ratelimit = InferSelectModel; export type Identity = InferSelectModel; export type AuditLog = InferSelectModel; +export type AuditLogBucket = InferSelectModel; export type AuditLogTarget = InferSelectModel;