diff --git a/apps/api/src/pkg/keys/service.ts b/apps/api/src/pkg/keys/service.ts index 884ce8b4b4..21c5349b37 100644 --- a/apps/api/src/pkg/keys/service.ts +++ b/apps/api/src/pkg/keys/service.ts @@ -10,6 +10,18 @@ import type { PermissionQuery, RBAC } from "@unkey/rbac"; import type { Logger } from "@unkey/worker-logging"; import { retry } from "../util/retry"; +/* + * Unless specified by the user, we deduct this from the current `remaining` + * value of the key. + */ +const DEFAULT_REMAINING_COST = 1; + +/** + * Unless specified by the user, we deduct this from the current ratelimit + * tokens of the key. + */ +const DEFAULT_RATELIMIT_COST = 1; + export class DisabledWorkspaceError extends BaseError<{ workspaceId: string }> { public readonly retry = false; public readonly name = DisabledWorkspaceError.name; @@ -142,6 +154,7 @@ export class KeyService { permissionQuery?: PermissionQuery; ratelimit?: { cost?: number }; ratelimits?: Array>; + remaining?: { cost: number }; }, ): Promise< Result< @@ -322,6 +335,7 @@ export class KeyService { permissionQuery?: PermissionQuery; ratelimit?: { cost?: number }; ratelimits?: Array>; + remaining?: { cost: number }; }, opts?: { skipCache?: boolean; @@ -535,7 +549,7 @@ export class KeyService { ratelimits.default = { identity: data.key.id, name: data.ratelimits.default.name, - cost: req.ratelimit?.cost ?? 1, + cost: req.ratelimit?.cost ?? DEFAULT_RATELIMIT_COST, limit: data.ratelimits.default.limit, duration: data.ratelimits.default.duration, }; @@ -546,7 +560,7 @@ export class KeyService { ratelimits[r.name] = { identity: data.identity?.id ?? data.key.id, name: r.name, - cost: r.cost ?? 1, + cost: r.cost ?? DEFAULT_RATELIMIT_COST, limit: r.limit, duration: r.duration, }; @@ -558,7 +572,7 @@ export class KeyService { ratelimits[configured.name] = { identity: data.identity?.id ?? data.key.id, name: configured.name, - cost: r.cost ?? 1, + cost: r.cost ?? DEFAULT_RATELIMIT_COST, limit: configured.limit, duration: configured.duration, }; @@ -591,7 +605,10 @@ export class KeyService { let remaining: number | undefined = undefined; if (data.key.remaining !== null) { - const limited = await this.usageLimiter.limit({ keyId: data.key.id }); + const limited = await this.usageLimiter.limit({ + keyId: data.key.id, + cost: req.remaining?.cost ?? DEFAULT_REMAINING_COST, + }); remaining = limited.remaining; if (!limited.valid) { return Ok({ diff --git a/apps/api/src/pkg/usagelimit/durable_object.ts b/apps/api/src/pkg/usagelimit/durable_object.ts index 8e680b1065..dfadd0bc03 100644 --- a/apps/api/src/pkg/usagelimit/durable_object.ts +++ b/apps/api/src/pkg/usagelimit/durable_object.ts @@ -75,19 +75,19 @@ export class DurableObjectUsagelimiter implements DurableObject { }); } - if (this.key.remaining <= 0) { + if (this.key.remaining <= 0 && req.cost !== 0) { return Response.json({ valid: false, remaining: 0, }); } - this.key.remaining = Math.max(0, this.key.remaining - 1); + this.key.remaining = Math.max(0, this.key.remaining - req.cost); this.state.waitUntil( this.db .update(schema.keys) - .set({ remaining: sql`${schema.keys.remaining}-1` }) + .set({ remaining: sql`${schema.keys.remaining}-${req.cost}` }) .where( and( eq(schema.keys.id, this.key.id), diff --git a/apps/api/src/pkg/usagelimit/interface.ts b/apps/api/src/pkg/usagelimit/interface.ts index 6c7dcaa9da..5493e0efdb 100644 --- a/apps/api/src/pkg/usagelimit/interface.ts +++ b/apps/api/src/pkg/usagelimit/interface.ts @@ -2,6 +2,7 @@ import { z } from "zod"; export const limitRequestSchema = z.object({ keyId: z.string(), + cost: z.number(), }); export type LimitRequest = z.infer; diff --git a/apps/api/src/routes/v1_keys_verifyKey.test.ts b/apps/api/src/routes/v1_keys_verifyKey.test.ts index a5534e16ca..9e94a2d124 100644 --- a/apps/api/src/routes/v1_keys_verifyKey.test.ts +++ b/apps/api/src/routes/v1_keys_verifyKey.test.ts @@ -423,6 +423,66 @@ describe("with default ratelimit", () => { }); }); +describe("with remaining", () => { + test("custom cost works", async (t) => { + const h = await IntegrationHarness.init(t); + const key = new KeyV1({ prefix: "test", byteLength: 16 }).toString(); + await h.db.primary.insert(schema.keys).values({ + id: newId("test"), + keyAuthId: h.resources.userKeyAuth.id, + hash: await sha256(key), + start: key.slice(0, 8), + workspaceId: h.resources.userWorkspace.id, + createdAt: new Date(), + remaining: 10, + }); + + const res = await h.post({ + url: "/v1/keys.verifyKey", + headers: { + "Content-Type": "application/json", + }, + body: { + key, + apiId: h.resources.userApi.id, + remaining: { cost: 2 }, + }, + }); + expect(res.status, `expected 200, received: ${JSON.stringify(res, null, 2)}`).toBe(200); + expect(res.body.valid).toBe(true); + expect(res.body.remaining).toEqual(8); + }); + + test("cost=0 works even when remaining=0", async (t) => { + const h = await IntegrationHarness.init(t); + const key = new KeyV1({ prefix: "test", byteLength: 16 }).toString(); + await h.db.primary.insert(schema.keys).values({ + id: newId("test"), + keyAuthId: h.resources.userKeyAuth.id, + hash: await sha256(key), + start: key.slice(0, 8), + workspaceId: h.resources.userWorkspace.id, + createdAt: new Date(), + remaining: 0, + }); + + const res = await h.post({ + url: "/v1/keys.verifyKey", + headers: { + "Content-Type": "application/json", + }, + body: { + key, + apiId: h.resources.userApi.id, + remaining: { cost: 0 }, + }, + }); + expect(res.status, `expected 200, received: ${JSON.stringify(res, null, 2)}`).toBe(200); + expect(res.body.valid).toBe(true); + expect(res.body.remaining).toEqual(0); + }); +}); + describe("with ratelimit", () => { describe("with valid key", () => { test.skip( diff --git a/apps/api/src/routes/v1_keys_verifyKey.ts b/apps/api/src/routes/v1_keys_verifyKey.ts index 39b4f50482..743174657b 100644 --- a/apps/api/src/routes/v1_keys_verifyKey.ts +++ b/apps/api/src/routes/v1_keys_verifyKey.ts @@ -90,6 +90,18 @@ The key will be verified against the api's configuration. If the key does not be .openapi({ description: "Perform RBAC checks", }), + remaining: z + .object({ + cost: z.number().int().default(1).openapi({ + description: + "How many tokens should be deducted from the current `remaining` value. Set it to 0, to make it free.", + }), + }) + .optional() + .openapi({ + description: + "Customize the behaviour of deducting remaining uses. When some of your endpoints are more expensive than others, you can set a custom `cost` for each.", + }), ratelimit: z .object({ cost: z.number().int().min(0).optional().default(1).openapi({ @@ -195,6 +207,7 @@ A key could be invalid for a number of reasons, for example if it has expired, h "The unix timestamp in milliseconds when the key will expire. If this field is null or undefined, the key is not expiring.", example: 123, }), + ratelimit: z .object({ limit: z.number().int().openapi({ @@ -310,6 +323,7 @@ export const registerV1KeysVerifyKey = (app: App) => permissionQuery: req.authorization?.permissions, ratelimit: req.ratelimit, ratelimits: req.ratelimits, + remaining: req.remaining, }); if (err) {