diff --git a/apps/api/src/pkg/keys/service.ts b/apps/api/src/pkg/keys/service.ts index d01b5fb087..503da59b46 100644 --- a/apps/api/src/pkg/keys/service.ts +++ b/apps/api/src/pkg/keys/service.ts @@ -608,7 +608,6 @@ export class KeyService { }); } - console.warn("ABC"); let remaining: number | undefined = undefined; if (data.key.remaining !== null) { const t0 = performance.now(); diff --git a/go/apps/api/config.go b/go/apps/api/config.go index 6c865d9587..fbeb1f2e6d 100644 --- a/go/apps/api/config.go +++ b/go/apps/api/config.go @@ -5,10 +5,17 @@ import ( "github.com/unkeyed/unkey/go/pkg/tls" ) -type Config struct { +type S3Config struct { + URL string + Bucket string + AccessKeyID string + SecretAccessKey string +} +type Config struct { // InstanceID is the unique identifier for this instance of the API server InstanceID string + // Platform identifies the cloud platform where the node is running (e.g., aws, gcp, hetzner) Platform string @@ -54,7 +61,9 @@ type Config struct { // TLSConfig provides HTTPS support when set TLSConfig *tls.Config + // Vault Configuration VaultMasterKeys []string + VaultS3 *S3Config } func (c Config) Validate() error { diff --git a/go/apps/api/integration/harness.go b/go/apps/api/integration/harness.go index bd4b726690..eb07aefd27 100644 --- a/go/apps/api/integration/harness.go +++ b/go/apps/api/integration/harness.go @@ -156,6 +156,7 @@ func (h *Harness) RunAPI(config ApiConfig) *ApiCluster { PrometheusPort: 0, TLSConfig: nil, VaultMasterKeys: []string{"Ch9rZWtfMmdqMFBJdVhac1NSa0ZhNE5mOWlLSnBHenFPENTt7an5MRogENt9Si6wms4pQ2XIvqNSIgNpaBenJmXgcInhu6Nfv2U="}, // Test key from docker-compose + VaultS3: nil, } // Start API server in goroutine diff --git a/go/apps/api/openapi/gen.go b/go/apps/api/openapi/gen.go index 20c8fef1b2..2edf28d30f 100644 --- a/go/apps/api/openapi/gen.go +++ b/go/apps/api/openapi/gen.go @@ -29,7 +29,6 @@ const ( INSUFFICIENTPERMISSIONS KeysVerifyKeyResponseDataCode = "INSUFFICIENT_PERMISSIONS" NOTFOUND KeysVerifyKeyResponseDataCode = "NOT_FOUND" RATELIMITED KeysVerifyKeyResponseDataCode = "RATE_LIMITED" - UNAUTHORIZED KeysVerifyKeyResponseDataCode = "UNAUTHORIZED" USAGEEXCEEDED KeysVerifyKeyResponseDataCode = "USAGE_EXCEEDED" VALID KeysVerifyKeyResponseDataCode = "VALID" ) @@ -320,6 +319,33 @@ type KeysDeleteKeyResponseData = map[string]interface{} // KeysUpdateKeyResponseData Empty response object by design. A successful response indicates the key was updated successfully. The endpoint doesn't return the updated key to reduce response size and avoid exposing sensitive information. Changes may take up to 30 seconds to propagate to all regions due to cache invalidation delays. If you need the updated key state, use a subsequent call to `keys.getKey`. type KeysUpdateKeyResponseData = map[string]interface{} +// KeysVerifyKeyCredits Controls credit consumption for usage-based billing and quota enforcement. +// Omitting this field uses the default cost of 1 credit per verification. +// Credits provide globally consistent usage tracking, essential for paid APIs with strict quotas. +// Verification can succeed while credit deduction fails if the key has insufficient credits. +type KeysVerifyKeyCredits struct { + // Cost Sets how many credits to deduct for this verification request. + // Use 0 for read-only operations or free tier access, higher values for premium features. + // Credits are deducted immediately upon verification, even if the key lacks required permissions. + // Essential for implementing usage-based pricing with different operation costs. + Cost int32 `json:"cost"` +} + +// KeysVerifyKeyRatelimit defines model for KeysVerifyKeyRatelimit. +type KeysVerifyKeyRatelimit struct { + // Cost Optionally override how expensive this operation is and how many tokens are deducted from the current limit. + Cost *int `json:"cost,omitempty"` + + // Duration Optionally override the duration of the rate limit window duration. + Duration *int `json:"duration,omitempty"` + + // Limit Optionally override the maximum number of requests allowed within the specified interval. + Limit *int `json:"limit,omitempty"` + + // Name References an existing ratelimit by its name. Key Ratelimits will take precedence over identifier-based limits. + Name string `json:"name"` +} + // KeysVerifyKeyResponseData defines model for KeysVerifyKeyResponseData. type KeysVerifyKeyResponseData struct { // Code A machine-readable code indicating the verification status or failure reason. Values: `VALID` (key is valid), `NOT_FOUND` (key doesn't exist), `FORBIDDEN` (key exists but belongs to a different API), `USAGE_EXCEEDED` (key has no more credits), `RATE_LIMITED` (key exceeded rate limits), `UNAUTHORIZED` (key can't be used for this action), `DISABLED` (key was explicitly disabled), `INSUFFICIENT_PERMISSIONS` (key lacks required permissions), `EXPIRED` (key has passed its expiration date). @@ -332,11 +358,8 @@ type KeysVerifyKeyResponseData struct { Enabled *bool `json:"enabled,omitempty"` // Expires Unix timestamp (in milliseconds) when the key will expire. If null or not present, the key has no expiration. You can use this to warn users about upcoming expirations or to understand the validity period. - Expires *int64 `json:"expires,omitempty"` - - // ExternalId Your user/tenant identifier that was associated with this key during creation. This allows you to connect the key back to your user without additional database lookups, making it ideal for implementing user-based authorization in stateless services. - ExternalId *string `json:"externalId,omitempty"` - Identity *Identity `json:"identity,omitempty"` + Expires *int64 `json:"expires,omitempty"` + Identity *Identity `json:"identity,omitempty"` // KeyId The unique identifier of the verified key in Unkey's system. Use this ID for operations like updating or revoking the key. This field is returned for both valid and invalid keys (except when `code=NOT_FOUND`). KeyId *string `json:"keyId,omitempty"` @@ -348,8 +371,8 @@ type KeysVerifyKeyResponseData struct { Name *string `json:"name,omitempty"` // Permissions A list of all permission names assigned to this key, either directly or through roles. These permissions determine what actions the key can perform. Only returned when permissions were checked during verification or when the key fails with `code=INSUFFICIENT_PERMISSIONS`. - Permissions *[]string `json:"permissions,omitempty"` - Ratelimits *[]RatelimitResponse `json:"ratelimits,omitempty"` + Permissions *[]string `json:"permissions,omitempty"` + Ratelimits *[]VerifyKeyRatelimitData `json:"ratelimits,omitempty"` // Roles A list of all role names assigned to this key. Roles are collections of permissions that grant access to specific functionality. Only returned when permissions were checked during verification. Roles *[]string `json:"roles,omitempty"` @@ -418,6 +441,9 @@ type Permission struct { // Use clear, semantic names that reflect the resources or actions being permitted. // Names must be unique within your workspace to avoid confusion and conflicts. Name string `json:"name"` + + // Slug The URL-safe identifier when this permission was created. + Slug string `json:"slug"` } // PermissionsCreatePermissionResponseData defines model for PermissionsCreatePermissionResponseData. @@ -588,8 +614,7 @@ type RatelimitRequest struct { // RatelimitResponse defines model for RatelimitResponse. type RatelimitResponse struct { - // AutoApply Whether this rate limit should be automatically applied when verifying keys. - // When true, we will automatically apply this limit during verification without it being explicitly listed. + // AutoApply Whether this rate limit was automatically applied when verifying the key. AutoApply bool `json:"autoApply"` // Duration Rate limit window duration in milliseconds. @@ -1673,13 +1698,7 @@ type V2KeysVerifyKeyRequestBody struct { // Omitting this field uses the default cost of 1 credit per verification. // Credits provide globally consistent usage tracking, essential for paid APIs with strict quotas. // Verification can succeed while credit deduction fails if the key has insufficient credits. - Credits *struct { - // Cost Sets how many credits to deduct for this verification request. - // Use 0 for read-only operations or free tier access, higher values for premium features. - // Credits are deducted immediately upon verification, even if the key lacks required permissions. - // Essential for implementing usage-based pricing with different operation costs. - Cost *int64 `json:"cost,omitempty"` - } `json:"credits,omitempty"` + Credits *KeysVerifyKeyCredits `json:"credits,omitempty"` // Key The complete API key string provided by your user, including any prefix. // Verification uses secure hashing algorithms without storing plaintext values. @@ -1697,7 +1716,7 @@ type V2KeysVerifyKeyRequestBody struct { // Omitting this field skips rate limit checks entirely, relying only on configured key rate limits. // Multiple rate limits can be checked simultaneously, each with different costs and temporary overrides. // Rate limit checks are optimized for performance but may allow brief bursts during high concurrency. - Ratelimits *[]RatelimitRequest `json:"ratelimits,omitempty"` + Ratelimits *[]KeysVerifyKeyRatelimit `json:"ratelimits,omitempty"` // Tags Attaches metadata tags for analytics and monitoring without affecting verification outcomes. // Enables segmentation of API usage in dashboards by endpoint, client version, region, or custom dimensions. @@ -2027,8 +2046,15 @@ type V2RatelimitGetOverrideRequestBody struct { // NamespaceName The name of the rate limit namespace. Either `namespaceId` or `namespaceName` must be provided, but not both. Using `namespaceName` is more human-readable and easier to work with for manual operations and configurations. NamespaceName *string `json:"namespaceName,omitempty"` + union json.RawMessage } +// V2RatelimitGetOverrideRequestBody0 defines model for . +type V2RatelimitGetOverrideRequestBody0 = interface{} + +// V2RatelimitGetOverrideRequestBody1 defines model for . +type V2RatelimitGetOverrideRequestBody1 = interface{} + // V2RatelimitGetOverrideResponseBody defines model for V2RatelimitGetOverrideResponseBody. type V2RatelimitGetOverrideResponseBody struct { Data RatelimitOverride `json:"data"` @@ -2186,6 +2212,34 @@ type ValidationError struct { Message string `json:"message"` } +// VerifyKeyRatelimitData defines model for VerifyKeyRatelimitData. +type VerifyKeyRatelimitData struct { + // AutoApply Whether this rate limit should be automatically applied when verifying keys. + // When true, we will automatically apply this limit during verification without it being explicitly listed. + AutoApply bool `json:"autoApply"` + + // Duration Rate limit window duration in milliseconds. + Duration int64 `json:"duration"` + + // Exceeded Whether the rate limit was exceeded. + Exceeded bool `json:"exceeded"` + + // Id Unique identifier for this rate limit configuration. + Id string `json:"id"` + + // Limit Maximum requests allowed within the time window. + Limit int64 `json:"limit"` + + // Name Human-readable name for this rate limit. + Name string `json:"name"` + + // Remaining Rate limit remaining requests within the time window. + Remaining int64 `json:"remaining"` + + // Reset Rate limit reset duration in milliseconds. + Reset int64 `json:"reset"` +} + // CreateApiJSONRequestBody defines body for CreateApi for application/json ContentType. type CreateApiJSONRequestBody = V2ApisCreateApiRequestBody @@ -2831,3 +2885,125 @@ func (t *V2KeysVerifyKeyRequestBody_Permissions) UnmarshalJSON(b []byte) error { err := t.union.UnmarshalJSON(b) return err } + +// AsV2RatelimitGetOverrideRequestBody0 returns the union data inside the V2RatelimitGetOverrideRequestBody as a V2RatelimitGetOverrideRequestBody0 +func (t V2RatelimitGetOverrideRequestBody) AsV2RatelimitGetOverrideRequestBody0() (V2RatelimitGetOverrideRequestBody0, error) { + var body V2RatelimitGetOverrideRequestBody0 + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromV2RatelimitGetOverrideRequestBody0 overwrites any union data inside the V2RatelimitGetOverrideRequestBody as the provided V2RatelimitGetOverrideRequestBody0 +func (t *V2RatelimitGetOverrideRequestBody) FromV2RatelimitGetOverrideRequestBody0(v V2RatelimitGetOverrideRequestBody0) error { + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeV2RatelimitGetOverrideRequestBody0 performs a merge with any union data inside the V2RatelimitGetOverrideRequestBody, using the provided V2RatelimitGetOverrideRequestBody0 +func (t *V2RatelimitGetOverrideRequestBody) MergeV2RatelimitGetOverrideRequestBody0(v V2RatelimitGetOverrideRequestBody0) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JSONMerge(t.union, b) + t.union = merged + return err +} + +// AsV2RatelimitGetOverrideRequestBody1 returns the union data inside the V2RatelimitGetOverrideRequestBody as a V2RatelimitGetOverrideRequestBody1 +func (t V2RatelimitGetOverrideRequestBody) AsV2RatelimitGetOverrideRequestBody1() (V2RatelimitGetOverrideRequestBody1, error) { + var body V2RatelimitGetOverrideRequestBody1 + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromV2RatelimitGetOverrideRequestBody1 overwrites any union data inside the V2RatelimitGetOverrideRequestBody as the provided V2RatelimitGetOverrideRequestBody1 +func (t *V2RatelimitGetOverrideRequestBody) FromV2RatelimitGetOverrideRequestBody1(v V2RatelimitGetOverrideRequestBody1) error { + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeV2RatelimitGetOverrideRequestBody1 performs a merge with any union data inside the V2RatelimitGetOverrideRequestBody, using the provided V2RatelimitGetOverrideRequestBody1 +func (t *V2RatelimitGetOverrideRequestBody) MergeV2RatelimitGetOverrideRequestBody1(v V2RatelimitGetOverrideRequestBody1) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JSONMerge(t.union, b) + t.union = merged + return err +} + +func (t V2RatelimitGetOverrideRequestBody) MarshalJSON() ([]byte, error) { + b, err := t.union.MarshalJSON() + if err != nil { + return nil, err + } + object := make(map[string]json.RawMessage) + if t.union != nil { + err = json.Unmarshal(b, &object) + if err != nil { + return nil, err + } + } + + object["identifier"], err = json.Marshal(t.Identifier) + if err != nil { + return nil, fmt.Errorf("error marshaling 'identifier': %w", err) + } + + if t.NamespaceId != nil { + object["namespaceId"], err = json.Marshal(t.NamespaceId) + if err != nil { + return nil, fmt.Errorf("error marshaling 'namespaceId': %w", err) + } + } + + if t.NamespaceName != nil { + object["namespaceName"], err = json.Marshal(t.NamespaceName) + if err != nil { + return nil, fmt.Errorf("error marshaling 'namespaceName': %w", err) + } + } + b, err = json.Marshal(object) + return b, err +} + +func (t *V2RatelimitGetOverrideRequestBody) UnmarshalJSON(b []byte) error { + err := t.union.UnmarshalJSON(b) + if err != nil { + return err + } + object := make(map[string]json.RawMessage) + err = json.Unmarshal(b, &object) + if err != nil { + return err + } + + if raw, found := object["identifier"]; found { + err = json.Unmarshal(raw, &t.Identifier) + if err != nil { + return fmt.Errorf("error reading 'identifier': %w", err) + } + } + + if raw, found := object["namespaceId"]; found { + err = json.Unmarshal(raw, &t.NamespaceId) + if err != nil { + return fmt.Errorf("error reading 'namespaceId': %w", err) + } + } + + if raw, found := object["namespaceName"]; found { + err = json.Unmarshal(raw, &t.NamespaceName) + if err != nil { + return fmt.Errorf("error reading 'namespaceName': %w", err) + } + } + + return err +} diff --git a/go/apps/api/openapi/openapi.yaml b/go/apps/api/openapi/openapi.yaml index 16c1ab175c..495c74762f 100644 --- a/go/apps/api/openapi/openapi.yaml +++ b/go/apps/api/openapi/openapi.yaml @@ -1764,9 +1764,7 @@ components: example: 3600000 autoApply: type: boolean - description: | - Whether this rate limit should be automatically applied when verifying keys. - When true, we will automatically apply this limit during verification without it being explicitly listed. + description: Whether this rate limit was automatically applied when verifying the key. example: true required: - id @@ -1931,6 +1929,13 @@ components: required: - identifier type: object + oneOf: + - required: + - namespaceName + - identifier + - required: + - namespaceId + - identifier V2RatelimitGetOverrideResponseBody: type: object required: @@ -2504,6 +2509,13 @@ components: Use clear, semantic names that reflect the resources or actions being permitted. Names must be unique within your workspace to avoid confusion and conflicts. example: "users.read" + slug: + type: string + minLength: 1 + maxLength: 512 + description: | + The URL-safe identifier when this permission was created. + example: users-read description: type: string maxLength: 2048 @@ -2526,6 +2538,7 @@ components: required: - id - name + - slug - createdAt additionalProperties: false V2PermissionsCreatePermissionRequestBody: @@ -3361,7 +3374,7 @@ components: - type: string minLength: 1 maxLength: 100 # Keep permission names concise and readable - pattern: "^[a-zA-Z0-9_]+$" + pattern: "^[a-zA-Z0-9_.]+$" description: | Checks if the key has this specific permission. Supports hierarchical permissions where `documents.*` grants access to `documents.read` and `documents.write`. @@ -3384,7 +3397,7 @@ components: type: string minLength: 1 maxLength: 100 # Keep permission names concise and readable - pattern: "^[a-zA-Z0-9_]+$" + pattern: "^[a-zA-Z0-9_.]+$" maxItems: 100 # Allow complex permission sets without performance degradation minItems: 1 description: | @@ -3402,36 +3415,67 @@ components: When provided, verification fails unless the key has the specified permissions through direct assignment or role inheritance. Essential for implementing fine-grained authorization in multi-tenant or privilege-separated APIs. credits: - type: object - properties: - cost: - type: integer - format: int64 - minimum: 0 - maximum: 1000 # Reasonable upper bound for operation costs - default: 1 - description: | - Sets how many credits to deduct for this verification request. - Use 0 for read-only operations or free tier access, higher values for premium features. - Credits are deducted immediately upon verification, even if the key lacks required permissions. - Essential for implementing usage-based pricing with different operation costs. - example: 5 - additionalProperties: false - description: | - Controls credit consumption for usage-based billing and quota enforcement. - Omitting this field uses the default cost of 1 credit per verification. - Credits provide globally consistent usage tracking, essential for paid APIs with strict quotas. - Verification can succeed while credit deduction fails if the key has insufficient credits. + "$ref": "#/components/schemas/KeysVerifyKeyCredits" ratelimits: type: array items: - "$ref": "#/components/schemas/RatelimitRequest" + "$ref": "#/components/schemas/KeysVerifyKeyRatelimit" description: | Enforces time-based rate limiting during verification to prevent abuse and ensure fair usage. Omitting this field skips rate limit checks entirely, relying only on configured key rate limits. Multiple rate limits can be checked simultaneously, each with different costs and temporary overrides. Rate limit checks are optimized for performance but may allow brief bursts during high concurrency. additionalProperties: false + KeysVerifyKeyCredits: + type: object + required: + - cost + properties: + cost: + type: integer + format: int32 + minimum: 0 + maximum: 1000000000 + description: | + Sets how many credits to deduct for this verification request. + Use 0 for read-only operations or free tier access, higher values for premium features. + Credits are deducted immediately upon verification, even if the key lacks required permissions. + Essential for implementing usage-based pricing with different operation costs. + example: 5 + additionalProperties: false + description: | + Controls credit consumption for usage-based billing and quota enforcement. + Omitting this field uses the default cost of 1 credit per verification. + Credits provide globally consistent usage tracking, essential for paid APIs with strict quotas. + Verification can succeed while credit deduction fails if the key has insufficient credits. + KeysVerifyKeyRatelimit: + type: object + required: + - name + properties: + name: + type: string + minLength: 3 + maxLength: 255 # Reasonable upper bound for database identifiers + pattern: "" + description: References an existing ratelimit by its name. Key Ratelimits will take precedence over identifier-based limits. + example: tokens + cost: + type: integer + minimum: 0 + default: 1 + description: Optionally override how expensive this operation is and how many tokens are deducted from the current limit. + example: 2 + limit: + type: integer + minimum: 0 + description: Optionally override the maximum number of requests allowed within the specified interval. + example: 50 + duration: + type: integer + minimum: 0 + description: Optionally override the duration of the rate limit window duration. + example: 600000 KeysVerifyKeyResponseData: type: object properties: @@ -3449,7 +3493,6 @@ components: - FORBIDDEN - USAGE_EXCEEDED - RATE_LIMITED - - UNAUTHORIZED - DISABLED - INSUFFICIENT_PERMISSIONS - EXPIRED @@ -3473,13 +3516,6 @@ components: The human-readable name assigned to this key during creation. This is useful for displaying in logs or admin interfaces to identify the key's purpose or owner. - externalId: - type: string - description: - Your user/tenant identifier that was associated with this key - during creation. This allows you to connect the key back to your user - without additional database lookups, making it ideal for implementing - user-based authorization in stateless services. meta: type: object additionalProperties: true @@ -3536,14 +3572,75 @@ components: ratelimits: type: array items: - "$ref": "#/components/schemas/RatelimitResponse" - description: - Information about the rate limits applied during verification. - Only included when rate limits were checked. If verification failed with - `code=RATE_LIMITED`, this will show which specific rate limit was exceeded. + "$ref": "#/components/schemas/VerifyKeyRatelimitData" + description: The ratelimits that got checked required: - valid - code + VerifyKeyRatelimitData: + type: object + properties: + exceeded: + type: boolean + description: Whether the rate limit was exceeded. + id: + type: string + minLength: 8 + maxLength: 255 + pattern: "^rl_[a-zA-Z0-9_]+$" + description: Unique identifier for this rate limit configuration. + example: rl_1234567890abcdef + name: + type: string + minLength: 1 + maxLength: 128 + pattern: "^[a-zA-Z][a-zA-Z0-9_-]*$" + description: Human-readable name for this rate limit. + example: api_requests + limit: + type: integer + format: int64 + minimum: 1 + maximum: 1000000 + description: Maximum requests allowed within the time window. + example: 1000 + duration: + type: integer + format: int64 + minimum: 1000 + maximum: 2592000000 + description: Rate limit window duration in milliseconds. + example: 3600000 + reset: + type: integer + format: int64 + minimum: 1000 + maximum: 2592000000 + description: Rate limit reset duration in milliseconds. + example: 3600000 + remaining: + type: integer + format: int64 + minimum: 0 + maximum: 1000000 + description: Rate limit remaining requests within the time window. + example: 999 + autoApply: + type: boolean + description: | + Whether this rate limit should be automatically applied when verifying keys. + When true, we will automatically apply this limit during verification without it being explicitly listed. + example: true + required: + - id + - exceeded + - name + - limit + - duration + - reset + - remaining + - autoApply + additionalProperties: false V2IdentitiesUpdateIdentityRequestBody: type: object properties: diff --git a/go/apps/api/routes/register.go b/go/apps/api/routes/register.go index 96c4e3687b..bf566b34bb 100644 --- a/go/apps/api/routes/register.go +++ b/go/apps/api/routes/register.go @@ -40,6 +40,7 @@ import ( v2KeysSetPermissions "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_set_permissions" v2KeysSetRoles "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_set_roles" v2KeysUpdateCredits "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_update_credits" + v2KeysVerifyKey "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_verify_key" zen "github.com/unkeyed/unkey/go/pkg/zen" ) @@ -76,9 +77,7 @@ func Register(srv *zen.Server, svc *Services) { Keys: svc.Keys, ClickHouse: svc.ClickHouse, Ratelimit: svc.Ratelimit, - Permissions: svc.Permissions, RatelimitNamespaceByNameCache: svc.Caches.RatelimitNamespaceByName, - RatelimitOverrideMatchesCache: svc.Caches.RatelimitOverridesMatch, TestMode: srv.Flags().TestMode, }, ) @@ -87,11 +86,11 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2RatelimitSetOverride.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, - Auditlogs: svc.Auditlogs, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, + Auditlogs: svc.Auditlogs, + RatelimitNamespaceByNameCache: svc.Caches.RatelimitNamespaceByName, }, ) @@ -99,10 +98,10 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2RatelimitGetOverride.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, + RatelimitNamespaceByNameCache: svc.Caches.RatelimitNamespaceByName, }, ) @@ -110,11 +109,11 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2RatelimitDeleteOverride.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, - Auditlogs: svc.Auditlogs, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, + Auditlogs: svc.Auditlogs, + RatelimitNamespaceByNameCache: svc.Caches.RatelimitNamespaceByName, }, ) @@ -122,10 +121,9 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2RatelimitListOverrides.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, }, ) @@ -136,11 +134,10 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2IdentitiesCreateIdentity.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, - Auditlogs: svc.Auditlogs, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, + Auditlogs: svc.Auditlogs, }, ) @@ -148,11 +145,10 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2IdentitiesDeleteIdentity.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, - Auditlogs: svc.Auditlogs, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, + Auditlogs: svc.Auditlogs, }, ) @@ -160,10 +156,9 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2IdentitiesGetIdentity.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, }, ) @@ -171,10 +166,9 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2IdentitiesListIdentities.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, }, ) @@ -182,11 +176,10 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2IdentitiesUpdateIdentity.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, - Auditlogs: svc.Auditlogs, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, + Auditlogs: svc.Auditlogs, }, ) @@ -197,21 +190,19 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2ApisCreateApi.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, - Auditlogs: svc.Auditlogs, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, + Auditlogs: svc.Auditlogs, }, ) // v2/apis.getApi srv.RegisterRoute( defaultMiddlewares, &v2ApisGetApi.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, }, ) @@ -219,11 +210,10 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2ApisListKeys.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, - Vault: svc.Vault, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, + Vault: svc.Vault, }, ) @@ -231,12 +221,11 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2ApisDeleteApi.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, - Auditlogs: svc.Auditlogs, - Caches: svc.Caches, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, + Auditlogs: svc.Auditlogs, + Caches: svc.Caches, }, ) @@ -247,11 +236,10 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2PermissionsCreatePermission.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, - Auditlogs: svc.Auditlogs, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, + Auditlogs: svc.Auditlogs, }, ) @@ -259,10 +247,9 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2PermissionsGetPermission.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, }, ) @@ -270,10 +257,9 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2PermissionsGetRole.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, }, ) @@ -281,10 +267,9 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2PermissionsListPermissions.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, }, ) @@ -292,11 +277,10 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2PermissionsDeletePermission.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, - Auditlogs: svc.Auditlogs, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, + Auditlogs: svc.Auditlogs, }, ) @@ -304,11 +288,10 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2PermissionsCreateRole.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, - Auditlogs: svc.Auditlogs, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, + Auditlogs: svc.Auditlogs, }, ) @@ -316,10 +299,9 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2PermissionsListRoles.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, }, ) @@ -327,27 +309,37 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2PermissionsDeleteRole.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, - Auditlogs: svc.Auditlogs, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, + Auditlogs: svc.Auditlogs, }, ) // --------------------------------------------------------------------------- // v2/keys + // v2/keys.verifyKey + srv.RegisterRoute( + defaultMiddlewares, + &v2KeysVerifyKey.Handler{ + Logger: svc.Logger, + ClickHouse: svc.ClickHouse, + DB: svc.Database, + Keys: svc.Keys, + Auditlogs: svc.Auditlogs, + }, + ) + // v2/keys.createKey srv.RegisterRoute( defaultMiddlewares, &v2KeysCreateKey.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, - Auditlogs: svc.Auditlogs, - Vault: svc.Vault, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, + Auditlogs: svc.Auditlogs, + Vault: svc.Vault, }, ) @@ -355,12 +347,11 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2KeysGetKey.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, - Auditlogs: svc.Auditlogs, - Vault: svc.Vault, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, + Auditlogs: svc.Auditlogs, + Vault: svc.Vault, }, ) @@ -368,11 +359,11 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2KeysUpdateCredits.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, - Auditlogs: svc.Auditlogs, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, + Auditlogs: svc.Auditlogs, + KeyCache: svc.Caches.VerificationKeyByHash, }, ) @@ -380,22 +371,23 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2KeysSetRoles.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, - Auditlogs: svc.Auditlogs, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, + Auditlogs: svc.Auditlogs, + KeyCache: svc.Caches.VerificationKeyByHash, }, ) + // v2/keys.setPermissions srv.RegisterRoute( defaultMiddlewares, &v2KeysSetPermissions.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, - Auditlogs: svc.Auditlogs, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, + Auditlogs: svc.Auditlogs, + KeyCache: svc.Caches.VerificationKeyByHash, }, ) @@ -403,11 +395,11 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2KeysAddPermissions.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, - Auditlogs: svc.Auditlogs, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, + Auditlogs: svc.Auditlogs, + KeyCache: svc.Caches.VerificationKeyByHash, }, ) @@ -415,11 +407,11 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2KeysAddRoles.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, - Auditlogs: svc.Auditlogs, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, + Auditlogs: svc.Auditlogs, + KeyCache: svc.Caches.VerificationKeyByHash, }, ) @@ -427,11 +419,11 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2KeysRemovePermissions.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, - Auditlogs: svc.Auditlogs, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, + Auditlogs: svc.Auditlogs, + KeyCache: svc.Caches.VerificationKeyByHash, }, ) @@ -439,11 +431,11 @@ func Register(srv *zen.Server, svc *Services) { srv.RegisterRoute( defaultMiddlewares, &v2KeysRemoveRoles.Handler{ - Logger: svc.Logger, - DB: svc.Database, - Keys: svc.Keys, - Permissions: svc.Permissions, - Auditlogs: svc.Auditlogs, + Logger: svc.Logger, + DB: svc.Database, + Keys: svc.Keys, + Auditlogs: svc.Auditlogs, + KeyCache: svc.Caches.VerificationKeyByHash, }, ) diff --git a/go/apps/api/routes/services.go b/go/apps/api/routes/services.go index 739bf50521..d88ee87dc8 100644 --- a/go/apps/api/routes/services.go +++ b/go/apps/api/routes/services.go @@ -4,7 +4,6 @@ import ( "github.com/unkeyed/unkey/go/internal/services/auditlogs" "github.com/unkeyed/unkey/go/internal/services/caches" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/internal/services/ratelimit" "github.com/unkeyed/unkey/go/pkg/clickhouse" "github.com/unkeyed/unkey/go/pkg/clickhouse/schema" @@ -19,14 +18,13 @@ type EventBuffer interface { } type Services struct { - Logger logging.Logger - Database db.Database - Keys keys.KeyService - ClickHouse clickhouse.ClickHouse - Permissions permissions.PermissionService - Validator *validation.Validator - Ratelimit ratelimit.Service - Auditlogs auditlogs.AuditLogService - Caches caches.Caches - Vault *vault.Service + Logger logging.Logger + Database db.Database + Keys keys.KeyService + ClickHouse clickhouse.ClickHouse + Validator *validation.Validator + Ratelimit ratelimit.Service + Auditlogs auditlogs.AuditLogService + Caches caches.Caches + Vault *vault.Service } diff --git a/go/apps/api/routes/v2_apis_create_api/200_test.go b/go/apps/api/routes/v2_apis_create_api/200_test.go index bf975044f7..0b211932bc 100644 --- a/go/apps/api/routes/v2_apis_create_api/200_test.go +++ b/go/apps/api/routes/v2_apis_create_api/200_test.go @@ -2,16 +2,15 @@ package handler_test import ( "context" - "database/sql" "fmt" "net/http" "testing" - "time" "github.com/stretchr/testify/require" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_apis_create_api" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/uid" ) @@ -24,11 +23,10 @@ func TestCreateApiSuccessfully(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, } h.Register(route) @@ -43,34 +41,13 @@ func TestCreateApiSuccessfully(t *testing.T) { // This test validates that the underlying database queries work correctly // by bypassing the HTTP handler and directly testing the DB operations. t.Run("insert api via DB", func(t *testing.T) { - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: time.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: false, String: ""}, - DefaultBytes: sql.NullInt32{Valid: false, Int32: 0}, - }) - require.NoError(t, err) - - apiID := uid.New(uid.APIPrefix) - apiName := "test-api-db" - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: apiName, - WorkspaceID: h.Resources().UserWorkspace.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: keyAuthID}, - CreatedAtM: time.Now().UnixMilli(), - }) - require.NoError(t, err) + createdAPI := h.CreateApi(seed.CreateApiRequest{WorkspaceID: h.Resources().UserWorkspace.ID}) - api, err := db.Query.FindApiByID(ctx, h.DB.RO(), apiID) + api, err := db.Query.FindApiByID(ctx, h.DB.RO(), createdAPI.ID) require.NoError(t, err) - require.Equal(t, apiName, api.Name) require.Equal(t, h.Resources().UserWorkspace.ID, api.WorkspaceID) require.True(t, api.KeyAuthID.Valid) - require.Equal(t, keyAuthID, api.KeyAuthID.String) + require.Equal(t, createdAPI.KeyAuthID.String, api.KeyAuthID.String) }) // Test creating a basic API diff --git a/go/apps/api/routes/v2_apis_create_api/400_test.go b/go/apps/api/routes/v2_apis_create_api/400_test.go index f1bba2e3aa..bbc4cd75a8 100644 --- a/go/apps/api/routes/v2_apis_create_api/400_test.go +++ b/go/apps/api/routes/v2_apis_create_api/400_test.go @@ -22,11 +22,10 @@ func TestCreateApi_BadRequest(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_apis_create_api/401_test.go b/go/apps/api/routes/v2_apis_create_api/401_test.go index 0792b9a430..e0c87fe079 100644 --- a/go/apps/api/routes/v2_apis_create_api/401_test.go +++ b/go/apps/api/routes/v2_apis_create_api/401_test.go @@ -17,11 +17,10 @@ func TestCreateApi_Unauthorized(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_apis_create_api/403_test.go b/go/apps/api/routes/v2_apis_create_api/403_test.go index fb7003ffc0..8602f2c3f7 100644 --- a/go/apps/api/routes/v2_apis_create_api/403_test.go +++ b/go/apps/api/routes/v2_apis_create_api/403_test.go @@ -20,11 +20,10 @@ func TestCreateApi_Forbidden(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_apis_create_api/handler.go b/go/apps/api/routes/v2_apis_create_api/handler.go index 951bca8cb4..780911cf5a 100644 --- a/go/apps/api/routes/v2_apis_create_api/handler.go +++ b/go/apps/api/routes/v2_apis_create_api/handler.go @@ -10,7 +10,6 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/auditlogs" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/auditlog" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" @@ -25,11 +24,10 @@ type Request = openapi.V2ApisCreateApiRequestBody type Response = openapi.V2ApisCreateApiResponseBody type Handler struct { - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService - Auditlogs auditlogs.AuditLogService + Logger logging.Logger + DB db.Database + Keys keys.KeyService + Auditlogs auditlogs.AuditLogService } func (h *Handler) Method() string { @@ -41,7 +39,7 @@ func (h *Handler) Path() string { } func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -54,18 +52,13 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { ) } - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: "*", - Action: rbac.CreateAPI, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: "*", + Action: rbac.CreateAPI, + }), + ))) if err != nil { return err } @@ -94,6 +87,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, KeyAuthID: sql.NullString{Valid: true, String: keyAuthId}, + IpWhitelist: sql.NullString{Valid: false, String: ""}, CreatedAtM: time.Now().UnixMilli(), }) if err != nil { @@ -108,7 +102,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: auditlog.APICreateEvent, Display: fmt.Sprintf("Created API %s", apiId), - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorName: "root key", ActorMeta: map[string]any{}, ActorType: auditlog.RootKeyActor, diff --git a/go/apps/api/routes/v2_apis_delete_api/200_test.go b/go/apps/api/routes/v2_apis_delete_api/200_test.go index ac4c859b01..72072741e5 100644 --- a/go/apps/api/routes/v2_apis_delete_api/200_test.go +++ b/go/apps/api/routes/v2_apis_delete_api/200_test.go @@ -2,18 +2,15 @@ package handler_test import ( "context" - "database/sql" "fmt" "net/http" "testing" - "time" "github.com/stretchr/testify/require" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_apis_delete_api" "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/uid" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" ) func TestSuccess(t *testing.T) { @@ -21,12 +18,11 @@ func TestSuccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, - Caches: h.Caches, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, + Caches: h.Caches, } h.Register(route) @@ -45,38 +41,18 @@ func TestSuccess(t *testing.T) { // Test case for deleting an API without keys t.Run("delete api without keys", func(t *testing.T) { - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: time.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: false, String: ""}, - DefaultBytes: sql.NullInt32{Valid: false, Int32: 0}, - StoreEncryptedKeys: false, - }) - require.NoError(t, err) - - apiID := uid.New(uid.APIPrefix) - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: "Test API", - WorkspaceID: h.Resources().UserWorkspace.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: keyAuthID}, - CreatedAtM: time.Now().UnixMilli(), - }) - require.NoError(t, err) + api := h.CreateApi(seed.CreateApiRequest{WorkspaceID: h.Resources().UserWorkspace.ID}) // Ensure API exists before deletion - apiBeforeDelete, err := db.Query.FindApiByID(ctx, h.DB.RO(), apiID) + apiBeforeDelete, err := db.Query.FindApiByID(ctx, h.DB.RO(), api.ID) require.NoError(t, err) - require.Equal(t, apiID, apiBeforeDelete.ID) + require.Equal(t, api.ID, apiBeforeDelete.ID) require.False(t, apiBeforeDelete.DeletedAtM.Valid) // Delete the API req := handler.Request{ - ApiId: apiID, + ApiId: api.ID, } res := testutil.CallRoute[handler.Request, handler.Response]( @@ -91,59 +67,29 @@ func TestSuccess(t *testing.T) { require.NotEmpty(t, res.Body.Meta.RequestId) // Verify API is marked as deleted - apiAfterDelete, err := db.Query.FindApiByID(ctx, h.DB.RO(), apiID) + apiAfterDelete, err := db.Query.FindApiByID(ctx, h.DB.RO(), api.ID) require.NoError(t, err) // Should still find it, just marked as deleted require.True(t, apiAfterDelete.DeletedAtM.Valid) }) // Test case for deleting an API with active keys t.Run("delete api with active keys", func(t *testing.T) { - // Create keyring for the API - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: time.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: false, String: ""}, - DefaultBytes: sql.NullInt32{Valid: false, Int32: 0}, - StoreEncryptedKeys: false, - }) - require.NoError(t, err) + api := h.CreateApi(seed.CreateApiRequest{WorkspaceID: h.Resources().UserWorkspace.ID}) - // Create the API - apiID := uid.New(uid.APIPrefix) - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: "Test API With Keys", + createKey := h.CreateKey(seed.CreateKeyRequest{ + KeyAuthID: api.KeyAuthID.String, WorkspaceID: h.Resources().UserWorkspace.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: keyAuthID}, - CreatedAtM: time.Now().UnixMilli(), }) - require.NoError(t, err) - - // Create a key associated with this API - keyID := uid.New(uid.KeyPrefix) - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: time.Now().UnixMilli(), - // Add other required fields based on your schema - Hash: hash.Sha256(uid.New(uid.TestPrefix)), - Start: "teststart", - }) - require.NoError(t, err) // Ensure API exists before deletion - apiBeforeDelete, err := db.Query.FindApiByID(ctx, h.DB.RO(), apiID) + apiBeforeDelete, err := db.Query.FindApiByID(ctx, h.DB.RO(), api.ID) require.NoError(t, err) - require.Equal(t, apiID, apiBeforeDelete.ID) + require.Equal(t, api.ID, apiBeforeDelete.ID) require.False(t, apiBeforeDelete.DeletedAtM.Valid) // Delete the API req := handler.Request{ - ApiId: apiID, + ApiId: api.ID, } res := testutil.CallRoute[handler.Request, handler.Response]( @@ -158,53 +104,30 @@ func TestSuccess(t *testing.T) { require.NotEmpty(t, res.Body.Meta.RequestId) // Verify API is marked as deleted - apiAfterDelete, err := db.Query.FindApiByID(ctx, h.DB.RO(), apiID) + apiAfterDelete, err := db.Query.FindApiByID(ctx, h.DB.RO(), api.ID) require.NoError(t, err) require.True(t, apiAfterDelete.DeletedAtM.Valid) // Check that the key is still accessible (soft delete doesn't cascade to keys) - key, err := db.Query.FindKeyByID(ctx, h.DB.RO(), keyID) + key, err := db.Query.FindKeyByID(ctx, h.DB.RO(), createKey.KeyID) require.NoError(t, err) - require.Equal(t, keyID, key.ID) + require.Equal(t, createKey.KeyID, key.ID) }) // Test case for deleting an API immediately after creation t.Run("delete api immediately after creation", func(t *testing.T) { - // Create keyring for the API - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: time.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: false, String: ""}, - DefaultBytes: sql.NullInt32{Valid: false, Int32: 0}, - StoreEncryptedKeys: false, - }) - require.NoError(t, err) - - // Create the API - apiID := uid.New(uid.APIPrefix) - apiName := "Test Immediate Delete API" - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: apiName, - WorkspaceID: h.Resources().UserWorkspace.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: keyAuthID}, - CreatedAtM: time.Now().UnixMilli(), - }) - require.NoError(t, err) + api := h.CreateApi(seed.CreateApiRequest{WorkspaceID: h.Resources().UserWorkspace.ID}) // Verify the API was created - apiBeforeDelete, err := db.Query.FindApiByID(ctx, h.DB.RO(), apiID) + apiBeforeDelete, err := db.Query.FindApiByID(ctx, h.DB.RO(), api.ID) require.NoError(t, err) - require.Equal(t, apiID, apiBeforeDelete.ID) - require.Equal(t, apiName, apiBeforeDelete.Name) + require.Equal(t, api.ID, apiBeforeDelete.ID) + require.Equal(t, api.Name, apiBeforeDelete.Name) require.False(t, apiBeforeDelete.DeletedAtM.Valid) // Immediately delete the API without any delay req := handler.Request{ - ApiId: apiID, + ApiId: api.ID, } res := testutil.CallRoute[handler.Request, handler.Response]( @@ -219,7 +142,7 @@ func TestSuccess(t *testing.T) { require.NotEmpty(t, res.Body.Meta.RequestId) // Verify API is marked as deleted - apiAfterDelete, err := db.Query.FindApiByID(ctx, h.DB.RO(), apiID) + apiAfterDelete, err := db.Query.FindApiByID(ctx, h.DB.RO(), api.ID) require.NoError(t, err) require.True(t, apiAfterDelete.DeletedAtM.Valid) }) diff --git a/go/apps/api/routes/v2_apis_delete_api/400_test.go b/go/apps/api/routes/v2_apis_delete_api/400_test.go index 02eebeb54d..b49843db61 100644 --- a/go/apps/api/routes/v2_apis_delete_api/400_test.go +++ b/go/apps/api/routes/v2_apis_delete_api/400_test.go @@ -15,12 +15,11 @@ func TestValidationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, - Caches: h.Caches, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, + Caches: h.Caches, } h.Register(route) diff --git a/go/apps/api/routes/v2_apis_delete_api/401_test.go b/go/apps/api/routes/v2_apis_delete_api/401_test.go index b198d20ada..12edbd57ce 100644 --- a/go/apps/api/routes/v2_apis_delete_api/401_test.go +++ b/go/apps/api/routes/v2_apis_delete_api/401_test.go @@ -14,12 +14,11 @@ func TestAuthenticationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, - Caches: h.Caches, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, + Caches: h.Caches, } h.Register(route) diff --git a/go/apps/api/routes/v2_apis_delete_api/403_test.go b/go/apps/api/routes/v2_apis_delete_api/403_test.go index fcf4d26b15..e7265d12c4 100644 --- a/go/apps/api/routes/v2_apis_delete_api/403_test.go +++ b/go/apps/api/routes/v2_apis_delete_api/403_test.go @@ -1,32 +1,27 @@ package handler_test import ( - "context" - "database/sql" "fmt" "net/http" "testing" - "time" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_apis_delete_api" - "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/uid" ) func TestAuthorizationErrors(t *testing.T) { - ctx := context.Background() h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, - Caches: h.Caches, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, + Caches: h.Caches, } h.Register(route) @@ -34,29 +29,7 @@ func TestAuthorizationErrors(t *testing.T) { // Create a workspace workspace := h.Resources().UserWorkspace - // Create an API for testing - - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: h.Clock.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: false, String: ""}, - DefaultBytes: sql.NullInt32{Valid: false, Int32: 0}, - StoreEncryptedKeys: false, - }) - require.NoError(t, err) - - apiID := uid.New(uid.APIPrefix) - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: "Test API", - WorkspaceID: h.Resources().UserWorkspace.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: keyAuthID}, - CreatedAtM: time.Now().UnixMilli(), - }) - require.NoError(t, err) + api := h.CreateApi(seed.CreateApiRequest{WorkspaceID: h.Resources().UserWorkspace.ID}) // Test case for insufficient permissions - missing delete_api t.Run("missing delete_api permission", func(t *testing.T) { @@ -69,7 +42,7 @@ func TestAuthorizationErrors(t *testing.T) { } req := handler.Request{ - ApiId: apiID, + ApiId: api.ID, } res := testutil.CallRoute[handler.Request, openapi.ForbiddenErrorResponse]( @@ -100,7 +73,7 @@ func TestAuthorizationErrors(t *testing.T) { } req := handler.Request{ - ApiId: apiID, // Using the test API, not the one we have permission for + ApiId: api.ID, } res := testutil.CallRoute[handler.Request, openapi.ForbiddenErrorResponse]( @@ -118,10 +91,7 @@ func TestAuthorizationErrors(t *testing.T) { // Test case for wrong workspace t.Run("wrong workspace", func(t *testing.T) { - // Create a different workspace - - // Create a root key for the other workspace - rootKey := h.CreateRootKey(uid.New(uid.TestPrefix), "api.*.delete_api") + rootKey := h.CreateRootKey(uid.New(uid.WorkspacePrefix), "api.*.delete_api") headers := http.Header{ "Content-Type": {"application/json"}, @@ -129,7 +99,7 @@ func TestAuthorizationErrors(t *testing.T) { } req := handler.Request{ - ApiId: apiID, // API is in the original workspace + ApiId: api.ID, } res := testutil.CallRoute[handler.Request, openapi.NotFoundErrorResponse]( diff --git a/go/apps/api/routes/v2_apis_delete_api/404_test.go b/go/apps/api/routes/v2_apis_delete_api/404_test.go index 00288cd695..4da5c9de23 100644 --- a/go/apps/api/routes/v2_apis_delete_api/404_test.go +++ b/go/apps/api/routes/v2_apis_delete_api/404_test.go @@ -6,14 +6,13 @@ import ( "fmt" "net/http" "testing" - "time" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_apis_delete_api" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/uid" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" ) func TestNotFoundErrors(t *testing.T) { @@ -21,12 +20,11 @@ func TestNotFoundErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, - Caches: h.Caches, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, + Caches: h.Caches, } h.Register(route) @@ -64,38 +62,17 @@ func TestNotFoundErrors(t *testing.T) { // Test case for API that's already deleted t.Run("already deleted API", func(t *testing.T) { + api := h.CreateApi(seed.CreateApiRequest{WorkspaceID: h.Resources().UserWorkspace.ID}) - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: h.Clock.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: false, String: ""}, - DefaultBytes: sql.NullInt32{Valid: false, Int32: 0}, - StoreEncryptedKeys: false, - }) - require.NoError(t, err) - - apiID := uid.New(uid.APIPrefix) - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: "Test API", - WorkspaceID: h.Resources().UserWorkspace.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: keyAuthID}, - CreatedAtM: time.Now().UnixMilli(), - }) - require.NoError(t, err) - - err = db.Query.SoftDeleteApi(ctx, h.DB.RW(), db.SoftDeleteApiParams{ - ApiID: apiID, + err := db.Query.SoftDeleteApi(ctx, h.DB.RW(), db.SoftDeleteApiParams{ + ApiID: api.ID, Now: sql.NullInt64{Valid: true, Int64: h.Clock.Now().UnixMilli()}, }) require.NoError(t, err) // Try to delete it again req := handler.Request{ - ApiId: apiID, + ApiId: api.ID, } res := testutil.CallRoute[handler.Request, openapi.NotFoundErrorResponse]( diff --git a/go/apps/api/routes/v2_apis_delete_api/412_test.go b/go/apps/api/routes/v2_apis_delete_api/412_test.go index e3fa99099e..5f739498ac 100644 --- a/go/apps/api/routes/v2_apis_delete_api/412_test.go +++ b/go/apps/api/routes/v2_apis_delete_api/412_test.go @@ -6,14 +6,13 @@ import ( "fmt" "net/http" "testing" - "time" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_apis_delete_api" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/uid" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" ) func TestDeleteProtection(t *testing.T) { @@ -21,12 +20,11 @@ func TestDeleteProtection(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, - Caches: h.Caches, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, + Caches: h.Caches, } h.Register(route) @@ -45,45 +43,24 @@ func TestDeleteProtection(t *testing.T) { // Test case for deleting an API with delete protection enabled t.Run("delete protected API", func(t *testing.T) { + api := h.CreateApi(seed.CreateApiRequest{WorkspaceID: h.Resources().UserWorkspace.ID}) - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: h.Clock.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: false, String: ""}, - DefaultBytes: sql.NullInt32{Valid: false, Int32: 0}, - StoreEncryptedKeys: false, - }) - require.NoError(t, err) - - apiID := uid.New(uid.APIPrefix) - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: "Test API", - WorkspaceID: h.Resources().UserWorkspace.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: keyAuthID}, - CreatedAtM: time.Now().UnixMilli(), - }) - require.NoError(t, err) - - err = db.Query.UpdateApiDeleteProtection(ctx, h.DB.RW(), db.UpdateApiDeleteProtectionParams{ - ApiID: apiID, + err := db.Query.UpdateApiDeleteProtection(ctx, h.DB.RW(), db.UpdateApiDeleteProtectionParams{ + ApiID: api.ID, DeleteProtection: sql.NullBool{Valid: true, Bool: true}, }) require.NoError(t, err) // Ensure API exists and has delete protection - apiBeforeDelete, err := db.Query.FindApiByID(ctx, h.DB.RO(), apiID) + apiBeforeDelete, err := db.Query.FindApiByID(ctx, h.DB.RO(), api.ID) require.NoError(t, err) - require.Equal(t, apiID, apiBeforeDelete.ID) + require.Equal(t, api.ID, apiBeforeDelete.ID) require.True(t, apiBeforeDelete.DeleteProtection.Valid) require.True(t, apiBeforeDelete.DeleteProtection.Bool) // Attempt to delete the API req := handler.Request{ - ApiId: apiID, + ApiId: api.ID, } res := testutil.CallRoute[handler.Request, openapi.PreconditionFailedErrorResponse]( @@ -99,7 +76,7 @@ func TestDeleteProtection(t *testing.T) { require.Equal(t, "This API has delete protection enabled. Disable it before attempting to delete.", res.Body.Error.Detail) // Verify API was NOT deleted - apiAfterDelete, err := db.Query.FindApiByID(ctx, h.DB.RO(), apiID) + apiAfterDelete, err := db.Query.FindApiByID(ctx, h.DB.RO(), api.ID) require.NoError(t, err) require.False(t, apiAfterDelete.DeletedAtM.Valid, "API should not have been deleted") }) diff --git a/go/apps/api/routes/v2_apis_delete_api/cache_validation_test.go b/go/apps/api/routes/v2_apis_delete_api/cache_validation_test.go index da12b935d7..1ae0abe401 100644 --- a/go/apps/api/routes/v2_apis_delete_api/cache_validation_test.go +++ b/go/apps/api/routes/v2_apis_delete_api/cache_validation_test.go @@ -2,11 +2,9 @@ package handler_test import ( "context" - "database/sql" "fmt" "net/http" "testing" - "time" "github.com/stretchr/testify/require" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_apis_delete_api" @@ -14,7 +12,7 @@ import ( "github.com/unkeyed/unkey/go/pkg/cache" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/uid" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" ) func TestCacheInvalidation(t *testing.T) { @@ -22,12 +20,11 @@ func TestCacheInvalidation(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, - Caches: h.Caches, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, + Caches: h.Caches, } h.Register(route) @@ -46,38 +43,16 @@ func TestCacheInvalidation(t *testing.T) { // Test case for verifying cache invalidation t.Run("verify cache invalidation", func(t *testing.T) { - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: h.Clock.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: false, String: ""}, - DefaultBytes: sql.NullInt32{Valid: false, Int32: 0}, - StoreEncryptedKeys: false, - }) - require.NoError(t, err) - - apiID := uid.New(uid.APIPrefix) - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: "Test API", - WorkspaceID: h.Resources().UserWorkspace.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: keyAuthID}, - CreatedAtM: time.Now().UnixMilli(), - }) - require.NoError(t, err) + api := h.CreateApi(seed.CreateApiRequest{WorkspaceID: h.Resources().UserWorkspace.ID}) // Get API to ensure it's in the cache - _, err = h.Caches.ApiByID.SWR(ctx, apiID, func(ctx context.Context) (db.Api, error) { - - return db.Query.FindApiByID(ctx, h.DB.RO(), apiID) - + _, err := h.Caches.ApiByID.SWR(ctx, api.ID, func(ctx context.Context) (db.Api, error) { + return db.Query.FindApiByID(ctx, h.DB.RO(), api.ID) }, caches.DefaultFindFirstOp) require.NoError(t, err) // Delete the API req := handler.Request{ - ApiId: apiID, + ApiId: api.ID, } res := testutil.CallRoute[handler.Request, handler.Response]( @@ -90,12 +65,12 @@ func TestCacheInvalidation(t *testing.T) { require.Equal(t, 200, res.Status) // Verify API is soft-deleted in the database - apiAfterDelete, err := db.Query.FindApiByID(ctx, h.DB.RO(), apiID) + apiAfterDelete, err := db.Query.FindApiByID(ctx, h.DB.RO(), api.ID) require.NoError(t, err) require.True(t, apiAfterDelete.DeletedAtM.Valid) // Verify the API is deleted in the cache - _, hit := h.Caches.ApiByID.Get(ctx, apiID) + _, hit := h.Caches.ApiByID.Get(ctx, api.ID) require.Equal(t, cache.Null, hit) }) } diff --git a/go/apps/api/routes/v2_apis_delete_api/handler.go b/go/apps/api/routes/v2_apis_delete_api/handler.go index 73d95ef543..a12fdc042f 100644 --- a/go/apps/api/routes/v2_apis_delete_api/handler.go +++ b/go/apps/api/routes/v2_apis_delete_api/handler.go @@ -11,7 +11,6 @@ import ( "github.com/unkeyed/unkey/go/internal/services/auditlogs" "github.com/unkeyed/unkey/go/internal/services/caches" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/auditlog" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" @@ -27,12 +26,11 @@ type Response = openapi.V2ApisDeleteApiResponseBody // Handler implements zen.Route interface for the v2 APIs delete API endpoint type Handler struct { // Services as public fields - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService - Auditlogs auditlogs.AuditLogService - Caches caches.Caches + Logger logging.Logger + DB db.Database + Keys keys.KeyService + Auditlogs auditlogs.AuditLogService + Caches caches.Caches } // Method returns the HTTP method this route responds to @@ -47,7 +45,7 @@ func (h *Handler) Path() string { // Handle processes the HTTP request func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -60,22 +58,18 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { ) } - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: "*", - Action: rbac.DeleteAPI, - }), - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: req.ApiId, - Action: rbac.DeleteAPI, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: "*", + Action: rbac.DeleteAPI, + }), + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: req.ApiId, + Action: rbac.DeleteAPI, + }), + ))) if err != nil { return err } @@ -114,12 +108,12 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { if api.DeleteProtection.Valid && api.DeleteProtection.Bool { return fault.New("delete protected", fault.Code(codes.App.Protection.ProtectedResource.URN()), - fault.Internal("api is protected from deletion"), fault.Public("This API has delete protection enabled. Disable it before attempting to delete."), + fault.Internal("api is protected from deletion"), + fault.Public("This API has delete protection enabled. Disable it before attempting to delete."), ) } now := time.Now() - err = db.Tx(ctx, h.DB.RW(), func(ctx context.Context, tx db.DBTX) error { // Soft delete the API err = db.Query.SoftDeleteApi(ctx, tx, db.SoftDeleteApiParams{ @@ -138,7 +132,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: auditlog.APIDeleteEvent, ActorType: auditlog.RootKeyActor, - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorName: "root key", ActorMeta: map[string]any{}, Display: fmt.Sprintf("Deleted API %s", req.ApiId), diff --git a/go/apps/api/routes/v2_apis_delete_api/idempotent_test.go b/go/apps/api/routes/v2_apis_delete_api/idempotent_test.go index b407feb77a..31f502a17f 100644 --- a/go/apps/api/routes/v2_apis_delete_api/idempotent_test.go +++ b/go/apps/api/routes/v2_apis_delete_api/idempotent_test.go @@ -1,32 +1,26 @@ package handler_test import ( - "context" - "database/sql" "fmt" "net/http" "testing" - "time" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_apis_delete_api" - "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/uid" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" ) func TestIdempotentDeletion(t *testing.T) { - ctx := context.Background() h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, - Caches: h.Caches, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, + Caches: h.Caches, } h.Register(route) @@ -45,31 +39,11 @@ func TestIdempotentDeletion(t *testing.T) { // Test case for idempotent deletion t.Run("idempotent deletion - multiple delete requests", func(t *testing.T) { - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: h.Clock.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: false, String: ""}, - DefaultBytes: sql.NullInt32{Valid: false, Int32: 0}, - StoreEncryptedKeys: false, - }) - require.NoError(t, err) - - apiID := uid.New(uid.APIPrefix) - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: "Test API", - WorkspaceID: h.Resources().UserWorkspace.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: keyAuthID}, - CreatedAtM: time.Now().UnixMilli(), - }) - require.NoError(t, err) + api := h.CreateApi(seed.CreateApiRequest{WorkspaceID: h.Resources().UserWorkspace.ID}) // First deletion - should succeed req := handler.Request{ - ApiId: apiID, + ApiId: api.ID, } res1 := testutil.CallRoute[handler.Request, handler.Response]( diff --git a/go/apps/api/routes/v2_apis_get_api/200_test.go b/go/apps/api/routes/v2_apis_get_api/200_test.go index 9f11573644..dc41c2d8e6 100644 --- a/go/apps/api/routes/v2_apis_get_api/200_test.go +++ b/go/apps/api/routes/v2_apis_get_api/200_test.go @@ -11,8 +11,9 @@ import ( "github.com/stretchr/testify/require" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_apis_get_api" "github.com/unkeyed/unkey/go/pkg/db" + "github.com/unkeyed/unkey/go/pkg/ptr" "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/uid" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" ) func TestGetApiSuccessfully(t *testing.T) { @@ -20,10 +21,9 @@ func TestGetApiSuccessfully(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, } h.Register(route) @@ -37,16 +37,11 @@ func TestGetApiSuccessfully(t *testing.T) { "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, } - // Create a test API - apiID := uid.New(uid.APIPrefix) apiName := "test-get-existing-api" - err := db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: apiName, + api := h.CreateApi(seed.CreateApiRequest{ WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: time.Now().UnixMilli(), + Name: &apiName, }) - require.NoError(t, err) // Make the request to get the API res := testutil.CallRoute[handler.Request, handler.Response]( @@ -54,29 +49,23 @@ func TestGetApiSuccessfully(t *testing.T) { route, headers, handler.Request{ - ApiId: apiID, + ApiId: api.ID, }, ) require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) require.NotNil(t, res.Body) - require.Equal(t, apiID, res.Body.Data.Id) + require.Equal(t, api.ID, res.Body.Data.Id) require.Equal(t, apiName, res.Body.Data.Name) }) // Test with different authorization scopes t.Run("authorization scopes", func(t *testing.T) { - // Create a new test API - apiName := "test-get-api" - apiID := uid.New(uid.APIPrefix) - - err := db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: apiName, + apiName := "test-get-existing-api" + api := h.CreateApi(seed.CreateApiRequest{ WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: time.Now().UnixMilli(), + Name: &apiName, }) - require.NoError(t, err) testCases := []struct { name string @@ -95,7 +84,7 @@ func TestGetApiSuccessfully(t *testing.T) { }, { name: "specific api permission", - permissions: []string{fmt.Sprintf("api.%s.read_api", apiID)}, + permissions: []string{fmt.Sprintf("api.%s.read_api", api.ID)}, expectedStatus: 200, }, { @@ -118,55 +107,20 @@ func TestGetApiSuccessfully(t *testing.T) { route, headers, handler.Request{ - ApiId: apiID, + ApiId: api.ID, }, ) require.Equal(t, tc.expectedStatus, res.Status, "expected %d, received: %#v", tc.expectedStatus, res) if tc.expectedStatus == 200 { require.NotNil(t, res.Body) - require.Equal(t, apiID, res.Body.Data.Id) + require.Equal(t, api.ID, res.Body.Data.Id) require.Equal(t, apiName, res.Body.Data.Name) } }) } }) - // Test with API that has IP whitelist - t.Run("get api with ip whitelist", func(t *testing.T) { - rootKey := h.CreateRootKey(h.Resources().UserWorkspace.ID, "api.*.read_api") - headers := http.Header{ - "Content-Type": {"application/json"}, - "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, - } - - // Create API with IP whitelist - apiID := uid.New(uid.APIPrefix) - apiName := "api-with-ip-whitelist" - - err := db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: apiName, - WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: time.Now().UnixMilli(), - }) - require.NoError(t, err) - - res := testutil.CallRoute[handler.Request, handler.Response]( - h, - route, - headers, - handler.Request{ - ApiId: apiID, - }, - ) - - require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) - require.NotNil(t, res.Body) - require.Equal(t, apiID, res.Body.Data.Id) - require.Equal(t, apiName, res.Body.Data.Name) - }) - // Test API with very long name t.Run("get api with very long name", func(t *testing.T) { rootKey := h.CreateRootKey(h.Resources().UserWorkspace.ID, "api.*.read_api") @@ -176,29 +130,23 @@ func TestGetApiSuccessfully(t *testing.T) { } // Create API with a very long name - apiID := uid.New(uid.APIPrefix) apiName := "this-is-a-very-long-api-name-for-testing-the-limits-of-what-the-system-can-handle-when-dealing-with-extremely-verbose-identifiers-that-might-challenge-database-storage-ui-rendering-and-overall-system-performance-with-edge-cases" - - err := db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: apiName, + api := h.CreateApi(seed.CreateApiRequest{ WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: time.Now().UnixMilli(), + Name: &apiName, }) - require.NoError(t, err) - res := testutil.CallRoute[handler.Request, handler.Response]( h, route, headers, handler.Request{ - ApiId: apiID, + ApiId: api.ID, }, ) require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) require.NotNil(t, res.Body) - require.Equal(t, apiID, res.Body.Data.Id) + require.Equal(t, api.ID, res.Body.Data.Id) require.Equal(t, apiName, res.Body.Data.Name, "The long name should be returned exactly as stored") }) @@ -211,29 +159,24 @@ func TestGetApiSuccessfully(t *testing.T) { } // Create API with special characters in name - apiID := uid.New(uid.APIPrefix) apiName := "special!@#$%^&*()_+-=[]{}|;:,.<>?/~` characters" - - err := db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: apiName, + api := h.CreateApi(seed.CreateApiRequest{ WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: time.Now().UnixMilli(), + Name: &apiName, }) - require.NoError(t, err) res := testutil.CallRoute[handler.Request, handler.Response]( h, route, headers, handler.Request{ - ApiId: apiID, + ApiId: api.ID, }, ) require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) require.NotNil(t, res.Body) - require.Equal(t, apiID, res.Body.Data.Id) + require.Equal(t, api.ID, res.Body.Data.Id) require.Equal(t, apiName, res.Body.Data.Name, "Special characters should be preserved in the name") }) @@ -246,29 +189,24 @@ func TestGetApiSuccessfully(t *testing.T) { } // Create API with Unicode characters in name - apiID := uid.New(uid.APIPrefix) apiName := "Unicode 测试 API 名称 🔑 🔒 ✅" - - err := db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: apiName, + api := h.CreateApi(seed.CreateApiRequest{ WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: time.Now().UnixMilli(), + Name: &apiName, }) - require.NoError(t, err) res := testutil.CallRoute[handler.Request, handler.Response]( h, route, headers, handler.Request{ - ApiId: apiID, + ApiId: api.ID, }, ) require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) require.NotNil(t, res.Body) - require.Equal(t, apiID, res.Body.Data.Id) + require.Equal(t, api.ID, res.Body.Data.Id) require.Equal(t, apiName, res.Body.Data.Name, "Unicode characters should be preserved in the name") }) @@ -284,16 +222,12 @@ func TestGetApiSuccessfully(t *testing.T) { creationTime := time.Now().UnixMilli() // Create a new API - apiID := uid.New(uid.APIPrefix) apiName := "recent-api" - - err := db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: apiName, + api := h.CreateApi(seed.CreateApiRequest{ WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: creationTime, + Name: &apiName, + CreatedAt: &creationTime, }) - require.NoError(t, err) // Immediately retrieve the API res := testutil.CallRoute[handler.Request, handler.Response]( @@ -301,17 +235,17 @@ func TestGetApiSuccessfully(t *testing.T) { route, headers, handler.Request{ - ApiId: apiID, + ApiId: api.ID, }, ) require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) require.NotNil(t, res.Body) - require.Equal(t, apiID, res.Body.Data.Id) + require.Equal(t, api.ID, res.Body.Data.Id) require.Equal(t, apiName, res.Body.Data.Name) // Verify in database that timestamp is correct - api, err := db.Query.FindApiByID(ctx, h.DB.RO(), apiID) + api, err := db.Query.FindApiByID(ctx, h.DB.RO(), api.ID) require.NoError(t, err) require.Equal(t, creationTime, api.CreatedAtM, "Creation timestamp should match") }) @@ -324,36 +258,20 @@ func TestGetApiSuccessfully(t *testing.T) { "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, } - // Create keyring for the API - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: time.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: true, String: "test_"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - StoreEncryptedKeys: true, - }) - require.NoError(t, err) - - // Create API with all fields populated and delete protection enabled - apiID := uid.New(uid.APIPrefix) apiName := "complete-verification-api" creationTime := time.Now().UnixMilli() - - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: apiName, - WorkspaceID: h.Resources().UserWorkspace.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: keyAuthID}, - CreatedAtM: creationTime, + createdApi := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: h.Resources().UserWorkspace.ID, + Name: &apiName, + EncryptedKeys: true, + CreatedAt: &creationTime, + DefaultPrefix: ptr.P("test_"), + DefaultBytes: ptr.P(int32(16)), }) - require.NoError(t, err) // Set delete protection after API creation - err = db.Query.UpdateApiDeleteProtection(ctx, h.DB.RW(), db.UpdateApiDeleteProtectionParams{ - ApiID: apiID, + err := db.Query.UpdateApiDeleteProtection(ctx, h.DB.RW(), db.UpdateApiDeleteProtectionParams{ + ApiID: createdApi.ID, DeleteProtection: sql.NullBool{Valid: true, Bool: true}, }) require.NoError(t, err) @@ -364,17 +282,17 @@ func TestGetApiSuccessfully(t *testing.T) { route, headers, handler.Request{ - ApiId: apiID, + ApiId: createdApi.ID, }, ) require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) require.NotNil(t, res.Body) - require.Equal(t, apiID, res.Body.Data.Id) + require.Equal(t, createdApi.ID, res.Body.Data.Id) require.Equal(t, apiName, res.Body.Data.Name) // Verify database record matches exactly what's returned - api, err := db.Query.FindApiByID(ctx, h.DB.RO(), apiID) + api, err := db.Query.FindApiByID(ctx, h.DB.RO(), createdApi.ID) require.NoError(t, err) // Verify core fields @@ -387,7 +305,7 @@ func TestGetApiSuccessfully(t *testing.T) { require.True(t, api.AuthType.Valid) require.Equal(t, db.ApisAuthTypeKey, api.AuthType.ApisAuthType) require.True(t, api.KeyAuthID.Valid) - require.Equal(t, keyAuthID, api.KeyAuthID.String) + require.Equal(t, createdApi.KeyAuthID.String, api.KeyAuthID.String) require.Equal(t, creationTime, api.CreatedAtM) require.True(t, api.DeleteProtection.Valid) require.True(t, api.DeleteProtection.Bool) diff --git a/go/apps/api/routes/v2_apis_get_api/400_test.go b/go/apps/api/routes/v2_apis_get_api/400_test.go index 775400fabe..6da942e6a3 100644 --- a/go/apps/api/routes/v2_apis_get_api/400_test.go +++ b/go/apps/api/routes/v2_apis_get_api/400_test.go @@ -14,10 +14,9 @@ func TestGetApiInvalidRequest(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, } h.Register(route) @@ -68,10 +67,6 @@ func TestGetApiInvalidRequest(t *testing.T) { require.NotEmpty(t, res.Body.Error) }) - // We're unable to directly test invalid JSON types using CallRoute as Go's type system will prevent it - - // We're unable to directly test non-JSON content using CallRoute - // Test with a valid apiId t.Run("valid request", func(t *testing.T) { // Create a test API in the database diff --git a/go/apps/api/routes/v2_apis_get_api/403_test.go b/go/apps/api/routes/v2_apis_get_api/403_test.go index a578ea1bd3..2a547d05c3 100644 --- a/go/apps/api/routes/v2_apis_get_api/403_test.go +++ b/go/apps/api/routes/v2_apis_get_api/403_test.go @@ -1,42 +1,30 @@ package handler_test import ( - "context" "fmt" "net/http" "testing" - "time" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_apis_get_api" - "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/uid" ) func TestGetApiInsufficientPermissions(t *testing.T) { - ctx := context.Background() h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, } h.Register(route) - // Create a test API - apiID := uid.New(uid.APIPrefix) - err := db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: "test-api-permissions", - WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: time.Now().UnixMilli(), - }) - require.NoError(t, err) + api := h.CreateApi(seed.CreateApiRequest{WorkspaceID: h.Resources().UserWorkspace.ID}) testCases := []struct { name string @@ -56,7 +44,7 @@ func TestGetApiInsufficientPermissions(t *testing.T) { }, { name: "wrong scope for specific api", - permissions: []string{fmt.Sprintf("api.%s.create_api", apiID)}, + permissions: []string{fmt.Sprintf("api.%s.create_api", api.ID)}, }, { name: "permission for different api", @@ -81,7 +69,7 @@ func TestGetApiInsufficientPermissions(t *testing.T) { route, headers, handler.Request{ - ApiId: apiID, + ApiId: api.ID, }, ) @@ -108,7 +96,7 @@ func TestGetApiInsufficientPermissions(t *testing.T) { route, headers, handler.Request{ - ApiId: apiID, + ApiId: api.ID, }, ) diff --git a/go/apps/api/routes/v2_apis_get_api/404_test.go b/go/apps/api/routes/v2_apis_get_api/404_test.go index 01272b9b9e..d955f25de2 100644 --- a/go/apps/api/routes/v2_apis_get_api/404_test.go +++ b/go/apps/api/routes/v2_apis_get_api/404_test.go @@ -13,6 +13,7 @@ import ( handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_apis_get_api" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/uid" ) @@ -21,10 +22,9 @@ func TestGetApiNotFound(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, } h.Register(route) @@ -58,22 +58,14 @@ func TestGetApiNotFound(t *testing.T) { // Create a different workspace otherWorkspaceID := uid.New(uid.WorkspacePrefix) - // Create API in the different workspace - apiID := uid.New(uid.APIPrefix) - err := db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: "other-workspace-api", - WorkspaceID: otherWorkspaceID, - CreatedAtM: time.Now().UnixMilli(), - }) - require.NoError(t, err) + diffApi := h.CreateApi(seed.CreateApiRequest{WorkspaceID: otherWorkspaceID}) res := testutil.CallRoute[handler.Request, openapi.NotFoundErrorResponse]( h, route, headers, handler.Request{ - ApiId: apiID, + ApiId: diffApi.ID, }, ) @@ -84,31 +76,23 @@ func TestGetApiNotFound(t *testing.T) { // Test with soft-deleted API t.Run("deleted api", func(t *testing.T) { - // Create an API - apiID := uid.New(uid.APIPrefix) - err := db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: "to-be-deleted-api", - WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: time.Now().UnixMilli(), - }) - require.NoError(t, err) + diffApi := h.CreateApi(seed.CreateApiRequest{WorkspaceID: h.Resources().UserWorkspace.ID}) // Verify it exists - api, err := db.Query.FindApiByID(ctx, h.DB.RO(), apiID) + api, err := db.Query.FindApiByID(ctx, h.DB.RO(), diffApi.ID) require.NoError(t, err) - require.Equal(t, apiID, api.ID) + require.Equal(t, diffApi.ID, api.ID) require.False(t, api.DeletedAtM.Valid) // Mark API as deleted by setting DeletedAtM err = db.Query.SoftDeleteApi(ctx, h.DB.RW(), db.SoftDeleteApiParams{ - ApiID: apiID, + ApiID: diffApi.ID, Now: sql.NullInt64{Valid: true, Int64: time.Now().UnixMilli()}, }) require.NoError(t, err) // Verify it's marked as deleted - deletedApi, err := db.Query.FindApiByID(ctx, h.DB.RO(), apiID) + deletedApi, err := db.Query.FindApiByID(ctx, h.DB.RO(), diffApi.ID) require.NoError(t, err) require.True(t, deletedApi.DeletedAtM.Valid) @@ -118,7 +102,7 @@ func TestGetApiNotFound(t *testing.T) { route, headers, handler.Request{ - ApiId: apiID, + ApiId: diffApi.ID, }, ) diff --git a/go/apps/api/routes/v2_apis_get_api/handler.go b/go/apps/api/routes/v2_apis_get_api/handler.go index a799433289..bdee59bfe4 100644 --- a/go/apps/api/routes/v2_apis_get_api/handler.go +++ b/go/apps/api/routes/v2_apis_get_api/handler.go @@ -6,7 +6,6 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/fault" @@ -21,10 +20,9 @@ type Response = openapi.V2ApisGetApiResponseBody // Handler implements zen.Route interface for the v2 APIs get API endpoint type Handler struct { // Services as public fields - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService + Logger logging.Logger + DB db.Database + Keys keys.KeyService } // Method returns the HTTP method this route responds to @@ -40,7 +38,7 @@ func (h *Handler) Path() string { // Handle processes the HTTP request func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { h.Logger.Debug("handling request", "requestId", s.RequestID(), "path", "/v2/apis.getApi") - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -53,22 +51,18 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { ) } - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: "*", - Action: rbac.ReadAPI, - }), - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: req.ApiId, - Action: rbac.ReadAPI, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: "*", + Action: rbac.ReadAPI, + }), + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: req.ApiId, + Action: rbac.ReadAPI, + }), + ))) if err != nil { return err } diff --git a/go/apps/api/routes/v2_apis_list_keys/200_test.go b/go/apps/api/routes/v2_apis_list_keys/200_test.go index d31a46f62c..617cfa707e 100644 --- a/go/apps/api/routes/v2_apis_list_keys/200_test.go +++ b/go/apps/api/routes/v2_apis_list_keys/200_test.go @@ -25,11 +25,10 @@ func TestSuccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Vault: h.Vault, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Vault: h.Vault, } h.Register(route) @@ -164,10 +163,6 @@ func TestSuccess(t *testing.T) { CreatedAtM: time.Now().UnixMilli(), Enabled: keyData.enabled, RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, } if keyData.identityID != nil { @@ -513,11 +508,8 @@ func TestSuccess(t *testing.T) { Meta: sql.NullString{Valid: false}, Expires: sql.NullTime{Valid: false}, RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, - IdentityID: sql.NullString{Valid: false}, + + IdentityID: sql.NullString{Valid: false}, }) require.NoError(t, err) @@ -536,11 +528,8 @@ func TestSuccess(t *testing.T) { Meta: sql.NullString{Valid: false}, Expires: sql.NullTime{Valid: false}, RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, - IdentityID: sql.NullString{Valid: false}, + + IdentityID: sql.NullString{Valid: false}, }) require.NoError(t, err) diff --git a/go/apps/api/routes/v2_apis_list_keys/400_test.go b/go/apps/api/routes/v2_apis_list_keys/400_test.go index 558b9d2a30..e76db8b894 100644 --- a/go/apps/api/routes/v2_apis_list_keys/400_test.go +++ b/go/apps/api/routes/v2_apis_list_keys/400_test.go @@ -15,11 +15,10 @@ func TestValidationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Vault: h.Vault, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Vault: h.Vault, } h.Register(route) diff --git a/go/apps/api/routes/v2_apis_list_keys/401_test.go b/go/apps/api/routes/v2_apis_list_keys/401_test.go index 264ee30fcf..06a93d61a0 100644 --- a/go/apps/api/routes/v2_apis_list_keys/401_test.go +++ b/go/apps/api/routes/v2_apis_list_keys/401_test.go @@ -14,11 +14,10 @@ func TestAuthenticationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Vault: h.Vault, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Vault: h.Vault, } h.Register(route) diff --git a/go/apps/api/routes/v2_apis_list_keys/403_test.go b/go/apps/api/routes/v2_apis_list_keys/403_test.go index 72ed3c0397..3c31665242 100644 --- a/go/apps/api/routes/v2_apis_list_keys/403_test.go +++ b/go/apps/api/routes/v2_apis_list_keys/403_test.go @@ -21,11 +21,10 @@ func TestAuthorizationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Vault: h.Vault, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Vault: h.Vault, } h.Register(route) diff --git a/go/apps/api/routes/v2_apis_list_keys/404_test.go b/go/apps/api/routes/v2_apis_list_keys/404_test.go index b8056bda24..1ef23291e2 100644 --- a/go/apps/api/routes/v2_apis_list_keys/404_test.go +++ b/go/apps/api/routes/v2_apis_list_keys/404_test.go @@ -13,6 +13,7 @@ import ( handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_apis_list_keys" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/uid" ) @@ -21,11 +22,10 @@ func TestNotFoundErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Vault: h.Vault, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Vault: h.Vault, } h.Register(route) @@ -243,28 +243,13 @@ func TestNotFoundErrors(t *testing.T) { // Test case for API that exists but has no keys (should return 200 with empty array) t.Run("API exists but has no keys", func(t *testing.T) { - // Create a keyAuth for the API - emptyKeyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: emptyKeyAuthID, - WorkspaceID: workspace1.ID, - CreatedAtM: time.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: false}, - DefaultBytes: sql.NullInt32{Valid: false}, - }) - require.NoError(t, err) - - // Create API with no keys - emptyApiID := uid.New("api") - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: emptyApiID, - Name: "API with no keys", + // Create API with no keys using testutil helper + apiName := "API with no keys" + emptyApi := h.CreateApi(seed.CreateApiRequest{ WorkspaceID: workspace1.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: emptyKeyAuthID}, - CreatedAtM: time.Now().UnixMilli(), + Name: &apiName, }) - require.NoError(t, err) + emptyApiID := emptyApi.ID req := handler.Request{ ApiId: emptyApiID, diff --git a/go/apps/api/routes/v2_apis_list_keys/412_test.go b/go/apps/api/routes/v2_apis_list_keys/412_test.go index b4ee049fec..4db71b0c35 100644 --- a/go/apps/api/routes/v2_apis_list_keys/412_test.go +++ b/go/apps/api/routes/v2_apis_list_keys/412_test.go @@ -22,11 +22,10 @@ func TestPreconditionError(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Vault: h.Vault, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Vault: h.Vault, } h.Register(route) diff --git a/go/apps/api/routes/v2_apis_list_keys/handler.go b/go/apps/api/routes/v2_apis_list_keys/handler.go index 85e9e1cc88..0ae8760050 100644 --- a/go/apps/api/routes/v2_apis_list_keys/handler.go +++ b/go/apps/api/routes/v2_apis_list_keys/handler.go @@ -10,7 +10,6 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" vaultv1 "github.com/unkeyed/unkey/go/gen/proto/vault/v1" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/fault" @@ -26,11 +25,10 @@ type Response = openapi.V2ApisListKeysResponseBody // Handler implements zen.Route interface for the v2 APIs list keys endpoint type Handler struct { - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService - Vault *vault.Service + Logger logging.Logger + DB db.Database + Keys keys.KeyService + Vault *vault.Service } // Method returns the HTTP method this route responds to @@ -47,7 +45,7 @@ func (h *Handler) Path() string { // The current implementation queries the database directly without caching, which may impact performance. // Consider implementing cache with optional bypass via revalidateKeysCache parameter. func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -60,38 +58,34 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { ) } - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.And( - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: "*", - Action: rbac.ReadKey, - }), - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: req.ApiId, - Action: rbac.ReadKey, - }), - ), - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: "*", - Action: rbac.ReadAPI, - }), - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: req.ApiId, - Action: rbac.ReadAPI, - }), - ), + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.And( + rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: "*", + Action: rbac.ReadKey, + }), + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: req.ApiId, + Action: rbac.ReadKey, + }), + ), + rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: "*", + Action: rbac.ReadAPI, + }), + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: req.ApiId, + Action: rbac.ReadAPI, + }), ), ), - ) + ))) if err != nil { return err } @@ -150,22 +144,25 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { } if ptr.SafeDeref(req.Decrypt, false) { - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: "*", - Action: rbac.DecryptKey, - }), - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: api.ID, - Action: rbac.DecryptKey, - }), - ), - ) + if h.Vault == nil { + return fault.New("vault missing", + fault.Code(codes.App.Precondition.PreconditionFailed.URN()), + fault.Public("Vault hasn't been set up."), + ) + } + + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: "*", + Action: rbac.DecryptKey, + }), + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: api.ID, + Action: rbac.DecryptKey, + }), + ))) if err != nil { return err } diff --git a/go/apps/api/routes/v2_identities_create_identity/200_test.go b/go/apps/api/routes/v2_identities_create_identity/200_test.go index a7c50f8877..c91a986860 100644 --- a/go/apps/api/routes/v2_identities_create_identity/200_test.go +++ b/go/apps/api/routes/v2_identities_create_identity/200_test.go @@ -21,11 +21,10 @@ func TestCreateIdentitySuccessfully(t *testing.T) { ctx := context.Background() h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, } h.Register(route) @@ -119,7 +118,6 @@ func TestCreateIdentitySuccessfully(t *testing.T) { // Test creating a identity with metadata t.Run("create identity with metadata", func(t *testing.T) { externalTestID := uid.New("test_external_id") - meta := &map[string]any{"key": "example"} req := handler.Request{ ExternalId: externalTestID, @@ -147,7 +145,6 @@ func TestCreateIdentitySuccessfully(t *testing.T) { // Test creating a identity with ratelimits t.Run("create identity with ratelimits", func(t *testing.T) { externalTestID := uid.New("test_external_id") - identityRateLimits := []openapi.RatelimitRequest{ { Duration: time.Minute.Milliseconds(), diff --git a/go/apps/api/routes/v2_identities_create_identity/400_test.go b/go/apps/api/routes/v2_identities_create_identity/400_test.go index ad7c2561fd..c5d509b49b 100644 --- a/go/apps/api/routes/v2_identities_create_identity/400_test.go +++ b/go/apps/api/routes/v2_identities_create_identity/400_test.go @@ -19,11 +19,10 @@ func TestBadRequests(t *testing.T) { rootKey := h.CreateRootKey(h.Resources().UserWorkspace.ID, "identity.*.create_identity") route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_identities_create_identity/401_test.go b/go/apps/api/routes/v2_identities_create_identity/401_test.go index 142c94c2d2..1c08398497 100644 --- a/go/apps/api/routes/v2_identities_create_identity/401_test.go +++ b/go/apps/api/routes/v2_identities_create_identity/401_test.go @@ -15,11 +15,10 @@ func TestUnauthorizedAccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_identities_create_identity/403_test.go b/go/apps/api/routes/v2_identities_create_identity/403_test.go index 2f65bc01af..7f40de0960 100644 --- a/go/apps/api/routes/v2_identities_create_identity/403_test.go +++ b/go/apps/api/routes/v2_identities_create_identity/403_test.go @@ -15,11 +15,10 @@ func TestWorkspacePermissions(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_identities_create_identity/409_test.go b/go/apps/api/routes/v2_identities_create_identity/409_test.go index 96be865f20..3bb01a4f71 100644 --- a/go/apps/api/routes/v2_identities_create_identity/409_test.go +++ b/go/apps/api/routes/v2_identities_create_identity/409_test.go @@ -16,11 +16,10 @@ func TestCreateIdentityDuplicate(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_identities_create_identity/handler.go b/go/apps/api/routes/v2_identities_create_identity/handler.go index 7f94569455..ecfa09533e 100644 --- a/go/apps/api/routes/v2_identities_create_identity/handler.go +++ b/go/apps/api/routes/v2_identities_create_identity/handler.go @@ -11,7 +11,6 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/auditlogs" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/auditlog" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" @@ -28,11 +27,10 @@ type Response = openapi.V2IdentitiesCreateIdentityResponseBody // Handler implements zen.Route interface for the v2 identities create identity endpoint type Handler struct { // Services as public fields - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService - Auditlogs auditlogs.AuditLogService + Logger logging.Logger + DB db.Database + Keys keys.KeyService + Auditlogs auditlogs.AuditLogService } const ( @@ -53,7 +51,7 @@ func (h *Handler) Path() string { // Handle processes the HTTP request func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -67,17 +65,13 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { ) } - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Identity, - ResourceID: "*", - Action: rbac.CreateIdentity, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Identity, + ResourceID: "*", + Action: rbac.CreateIdentity, + }), + ))) if err != nil { return err } @@ -133,7 +127,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: auditlog.IdentityCreateEvent, Display: fmt.Sprintf("Created identity %s.", identityID), - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorName: "root key", ActorMeta: map[string]any{}, ActorType: auditlog.RootKeyActor, @@ -170,7 +164,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: auditlog.RatelimitCreateEvent, Display: fmt.Sprintf("Created ratelimit %s.", ratelimitID), - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorType: auditlog.RootKeyActor, ActorName: "root key", ActorMeta: map[string]any{}, diff --git a/go/apps/api/routes/v2_identities_delete_identity/200_test.go b/go/apps/api/routes/v2_identities_delete_identity/200_test.go index 45f36faece..4887f8ac75 100644 --- a/go/apps/api/routes/v2_identities_delete_identity/200_test.go +++ b/go/apps/api/routes/v2_identities_delete_identity/200_test.go @@ -65,11 +65,10 @@ func TestDeleteIdentitySuccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_identities_delete_identity/400_test.go b/go/apps/api/routes/v2_identities_delete_identity/400_test.go index 15dc9c65b8..6b74abe15e 100644 --- a/go/apps/api/routes/v2_identities_delete_identity/400_test.go +++ b/go/apps/api/routes/v2_identities_delete_identity/400_test.go @@ -19,11 +19,10 @@ func TestBadRequests(t *testing.T) { rootKey := h.CreateRootKey(h.Resources().UserWorkspace.ID, "identity.*.delete_identity") route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_identities_delete_identity/401_test.go b/go/apps/api/routes/v2_identities_delete_identity/401_test.go index a4c255a991..5d3bb8ddfd 100644 --- a/go/apps/api/routes/v2_identities_delete_identity/401_test.go +++ b/go/apps/api/routes/v2_identities_delete_identity/401_test.go @@ -17,11 +17,10 @@ func TestDeleteIdentityUnauthorized(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_identities_delete_identity/403_test.go b/go/apps/api/routes/v2_identities_delete_identity/403_test.go index 9a2dfdd5f3..9f20fb059a 100644 --- a/go/apps/api/routes/v2_identities_delete_identity/403_test.go +++ b/go/apps/api/routes/v2_identities_delete_identity/403_test.go @@ -19,11 +19,10 @@ func TestDeleteIdentityForbidden(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_identities_delete_identity/404_test.go b/go/apps/api/routes/v2_identities_delete_identity/404_test.go index 7a01caadb5..c6eaf5cd7a 100644 --- a/go/apps/api/routes/v2_identities_delete_identity/404_test.go +++ b/go/apps/api/routes/v2_identities_delete_identity/404_test.go @@ -19,11 +19,10 @@ func TestDeleteIdentityNotFound(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_identities_delete_identity/handler.go b/go/apps/api/routes/v2_identities_delete_identity/handler.go index 83b81c6b4b..d9e01dcbf3 100644 --- a/go/apps/api/routes/v2_identities_delete_identity/handler.go +++ b/go/apps/api/routes/v2_identities_delete_identity/handler.go @@ -74,7 +74,6 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/auditlogs" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/auditlog" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" @@ -90,11 +89,10 @@ type Response = openapi.V2IdentitiesDeleteIdentityResponseBody // Handler implements zen.Route interface for the v2 identities delete identity endpoint type Handler struct { // Services as public fields - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService - Auditlogs auditlogs.AuditLogService + Logger logging.Logger + DB db.Database + Keys keys.KeyService + Auditlogs auditlogs.AuditLogService } // Method returns the HTTP method this route responds to @@ -109,7 +107,7 @@ func (h *Handler) Path() string { // Handle processes the HTTP request func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -139,11 +137,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { })) } - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or(checks...), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or(checks...))) if err != nil { return err } @@ -198,7 +192,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: auditlog.IdentityDeleteEvent, Display: fmt.Sprintf("Deleted identity %s.", identity.ID), - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorType: auditlog.RootKeyActor, ActorName: "root key", ActorMeta: map[string]any{}, @@ -229,7 +223,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: auditlog.RatelimitDeleteEvent, Display: fmt.Sprintf("Deleted ratelimit %s.", rl.ID), - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorType: auditlog.RootKeyActor, ActorName: "root key", ActorMeta: map[string]any{}, diff --git a/go/apps/api/routes/v2_identities_get_identity/200_test.go b/go/apps/api/routes/v2_identities_get_identity/200_test.go index dec6255959..4af35d3f35 100644 --- a/go/apps/api/routes/v2_identities_get_identity/200_test.go +++ b/go/apps/api/routes/v2_identities_get_identity/200_test.go @@ -16,16 +16,16 @@ import ( "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/uid" ) func TestSuccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, } h.Register(route) @@ -36,12 +36,10 @@ func TestSuccess(t *testing.T) { "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, } - // Setup test data + // Setup test data using testutil helper ctx := context.Background() - identityID := uid.New(uid.IdentityPrefix) externalID := "test_user_123" - // Create metadata metaMap := map[string]interface{}{ "name": "Test User", @@ -52,41 +50,26 @@ func TestSuccess(t *testing.T) { metaBytes, err := json.Marshal(metaMap) require.NoError(t, err) - // Insert test identity - err = db.Query.InsertIdentity(ctx, h.DB.RW(), db.InsertIdentityParams{ - ID: identityID, - ExternalID: externalID, + // Create identity with ratelimits using testutil helper + identityID := h.CreateIdentity(seed.CreateIdentityRequest{ WorkspaceID: h.Resources().UserWorkspace.ID, - Environment: "default", - CreatedAt: time.Now().UnixMilli(), + ExternalID: externalID, Meta: metaBytes, + Ratelimits: []seed.CreateRatelimitRequest{ + { + WorkspaceID: h.Resources().UserWorkspace.ID, + Name: "api_calls", + Limit: 100, + Duration: 60000, + }, + { + WorkspaceID: h.Resources().UserWorkspace.ID, + Name: "special_feature", + Limit: 10, + Duration: 3600000, + }, + }, }) - require.NoError(t, err) - - // Insert test ratelimits - ratelimitID1 := uid.New(uid.RatelimitPrefix) - err = db.Query.InsertIdentityRatelimit(ctx, h.DB.RW(), db.InsertIdentityRatelimitParams{ - ID: ratelimitID1, - WorkspaceID: h.Resources().UserWorkspace.ID, - IdentityID: sql.NullString{String: identityID, Valid: true}, - Name: "api_calls", - Limit: 100, - Duration: 60000, // 1 minute - CreatedAt: time.Now().UnixMilli(), - }) - require.NoError(t, err) - - ratelimitID2 := uid.New(uid.RatelimitPrefix) - err = db.Query.InsertIdentityRatelimit(ctx, h.DB.RW(), db.InsertIdentityRatelimitParams{ - ID: ratelimitID2, - WorkspaceID: h.Resources().UserWorkspace.ID, - IdentityID: sql.NullString{String: identityID, Valid: true}, - Name: "special_feature", - Limit: 10, - Duration: 3600000, // 1 hour - CreatedAt: time.Now().UnixMilli(), - }) - require.NoError(t, err) // No need to set up permissions since we already gave the key the required permission diff --git a/go/apps/api/routes/v2_identities_get_identity/400_test.go b/go/apps/api/routes/v2_identities_get_identity/400_test.go index faca2afeb7..f3d917055d 100644 --- a/go/apps/api/routes/v2_identities_get_identity/400_test.go +++ b/go/apps/api/routes/v2_identities_get_identity/400_test.go @@ -15,10 +15,9 @@ import ( func TestBadRequests(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, } h.Register(route) diff --git a/go/apps/api/routes/v2_identities_get_identity/401_test.go b/go/apps/api/routes/v2_identities_get_identity/401_test.go index d36ba2e32a..378f3851fb 100644 --- a/go/apps/api/routes/v2_identities_get_identity/401_test.go +++ b/go/apps/api/routes/v2_identities_get_identity/401_test.go @@ -18,10 +18,9 @@ func strPtr(s string) *string { func TestUnauthorized(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, } h.Register(route) diff --git a/go/apps/api/routes/v2_identities_get_identity/403_test.go b/go/apps/api/routes/v2_identities_get_identity/403_test.go index 039f527f70..3c421c0fde 100644 --- a/go/apps/api/routes/v2_identities_get_identity/403_test.go +++ b/go/apps/api/routes/v2_identities_get_identity/403_test.go @@ -19,10 +19,9 @@ import ( func TestForbidden(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, } h.Register(route) diff --git a/go/apps/api/routes/v2_identities_get_identity/404_test.go b/go/apps/api/routes/v2_identities_get_identity/404_test.go index a1b0c3e829..0f9900708c 100644 --- a/go/apps/api/routes/v2_identities_get_identity/404_test.go +++ b/go/apps/api/routes/v2_identities_get_identity/404_test.go @@ -18,10 +18,9 @@ import ( func TestNotFound(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, } h.Register(route) diff --git a/go/apps/api/routes/v2_identities_get_identity/handler.go b/go/apps/api/routes/v2_identities_get_identity/handler.go index 7b44e0a811..620bea05fd 100644 --- a/go/apps/api/routes/v2_identities_get_identity/handler.go +++ b/go/apps/api/routes/v2_identities_get_identity/handler.go @@ -8,7 +8,6 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/fault" @@ -25,10 +24,9 @@ type Response = openapi.V2IdentitiesGetIdentityResponseBody // Handler implements zen.Route interface for the v2 identities get identity endpoint type Handler struct { // Services as public fields - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService + Logger logging.Logger + DB db.Database + Keys keys.KeyService } // Method returns the HTTP method this route responds to @@ -43,7 +41,7 @@ func (h *Handler) Path() string { // Handle processes the HTTP request func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -117,7 +115,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { ratelimits := result.Ratelimits // Check permissions using either wildcard or the specific identity ID - permissionCheck := rbac.Or( + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( rbac.T(rbac.Tuple{ ResourceType: rbac.Identity, ResourceID: "*", @@ -128,9 +126,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { ResourceID: identity.ID, Action: rbac.ReadIdentity, }), - ) - - err = h.Permissions.Check(ctx, auth.KeyID, permissionCheck) + ))) if err != nil { return err } diff --git a/go/apps/api/routes/v2_identities_list_identities/200_test.go b/go/apps/api/routes/v2_identities_list_identities/200_test.go index b415f8f934..bdb1be890a 100644 --- a/go/apps/api/routes/v2_identities_list_identities/200_test.go +++ b/go/apps/api/routes/v2_identities_list_identities/200_test.go @@ -20,10 +20,9 @@ import ( func TestSuccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, } // Register the route with the harness diff --git a/go/apps/api/routes/v2_identities_list_identities/400_test.go b/go/apps/api/routes/v2_identities_list_identities/400_test.go index df624d14b4..33f13f456c 100644 --- a/go/apps/api/routes/v2_identities_list_identities/400_test.go +++ b/go/apps/api/routes/v2_identities_list_identities/400_test.go @@ -16,10 +16,9 @@ import ( func TestBadRequests(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, } // Register the route with the harness diff --git a/go/apps/api/routes/v2_identities_list_identities/401_test.go b/go/apps/api/routes/v2_identities_list_identities/401_test.go index 7460d2f533..e0c811411e 100644 --- a/go/apps/api/routes/v2_identities_list_identities/401_test.go +++ b/go/apps/api/routes/v2_identities_list_identities/401_test.go @@ -13,10 +13,9 @@ import ( func TestUnauthorized(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, } // Register the route with the harness diff --git a/go/apps/api/routes/v2_identities_list_identities/403_test.go b/go/apps/api/routes/v2_identities_list_identities/403_test.go index c9bd4d3402..3c9395aa6e 100644 --- a/go/apps/api/routes/v2_identities_list_identities/403_test.go +++ b/go/apps/api/routes/v2_identities_list_identities/403_test.go @@ -18,10 +18,9 @@ import ( func TestForbidden(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, } // Create a rootKey without any permissions diff --git a/go/apps/api/routes/v2_identities_list_identities/cross_workspace_test.go b/go/apps/api/routes/v2_identities_list_identities/cross_workspace_test.go index 2d0b5d816f..d50787c784 100644 --- a/go/apps/api/routes/v2_identities_list_identities/cross_workspace_test.go +++ b/go/apps/api/routes/v2_identities_list_identities/cross_workspace_test.go @@ -17,10 +17,9 @@ import ( func TestCrossWorkspaceForbidden(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, } h.Register(route) diff --git a/go/apps/api/routes/v2_identities_list_identities/handler.go b/go/apps/api/routes/v2_identities_list_identities/handler.go index c5a3fa9dc3..e0ae65cdee 100644 --- a/go/apps/api/routes/v2_identities_list_identities/handler.go +++ b/go/apps/api/routes/v2_identities_list_identities/handler.go @@ -9,7 +9,6 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/fault" "github.com/unkeyed/unkey/go/pkg/otel/logging" @@ -24,10 +23,9 @@ type Response = openapi.V2IdentitiesListIdentitiesResponseBody // Handler implements zen.Route interface for the v2 identities list identities endpoint type Handler struct { // Services as public fields - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService + Logger logging.Logger + DB db.Database + Keys keys.KeyService } // Method returns the HTTP method this route responds to @@ -42,7 +40,7 @@ func (h *Handler) Path() string { // Handle processes the HTTP request func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -78,6 +76,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { // Trim the results to the requested limit identities = identities[:limit] } + // Check permissions for all identities before processing for _, id := range identities { // Check permissions @@ -94,11 +93,10 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { }), ) - err = h.Permissions.Check(ctx, auth.KeyID, permissionCheck) + err = auth.Verify(ctx, keys.WithPermissions(permissionCheck)) if err != nil { return err } - } // Process the results and get ratelimits for each identity diff --git a/go/apps/api/routes/v2_identities_update_identity/200_test.go b/go/apps/api/routes/v2_identities_update_identity/200_test.go index 477861494e..95e8a500d6 100644 --- a/go/apps/api/routes/v2_identities_update_identity/200_test.go +++ b/go/apps/api/routes/v2_identities_update_identity/200_test.go @@ -21,11 +21,10 @@ import ( func TestSuccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_identities_update_identity/400_test.go b/go/apps/api/routes/v2_identities_update_identity/400_test.go index 539a107385..7de40a76d4 100644 --- a/go/apps/api/routes/v2_identities_update_identity/400_test.go +++ b/go/apps/api/routes/v2_identities_update_identity/400_test.go @@ -15,11 +15,10 @@ import ( func TestBadRequests(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_identities_update_identity/401_test.go b/go/apps/api/routes/v2_identities_update_identity/401_test.go index 74b60deed2..2b0695b621 100644 --- a/go/apps/api/routes/v2_identities_update_identity/401_test.go +++ b/go/apps/api/routes/v2_identities_update_identity/401_test.go @@ -15,11 +15,10 @@ import ( func TestUnauthorized(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_identities_update_identity/403_test.go b/go/apps/api/routes/v2_identities_update_identity/403_test.go index 0010567d98..b51f4b86de 100644 --- a/go/apps/api/routes/v2_identities_update_identity/403_test.go +++ b/go/apps/api/routes/v2_identities_update_identity/403_test.go @@ -18,11 +18,10 @@ import ( func TestForbidden(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_identities_update_identity/404_test.go b/go/apps/api/routes/v2_identities_update_identity/404_test.go index cacaa8608b..db7ef7b5f6 100644 --- a/go/apps/api/routes/v2_identities_update_identity/404_test.go +++ b/go/apps/api/routes/v2_identities_update_identity/404_test.go @@ -16,11 +16,10 @@ import ( func TestNotFound(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_identities_update_identity/handler.go b/go/apps/api/routes/v2_identities_update_identity/handler.go index f5d76f8d52..1a40d53317 100644 --- a/go/apps/api/routes/v2_identities_update_identity/handler.go +++ b/go/apps/api/routes/v2_identities_update_identity/handler.go @@ -11,7 +11,6 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/auditlogs" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/auditlog" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" @@ -28,11 +27,10 @@ type Response = openapi.V2IdentitiesUpdateIdentityResponseBody // Handler implements zen.Route interface for the v2 identities update identity endpoint type Handler struct { // Services as public fields - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService - Auditlogs auditlogs.AuditLogService + Logger logging.Logger + DB db.Database + Keys keys.KeyService + Auditlogs auditlogs.AuditLogService } const ( @@ -52,7 +50,7 @@ func (h *Handler) Path() string { // Handle processes the HTTP request func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -85,22 +83,18 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { identityIdForPermissions = *req.IdentityId } - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Identity, - ResourceID: "*", - Action: rbac.UpdateIdentity, - }), - rbac.T(rbac.Tuple{ - ResourceType: rbac.Identity, - ResourceID: identityIdForPermissions, - Action: rbac.UpdateIdentity, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Identity, + ResourceID: "*", + Action: rbac.UpdateIdentity, + }), + rbac.T(rbac.Tuple{ + ResourceType: rbac.Identity, + ResourceID: identityIdForPermissions, + Action: rbac.UpdateIdentity, + }), + ))) if err != nil { return err } @@ -195,7 +189,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: auditlog.IdentityUpdateEvent, Display: fmt.Sprintf("Updated identity %s", identity.ID), - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorName: "root key", ActorType: auditlog.RootKeyActor, ActorMeta: map[string]any{}, @@ -261,7 +255,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: auditlog.RatelimitDeleteEvent, Display: fmt.Sprintf("Deleted ratelimit %s", existingRL.ID), - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorName: "root key", ActorType: auditlog.RootKeyActor, ActorMeta: map[string]any{}, @@ -316,7 +310,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: auditlog.RatelimitUpdateEvent, Display: fmt.Sprintf("Updated ratelimit %s", existingRL.ID), - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorName: "root key", ActorType: auditlog.RootKeyActor, ActorMeta: map[string]any{}, @@ -361,7 +355,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: auditlog.RatelimitCreateEvent, Display: fmt.Sprintf("Created ratelimit %s", ratelimitID), - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorName: "root key", ActorType: auditlog.RootKeyActor, ActorMeta: map[string]any{}, diff --git a/go/apps/api/routes/v2_keys_add_permissions/200_test.go b/go/apps/api/routes/v2_keys_add_permissions/200_test.go index 825603adab..6367fb921e 100644 --- a/go/apps/api/routes/v2_keys_add_permissions/200_test.go +++ b/go/apps/api/routes/v2_keys_add_permissions/200_test.go @@ -2,18 +2,15 @@ package handler_test import ( "context" - "database/sql" "fmt" "net/http" "testing" - "time" "github.com/stretchr/testify/require" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_add_permissions" "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/uid" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" ) func TestSuccess(t *testing.T) { @@ -21,11 +18,11 @@ func TestSuccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -41,52 +38,32 @@ func TestSuccess(t *testing.T) { } t.Run("add single permission by ID", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API with keyring using testutil helper + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + // Create a test key using testutil helper + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) - - // Create a permission - permissionID := uid.New(uid.TestPrefix) - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ - PermissionID: permissionID, - WorkspaceID: workspace.ID, - Name: "documents.read.single.id", - Slug: "documents.read.single.id", - Description: sql.NullString{Valid: true, String: "Read documents permission"}, + keyID := keyResponse.KeyID + + // Create a permission using testutil helper + permissionDescription := "Read documents permission" + permissionID := h.CreatePermission(seed.CreatePermissionRequest{ + WorkspaceID: workspace.ID, + Name: "documents.read.single.id", + Slug: "documents.read.single.id", + Description: &permissionDescription, }) - require.NoError(t, err) // Verify key has no permissions initially currentPermissions, err := db.Query.ListDirectPermissionsByKeyID(ctx, h.DB.RO(), keyID) @@ -141,53 +118,33 @@ func TestSuccess(t *testing.T) { }) t.Run("add single permission by name", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API with keyring using testutil helper + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + // Create a test key using testutil helper + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID - // Create a permission - permissionID := uid.New(uid.TestPrefix) + // Create a permission using testutil helper permissionSlug := "documents.write.single.name" - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ - PermissionID: permissionID, - WorkspaceID: workspace.ID, - Name: permissionSlug, - Slug: permissionSlug, - Description: sql.NullString{Valid: true, String: "Write documents permission"}, + permissionDescription := "Write documents permission" + permissionID := h.CreatePermission(seed.CreatePermissionRequest{ + WorkspaceID: workspace.ID, + Name: permissionSlug, + Slug: permissionSlug, + Description: &permissionDescription, }) - require.NoError(t, err) // Verify key has no permissions initially currentPermissions, err := db.Query.ListDirectPermissionsByKeyID(ctx, h.DB.RO(), keyID) @@ -227,64 +184,42 @@ func TestSuccess(t *testing.T) { }) t.Run("add multiple permissions", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API with keyring using testutil helper + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + // Create a test key using testutil helper + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID - // Create permissions - permission1ID := uid.New(uid.TestPrefix) + // Create permissions using testutil helper permission1Name := "documents.read.multiple" - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ - PermissionID: permission1ID, - WorkspaceID: workspace.ID, - Name: permission1Name, - Slug: permission1Name, - Description: sql.NullString{Valid: true, String: "Read documents permission"}, + permissionDescription1 := "Read documents permission" + permission1ID := h.CreatePermission(seed.CreatePermissionRequest{ + WorkspaceID: workspace.ID, + Name: permission1Name, + Slug: permission1Name, + Description: &permissionDescription1, }) - require.NoError(t, err) - permission2ID := uid.New(uid.TestPrefix) permission2Slug := "documents.write.multiple" - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ - PermissionID: permission2ID, - WorkspaceID: workspace.ID, - Name: permission2Slug, - Slug: permission2Slug, - Description: sql.NullString{Valid: true, String: "Write documents permission"}, + permissionDescription2 := "Write documents permission" + permission2ID := h.CreatePermission(seed.CreatePermissionRequest{ + WorkspaceID: workspace.ID, + Name: permission2Slug, + Slug: permission2Slug, + Description: &permissionDescription2, }) - require.NoError(t, err) req := handler.Request{ KeyId: keyID, @@ -323,52 +258,32 @@ func TestSuccess(t *testing.T) { }) t.Run("idempotent operation - adding same permission twice", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API with keyring using testutil helper + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + // Create a test key using testutil helper + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) - - // Create a permission - permissionID := uid.New(uid.TestPrefix) - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ - PermissionID: permissionID, - WorkspaceID: workspace.ID, - Name: "documents.read.idempotent", - Slug: "documents.read.idempotent", - Description: sql.NullString{Valid: true, String: "Read documents permission"}, + keyID := keyResponse.KeyID + + // Create a permission using testutil helper + permissionDescription := "Read documents permission" + permissionID := h.CreatePermission(seed.CreatePermissionRequest{ + WorkspaceID: workspace.ID, + Name: "documents.read.idempotent", + Slug: "documents.read.idempotent", + Description: &permissionDescription, }) - require.NoError(t, err) req := handler.Request{ KeyId: keyID, @@ -414,71 +329,41 @@ func TestSuccess(t *testing.T) { }) t.Run("add permissions to key that already has permissions", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), - }) - require.NoError(t, err) - - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, - }) - require.NoError(t, err) - - // Create permissions - existingPermissionID := uid.New(uid.TestPrefix) - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ - PermissionID: existingPermissionID, - WorkspaceID: workspace.ID, - Name: "documents.read.existing", - Slug: "documents.read.existing", - Description: sql.NullString{Valid: true, String: "Read documents permission"}, + // Create API with keyring using testutil helper + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - newPermissionID := uid.New(uid.TestPrefix) - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ - PermissionID: newPermissionID, - WorkspaceID: workspace.ID, - Name: "documents.write.existing", - Slug: "documents.write.existing", - Description: sql.NullString{Valid: true, String: "Write documents permission"}, + // Create permissions using testutil helper + existingPermissionDescription := "Read documents permission" + newPermissionDescription := "Write documents permission" + newPermissionID := h.CreatePermission(seed.CreatePermissionRequest{ + WorkspaceID: workspace.ID, + Name: "documents.write.existing", + Slug: "documents.write.existing", + Description: &newPermissionDescription, }) - require.NoError(t, err) - // Add existing permission to key first - err = db.Query.InsertKeyPermission(ctx, h.DB.RW(), db.InsertKeyPermissionParams{ - KeyID: keyID, - PermissionID: existingPermissionID, - WorkspaceID: workspace.ID, - CreatedAt: time.Now().UnixMilli(), + // Create a test key with existing permission using testutil helper + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, + Permissions: []seed.CreatePermissionRequest{ + { + WorkspaceID: workspace.ID, + Name: "documents.read.existing", + Slug: "documents.read.existing", + Description: &existingPermissionDescription, + }, + }, }) - require.NoError(t, err) + keyID := keyResponse.KeyID req := handler.Request{ KeyId: keyID, @@ -508,7 +393,7 @@ func TestSuccess(t *testing.T) { for _, p := range res.Body.Data { permissionIDs[p.Id] = true } - require.True(t, permissionIDs[existingPermissionID]) + require.True(t, permissionIDs[keyResponse.PermissionIds[0]]) require.True(t, permissionIDs[newPermissionID]) // Verify permissions in database diff --git a/go/apps/api/routes/v2_keys_add_permissions/400_test.go b/go/apps/api/routes/v2_keys_add_permissions/400_test.go index 3e02cc439e..677752aa67 100644 --- a/go/apps/api/routes/v2_keys_add_permissions/400_test.go +++ b/go/apps/api/routes/v2_keys_add_permissions/400_test.go @@ -22,11 +22,11 @@ func TestValidationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -69,10 +69,6 @@ func TestValidationErrors(t *testing.T) { Meta: sql.NullString{Valid: false}, Expires: sql.NullTime{Valid: false}, RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, }) require.NoError(t, err) diff --git a/go/apps/api/routes/v2_keys_add_permissions/401_test.go b/go/apps/api/routes/v2_keys_add_permissions/401_test.go index b6c5d0da63..5507eb6991 100644 --- a/go/apps/api/routes/v2_keys_add_permissions/401_test.go +++ b/go/apps/api/routes/v2_keys_add_permissions/401_test.go @@ -22,11 +22,11 @@ func TestAuthenticationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -62,10 +62,6 @@ func TestAuthenticationErrors(t *testing.T) { Meta: sql.NullString{Valid: false}, Expires: sql.NullTime{Valid: false}, RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, }) require.NoError(t, err) diff --git a/go/apps/api/routes/v2_keys_add_permissions/403_test.go b/go/apps/api/routes/v2_keys_add_permissions/403_test.go index 2e69868741..93bdc96aee 100644 --- a/go/apps/api/routes/v2_keys_add_permissions/403_test.go +++ b/go/apps/api/routes/v2_keys_add_permissions/403_test.go @@ -22,11 +22,11 @@ func TestAuthorizationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -62,10 +62,6 @@ func TestAuthorizationErrors(t *testing.T) { Meta: sql.NullString{Valid: false}, Expires: sql.NullTime{Valid: false}, RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, }) require.NoError(t, err) @@ -172,10 +168,6 @@ func TestAuthorizationErrors(t *testing.T) { Meta: sql.NullString{Valid: false}, Expires: sql.NullTime{Valid: false}, RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, }) require.NoError(t, err) diff --git a/go/apps/api/routes/v2_keys_add_permissions/404_test.go b/go/apps/api/routes/v2_keys_add_permissions/404_test.go index 9b12ec9002..d7d8f11f91 100644 --- a/go/apps/api/routes/v2_keys_add_permissions/404_test.go +++ b/go/apps/api/routes/v2_keys_add_permissions/404_test.go @@ -12,8 +12,8 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_add_permissions" "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/uid" ) @@ -22,11 +22,11 @@ func TestNotFoundErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -80,41 +80,22 @@ func TestNotFoundErrors(t *testing.T) { }) t.Run("permission not found by ID", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Use a non-existent permission ID nonExistentPermissionID := uid.New(uid.TestPrefix) @@ -143,41 +124,22 @@ func TestNotFoundErrors(t *testing.T) { }) t.Run("permission not found by name", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID nonExistentPermissionSlug := "nonexistent.permission.name" @@ -226,41 +188,22 @@ func TestNotFoundErrors(t *testing.T) { }) require.NoError(t, err) - // Create a test keyring in our workspace - keyAuthID := uid.New(uid.KeyAuthPrefix) - err = db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key in our workspace using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key in our workspace - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID req := handler.Request{ KeyId: keyID, @@ -296,41 +239,22 @@ func TestNotFoundErrors(t *testing.T) { }) require.NoError(t, err) - // Create a test keyring in the other workspace - otherKeyAuthID := uid.New(uid.KeyAuthPrefix) - err = db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: otherKeyAuthID, - WorkspaceID: otherWorkspaceID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key in the other workspace using testutil helpers + otherDefaultPrefix := "test" + otherDefaultBytes := int32(16) + otherApi := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: otherWorkspaceID, + DefaultPrefix: &otherDefaultPrefix, + DefaultBytes: &otherDefaultBytes, }) - require.NoError(t, err) - // Create a test key in the other workspace - otherKeyID := uid.New(uid.KeyPrefix) - otherKeyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: otherKeyID, - KeyringID: otherKeyAuthID, - Hash: hash.Sha256(otherKeyString), - Start: otherKeyString[:4], - WorkspaceID: otherWorkspaceID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Other Workspace Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + otherKeyName := "Other Workspace Key" + otherKeyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: otherWorkspaceID, + KeyAuthID: otherApi.KeyAuthID.String, + Name: &otherKeyName, }) - require.NoError(t, err) + otherKeyID := otherKeyResponse.KeyID // Create a permission in our workspace permissionID := uid.New(uid.TestPrefix) diff --git a/go/apps/api/routes/v2_keys_add_permissions/handler.go b/go/apps/api/routes/v2_keys_add_permissions/handler.go index f7967dea16..736c00a439 100644 --- a/go/apps/api/routes/v2_keys_add_permissions/handler.go +++ b/go/apps/api/routes/v2_keys_add_permissions/handler.go @@ -10,8 +10,8 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/auditlogs" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/auditlog" + "github.com/unkeyed/unkey/go/pkg/cache" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/fault" @@ -24,11 +24,11 @@ type Request = openapi.V2KeysAddPermissionsRequestBody type Response = openapi.V2KeysAddPermissionsResponse type Handler struct { - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService - Auditlogs auditlogs.AuditLogService + Logger logging.Logger + DB db.Database + Keys keys.KeyService + Auditlogs auditlogs.AuditLogService + KeyCache cache.Cache[string, db.FindKeyForVerificationRow] } // Method returns the HTTP method this route responds to @@ -43,9 +43,8 @@ func (h *Handler) Path() string { // Handle processes the HTTP request func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { - // 1. Authentication - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -57,17 +56,13 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { } // 3. Permission check - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: "*", - Action: rbac.UpdateKey, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: "*", + Action: rbac.UpdateKey, + }), + ))) if err != nil { return err } @@ -198,7 +193,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: auditlog.AuthConnectPermissionKeyEvent, ActorType: auditlog.RootKeyActor, - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorName: "root key", ActorMeta: map[string]any{}, Display: fmt.Sprintf("Added permission %s to key %s", permission.Name, req.KeyId), @@ -236,6 +231,8 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { if err != nil { return err } + + h.KeyCache.Remove(ctx, key.Hash) } // 9. Get final state of direct permissions and build response diff --git a/go/apps/api/routes/v2_keys_add_roles/200_test.go b/go/apps/api/routes/v2_keys_add_roles/200_test.go index 1ee77a4251..776ebd887a 100644 --- a/go/apps/api/routes/v2_keys_add_roles/200_test.go +++ b/go/apps/api/routes/v2_keys_add_roles/200_test.go @@ -2,7 +2,6 @@ package handler_test import ( "context" - "database/sql" "fmt" "net/http" "testing" @@ -11,9 +10,9 @@ import ( "github.com/stretchr/testify/require" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_add_roles" "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/hash" + "github.com/unkeyed/unkey/go/pkg/ptr" "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/uid" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" ) func TestSuccess(t *testing.T) { @@ -21,11 +20,11 @@ func TestSuccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -41,64 +40,35 @@ func TestSuccess(t *testing.T) { } t.Run("add single role by ID", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + EncryptedKeys: false, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + key := h.CreateKey(seed.CreateKeyRequest{ + KeyAuthID: api.KeyAuthID.String, + WorkspaceID: workspace.ID, }) - require.NoError(t, err) - // Create a role - roleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ - RoleID: roleID, + roleId := h.CreateRole(seed.CreateRoleRequest{ + WorkspaceID: workspace.ID, Name: "admin_single_id", - Description: sql.NullString{Valid: true, String: "Admin role"}, + Description: ptr.P("Admin Role"), }) - require.NoError(t, err) // Verify key has no roles initially - currentRoles, err := db.Query.ListRolesByKeyID(ctx, h.DB.RO(), keyID) + currentRoles, err := db.Query.ListRolesByKeyID(ctx, h.DB.RO(), key.KeyID) require.NoError(t, err) require.Empty(t, currentRoles) req := handler.Request{ - KeyId: keyID, + KeyId: key.KeyID, Roles: []struct { Id *string `json:"id,omitempty"` Name *string `json:"name,omitempty"` }{ - {Id: &roleID}, + {Id: &roleId}, }, } @@ -113,17 +83,17 @@ func TestSuccess(t *testing.T) { require.NotNil(t, res.Body) require.NotNil(t, res.Body.Data) require.Len(t, res.Body.Data, 1) - require.Equal(t, roleID, res.Body.Data[0].Id) + require.Equal(t, roleId, res.Body.Data[0].Id) require.Equal(t, "admin_single_id", res.Body.Data[0].Name) // Verify role was added to key - finalRoles, err := db.Query.ListRolesByKeyID(ctx, h.DB.RO(), keyID) + finalRoles, err := db.Query.ListRolesByKeyID(ctx, h.DB.RO(), key.KeyID) require.NoError(t, err) require.Len(t, finalRoles, 1) - require.Equal(t, roleID, finalRoles[0].ID) + require.Equal(t, roleId, finalRoles[0].ID) // Verify audit log was created - auditLogs, err := db.Query.FindAuditLogTargetByID(ctx, h.DB.RO(), keyID) + auditLogs, err := db.Query.FindAuditLogTargetByID(ctx, h.DB.RO(), key.KeyID) require.NoError(t, err) require.NotEmpty(t, auditLogs) @@ -139,55 +109,26 @@ func TestSuccess(t *testing.T) { }) t.Run("add single role by name", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + EncryptedKeys: false, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, - }) - require.NoError(t, err) - - // Create a role - roleID := uid.New(uid.TestPrefix) roleName := "editor_single_name" - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ - RoleID: roleID, + key := h.CreateKey(seed.CreateKeyRequest{ + KeyAuthID: api.KeyAuthID.String, WorkspaceID: workspace.ID, - Name: roleName, - Description: sql.NullString{Valid: true, String: "Editor role"}, + Roles: []seed.CreateRoleRequest{ + { + WorkspaceID: workspace.ID, + Name: "editor_single_name", + Description: ptr.P(roleName), + }, + }, }) - require.NoError(t, err) req := handler.Request{ - KeyId: keyID, + KeyId: key.KeyID, Roles: []struct { Id *string `json:"id,omitempty"` Name *string `json:"name,omitempty"` @@ -207,91 +148,52 @@ func TestSuccess(t *testing.T) { require.NotNil(t, res.Body) require.NotNil(t, res.Body.Data) require.Len(t, res.Body.Data, 1) - require.Equal(t, roleID, res.Body.Data[0].Id) + require.Equal(t, key.RolesIds[0], res.Body.Data[0].Id) require.Equal(t, roleName, res.Body.Data[0].Name) // Verify role was added to key - finalRoles, err := db.Query.ListRolesByKeyID(ctx, h.DB.RO(), keyID) + finalRoles, err := db.Query.ListRolesByKeyID(ctx, h.DB.RO(), key.KeyID) require.NoError(t, err) require.Len(t, finalRoles, 1) - require.Equal(t, roleID, finalRoles[0].ID) + require.Equal(t, key.RolesIds[0], finalRoles[0].ID) }) t.Run("add multiple roles mixed references", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + EncryptedKeys: false, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + key := h.CreateKey(seed.CreateKeyRequest{ + KeyAuthID: api.KeyAuthID.String, + WorkspaceID: workspace.ID, }) - require.NoError(t, err) - // Create multiple roles - adminRoleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ - RoleID: adminRoleID, + adminRole := h.CreateRole(seed.CreateRoleRequest{ WorkspaceID: workspace.ID, Name: "admin_multi", - Description: sql.NullString{Valid: true, String: "Admin role"}, }) - require.NoError(t, err) - editorRoleID := uid.New(uid.TestPrefix) - editorRoleName := "editor_multi" - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ - RoleID: editorRoleID, + viewerMultiRole := h.CreateRole(seed.CreateRoleRequest{ WorkspaceID: workspace.ID, - Name: editorRoleName, - Description: sql.NullString{Valid: true, String: "Editor role"}, + Name: "viewer_multi", }) - require.NoError(t, err) - viewerRoleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ - RoleID: viewerRoleID, + editorRoleName := "editor_multi" + h.CreateRole(seed.CreateRoleRequest{ WorkspaceID: workspace.ID, - Name: "viewer_multi", - Description: sql.NullString{Valid: true, String: "Viewer role"}, + Name: editorRoleName, }) - require.NoError(t, err) req := handler.Request{ - KeyId: keyID, + KeyId: key.KeyID, Roles: []struct { Id *string `json:"id,omitempty"` Name *string `json:"name,omitempty"` }{ - {Id: &adminRoleID}, // By ID + {Id: &adminRole}, // By ID {Name: &editorRoleName}, // By name - {Id: &viewerRoleID}, // By ID + {Id: &viewerMultiRole}, // By ID }, } @@ -312,12 +214,12 @@ func TestSuccess(t *testing.T) { require.Equal(t, []string{"admin_multi", "editor_multi", "viewer_multi"}, roleNames) // Verify roles were added to key - finalRoles, err := db.Query.ListRolesByKeyID(ctx, h.DB.RO(), keyID) + finalRoles, err := db.Query.ListRolesByKeyID(ctx, h.DB.RO(), key.KeyID) require.NoError(t, err) require.Len(t, finalRoles, 3) // Verify audit logs were created (one for each role) - auditLogs, err := db.Query.FindAuditLogTargetByID(ctx, h.DB.RO(), keyID) + auditLogs, err := db.Query.FindAuditLogTargetByID(ctx, h.DB.RO(), key.KeyID) require.NoError(t, err) require.NotEmpty(t, auditLogs) @@ -332,65 +234,31 @@ func TestSuccess(t *testing.T) { }) t.Run("idempotent behavior - add existing roles", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + EncryptedKeys: false, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + key := h.CreateKey(seed.CreateKeyRequest{ + KeyAuthID: api.KeyAuthID.String, + WorkspaceID: workspace.ID, }) - require.NoError(t, err) - - // Create roles - adminRoleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ - RoleID: adminRoleID, + adminId := h.CreateRole(seed.CreateRoleRequest{ WorkspaceID: workspace.ID, Name: "admin_idempotent", - Description: sql.NullString{Valid: true, String: "Admin role"}, + Description: ptr.P("admin_idempotent"), }) - require.NoError(t, err) - editorRoleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ - RoleID: editorRoleID, + editorId := h.CreateRole(seed.CreateRoleRequest{ WorkspaceID: workspace.ID, Name: "editor_idempotent", - Description: sql.NullString{Valid: true, String: "Editor role"}, + Description: ptr.P("editor_idempotent"), }) - require.NoError(t, err) // First, add admin role to the key - err = db.Query.InsertKeyRole(ctx, h.DB.RW(), db.InsertKeyRoleParams{ - KeyID: keyID, - RoleID: adminRoleID, + err := db.Query.InsertKeyRole(ctx, h.DB.RW(), db.InsertKeyRoleParams{ + KeyID: key.KeyID, + RoleID: adminId, WorkspaceID: workspace.ID, CreatedAtM: time.Now().UnixMilli(), }) @@ -398,13 +266,13 @@ func TestSuccess(t *testing.T) { // Now try to add both admin (existing) and editor (new) roles req := handler.Request{ - KeyId: keyID, + KeyId: key.KeyID, Roles: []struct { Id *string `json:"id,omitempty"` Name *string `json:"name,omitempty"` }{ - {Id: &adminRoleID}, // Already exists - {Id: &editorRoleID}, // New role + {Id: &adminId}, // Already exists + {Id: &editorId}, // New role }, } @@ -425,12 +293,12 @@ func TestSuccess(t *testing.T) { require.Equal(t, []string{"admin_idempotent", "editor_idempotent"}, roleNames) // Verify roles in database - finalRoles, err := db.Query.ListRolesByKeyID(ctx, h.DB.RO(), keyID) + finalRoles, err := db.Query.ListRolesByKeyID(ctx, h.DB.RO(), key.KeyID) require.NoError(t, err) require.Len(t, finalRoles, 2) // Verify audit logs - should only have one new log for the editor role - auditLogs, err := db.Query.FindAuditLogTargetByID(ctx, h.DB.RO(), keyID) + auditLogs, err := db.Query.FindAuditLogTargetByID(ctx, h.DB.RO(), key.KeyID) require.NoError(t, err) require.NotEmpty(t, auditLogs) @@ -438,7 +306,7 @@ func TestSuccess(t *testing.T) { editorConnectEvents := 0 for _, log := range auditLogs { if log.AuditLog.Event == "authorization.connect_role_and_key" && - log.AuditLog.Display == fmt.Sprintf("Added role editor_idempotent to key %s", keyID) { + log.AuditLog.Display == fmt.Sprintf("Added role editor_idempotent to key %s", key.KeyID) { editorConnectEvents++ } } @@ -446,73 +314,31 @@ func TestSuccess(t *testing.T) { }) t.Run("role reference with both ID and name", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + EncryptedKeys: false, }) - require.NoError(t, err) - - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, - }) - require.NoError(t, err) - // Create roles - role1ID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ - RoleID: role1ID, + key := h.CreateKey(seed.CreateKeyRequest{ + KeyAuthID: api.KeyAuthID.String, WorkspaceID: workspace.ID, - Name: "admin_both_ref", - Description: sql.NullString{Valid: true, String: "Admin role"}, }) - require.NoError(t, err) - role2ID := uid.New(uid.TestPrefix) - role2Name := "editor_both_ref" - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ - RoleID: role2ID, - WorkspaceID: workspace.ID, - Name: role2Name, - Description: sql.NullString{Valid: true, String: "Editor role"}, - }) - require.NoError(t, err) + adminID := h.CreateRole(seed.CreateRoleRequest{WorkspaceID: workspace.ID, Name: "admin_both_ref"}) + editorRoleName := "editor_both_ref" + h.CreateRole(seed.CreateRoleRequest{WorkspaceID: workspace.ID, Name: editorRoleName}) // Request with role reference having both ID and name // ID should take precedence req := handler.Request{ - KeyId: keyID, + KeyId: key.KeyID, Roles: []struct { Id *string `json:"id,omitempty"` Name *string `json:"name,omitempty"` }{ { - Id: &role1ID, - Name: &role2Name, // This should be ignored, ID takes precedence + Id: &adminID, + Name: &editorRoleName, // This should be ignored, ID takes precedence }, }, } @@ -528,13 +354,13 @@ func TestSuccess(t *testing.T) { require.NotNil(t, res.Body) require.NotNil(t, res.Body.Data) require.Len(t, res.Body.Data, 1) - require.Equal(t, role1ID, res.Body.Data[0].Id) + require.Equal(t, adminID, res.Body.Data[0].Id) require.Equal(t, "admin_both_ref", res.Body.Data[0].Name) // Should be role1, not role2 // Verify correct role was added - finalRoles, err := db.Query.ListRolesByKeyID(ctx, h.DB.RO(), keyID) + finalRoles, err := db.Query.ListRolesByKeyID(ctx, h.DB.RO(), key.KeyID) require.NoError(t, err) require.Len(t, finalRoles, 1) - require.Equal(t, role1ID, finalRoles[0].ID) + require.Equal(t, adminID, finalRoles[0].ID) }) } diff --git a/go/apps/api/routes/v2_keys_add_roles/400_test.go b/go/apps/api/routes/v2_keys_add_roles/400_test.go index 7a611e694b..b498325444 100644 --- a/go/apps/api/routes/v2_keys_add_roles/400_test.go +++ b/go/apps/api/routes/v2_keys_add_roles/400_test.go @@ -1,32 +1,26 @@ package handler_test import ( - "context" - "database/sql" "fmt" "net/http" "testing" - "time" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_add_roles" - "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/uid" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" ) func TestValidationErrors(t *testing.T) { - ctx := context.Background() h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -41,40 +35,22 @@ func TestValidationErrors(t *testing.T) { "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, } - // Create a test key for valid requests - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create a test API and key for valid requests using testutil helper + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - - validKeyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: validKeyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + validKeyID := keyResponse.KeyID // Test case for missing keyId t.Run("missing keyId", func(t *testing.T) { diff --git a/go/apps/api/routes/v2_keys_add_roles/401_test.go b/go/apps/api/routes/v2_keys_add_roles/401_test.go index 05bd207253..372fccc6a5 100644 --- a/go/apps/api/routes/v2_keys_add_roles/401_test.go +++ b/go/apps/api/routes/v2_keys_add_roles/401_test.go @@ -1,32 +1,26 @@ package handler_test import ( - "context" - "database/sql" "fmt" "net/http" "testing" - "time" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_add_roles" - "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/uid" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" ) func TestAuthenticationErrors(t *testing.T) { - ctx := context.Background() h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -34,41 +28,22 @@ func TestAuthenticationErrors(t *testing.T) { // Create a workspace and valid key for the request workspace := h.Resources().UserWorkspace - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create a test API and key using testutil helper + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create a valid request req := handler.Request{ diff --git a/go/apps/api/routes/v2_keys_add_roles/403_test.go b/go/apps/api/routes/v2_keys_add_roles/403_test.go index 6998ddea13..2dd8a7f76c 100644 --- a/go/apps/api/routes/v2_keys_add_roles/403_test.go +++ b/go/apps/api/routes/v2_keys_add_roles/403_test.go @@ -14,6 +14,7 @@ import ( "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/uid" ) @@ -22,11 +23,11 @@ func TestAuthorizationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -37,41 +38,22 @@ func TestAuthorizationErrors(t *testing.T) { workspace := h.Resources().UserWorkspace rootKey := h.CreateRootKey(workspace.ID, "api.*.read_key") // Only read permission - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID req := handler.Request{ KeyId: keyID, @@ -139,10 +121,6 @@ func TestAuthorizationErrors(t *testing.T) { Meta: sql.NullString{Valid: false}, Expires: sql.NullTime{Valid: false}, RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, }) require.NoError(t, err) @@ -180,41 +158,22 @@ func TestAuthorizationErrors(t *testing.T) { // Create root key with only read permissions rootKey := h.CreateRootKey(workspace.ID, "api.*.read_key", "rbac.*.read_role") - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID req := handler.Request{ KeyId: keyID, @@ -250,41 +209,22 @@ func TestAuthorizationErrors(t *testing.T) { // Use invalid root key to simulate expired key rootKey := "expired_root_key_12345" - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID req := handler.Request{ KeyId: keyID, diff --git a/go/apps/api/routes/v2_keys_add_roles/404_test.go b/go/apps/api/routes/v2_keys_add_roles/404_test.go index cffb0f02ff..77992c9a4d 100644 --- a/go/apps/api/routes/v2_keys_add_roles/404_test.go +++ b/go/apps/api/routes/v2_keys_add_roles/404_test.go @@ -12,8 +12,8 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_add_roles" "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/uid" ) @@ -22,11 +22,11 @@ func TestNotFoundErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -71,41 +71,22 @@ func TestNotFoundErrors(t *testing.T) { // Create a second workspace with a key workspace2 := h.CreateWorkspace() - // Create a test keyring in workspace2 - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace2.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create a test API and key in workspace2 using testutil helper + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace2.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key in workspace2 - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace2.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace2.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID req := handler.Request{ KeyId: keyID, // Key from workspace2, but accessed with workspace1 root key @@ -132,41 +113,22 @@ func TestNotFoundErrors(t *testing.T) { // Test case for role not found by ID t.Run("role not found by ID", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID nonExistentRoleId := "role_nonexistent123456789" req := handler.Request{ @@ -195,41 +157,22 @@ func TestNotFoundErrors(t *testing.T) { // Test case for role not found by name t.Run("role not found by name", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID nonExistentRoleName := "nonexistent_role" req := handler.Request{ @@ -269,41 +212,22 @@ func TestNotFoundErrors(t *testing.T) { }) require.NoError(t, err) - // Create a test keyring in workspace1 - keyAuthID := uid.New(uid.KeyAuthPrefix) - err = db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key in workspace1 using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key in workspace1 - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID req := handler.Request{ KeyId: keyID, @@ -342,41 +266,22 @@ func TestNotFoundErrors(t *testing.T) { }) require.NoError(t, err) - // Create a test keyring in workspace1 - keyAuthID := uid.New(uid.KeyAuthPrefix) - err = db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key in workspace1 using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key in workspace1 - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID req := handler.Request{ KeyId: keyID, @@ -403,45 +308,26 @@ func TestNotFoundErrors(t *testing.T) { // Test case for multiple roles with one not found t.Run("multiple roles one not found", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create one valid role validRoleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: validRoleID, WorkspaceID: workspace.ID, Name: "admin_multiple_roles", @@ -479,39 +365,26 @@ func TestNotFoundErrors(t *testing.T) { // Test case for deleted key t.Run("deleted key", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, + Deleted: true, + }) + + err := db.Query.SoftDeleteKeyByID(ctx, h.DB.RW(), db.SoftDeleteKeyByIDParams{ + Now: sql.NullInt64{Valid: true, Int64: time.Now().UnixMilli()}, + ID: keyResponse.KeyID, }) require.NoError(t, err) @@ -525,11 +398,8 @@ func TestNotFoundErrors(t *testing.T) { }) require.NoError(t, err) - // Use non-existent key ID to simulate deleted key - deletedKeyID := "key_deleted123456789" - req := handler.Request{ - KeyId: deletedKeyID, + KeyId: keyResponse.KeyID, Roles: []struct { Id *string `json:"id,omitempty"` Name *string `json:"name,omitempty"` @@ -553,45 +423,26 @@ func TestNotFoundErrors(t *testing.T) { // Test case for deleted role t.Run("deleted role", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create a role roleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: roleID, WorkspaceID: workspace.ID, Name: "admin_deleted_role", diff --git a/go/apps/api/routes/v2_keys_add_roles/handler.go b/go/apps/api/routes/v2_keys_add_roles/handler.go index 2debeab857..e5f8efa62e 100644 --- a/go/apps/api/routes/v2_keys_add_roles/handler.go +++ b/go/apps/api/routes/v2_keys_add_roles/handler.go @@ -10,8 +10,8 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/auditlogs" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/auditlog" + "github.com/unkeyed/unkey/go/pkg/cache" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/fault" @@ -24,11 +24,11 @@ type Request = openapi.V2KeysAddRolesRequestBody type Response = openapi.V2KeysAddRolesResponse type Handler struct { - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService - Auditlogs auditlogs.AuditLogService + Logger logging.Logger + DB db.Database + Keys keys.KeyService + Auditlogs auditlogs.AuditLogService + KeyCache cache.Cache[string, db.FindKeyForVerificationRow] } // Method returns the HTTP method this route responds to @@ -46,7 +46,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { h.Logger.Debug("handling request", "requestId", s.RequestID(), "path", "/v2/keys.addRoles") // 1. Authentication - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -58,17 +58,13 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { } // 3. Permission check - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: "*", - Action: rbac.UpdateKey, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: "*", + Action: rbac.UpdateKey, + }), + ))) if err != nil { return err } @@ -96,6 +92,13 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { ) } + if key.DeletedAtM.Valid { + return fault.New("key not found", + fault.Code(codes.Data.Key.NotFound.URN()), + fault.Internal("key is deleted"), fault.Public("The specified key was not found."), + ) + } + // 5. Get current roles for the key currentRoles, err := db.Query.ListRolesByKeyID(ctx, h.DB.RO(), req.KeyId) if err != nil { @@ -198,7 +201,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: auditlog.AuthConnectRoleKeyEvent, ActorType: auditlog.RootKeyActor, - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorName: "root key", ActorMeta: map[string]any{}, Display: fmt.Sprintf("Added role %s to key %s", role.Name, req.KeyId), @@ -236,6 +239,8 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { if err != nil { return err } + + h.KeyCache.Remove(ctx, key.Hash) } // 9. Get final state of roles and build response diff --git a/go/apps/api/routes/v2_keys_create_key/200_test.go b/go/apps/api/routes/v2_keys_create_key/200_test.go index a818c683b3..05db5d1351 100644 --- a/go/apps/api/routes/v2_keys_create_key/200_test.go +++ b/go/apps/api/routes/v2_keys_create_key/200_test.go @@ -2,18 +2,16 @@ package handler_test import ( "context" - "database/sql" "fmt" "net/http" "testing" - "time" "github.com/stretchr/testify/require" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_create_key" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/ptr" "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/uid" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" ) func Test_CreateKey_Success(t *testing.T) { @@ -23,37 +21,19 @@ func Test_CreateKey_Success(t *testing.T) { ctx := context.Background() route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, - Vault: h.Vault, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, + Vault: h.Vault, } h.Register(route) - // Create API manually - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: time.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: false, String: ""}, - DefaultBytes: sql.NullInt32{Valid: false, Int32: 0}, - }) - require.NoError(t, err) - - apiID := uid.New(uid.APIPrefix) - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: "test-api", + // Create API using testutil helper + api := h.CreateApi(seed.CreateApiRequest{ WorkspaceID: h.Resources().UserWorkspace.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: keyAuthID}, - CreatedAtM: time.Now().UnixMilli(), }) - require.NoError(t, err) rootKey := h.CreateRootKey(h.Resources().UserWorkspace.ID, "api.*.create_key") @@ -64,7 +44,7 @@ func Test_CreateKey_Success(t *testing.T) { // Test basic key creation req := handler.Request{ - ApiId: apiID, + ApiId: api.ID, } res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) @@ -92,37 +72,19 @@ func Test_CreateKey_WithOptionalFields(t *testing.T) { ctx := context.Background() route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, - Vault: h.Vault, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + Vault: h.Vault, } h.Register(route) - // Create API manually - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: time.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: false, String: ""}, - DefaultBytes: sql.NullInt32{Valid: false, Int32: 0}, - }) - require.NoError(t, err) - - apiID := uid.New(uid.APIPrefix) - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: "test-api", + // Create API using testutil helper + api := h.CreateApi(seed.CreateApiRequest{ WorkspaceID: h.Resources().UserWorkspace.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: keyAuthID}, - CreatedAtM: time.Now().UnixMilli(), }) - require.NoError(t, err) rootKey := h.CreateRootKey(h.Resources().UserWorkspace.ID, "api.*.create_key") @@ -140,7 +102,7 @@ func Test_CreateKey_WithOptionalFields(t *testing.T) { enabled := true req := handler.Request{ - ApiId: apiID, + ApiId: api.ID, Name: &name, Prefix: &prefix, ExternalId: &externalID, @@ -173,43 +135,20 @@ func TestCreateKeyWithEncryption(t *testing.T) { ctx := context.Background() route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, - Vault: h.Vault, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + Vault: h.Vault, } h.Register(route) - // Create API manually - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, + // Create API with encrypted keys using testutil helper + api := h.CreateApi(seed.CreateApiRequest{ WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: time.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: false, String: ""}, - DefaultBytes: sql.NullInt32{Valid: false, Int32: 0}, + EncryptedKeys: true, }) - require.NoError(t, err) - - err = db.Query.UpdateKeyringKeyEncryption(ctx, h.DB.RW(), db.UpdateKeyringKeyEncryptionParams{ - ID: keyAuthID, - StoreEncryptedKeys: true, - }) - require.NoError(t, err) - - apiID := uid.New(uid.APIPrefix) - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: "test-api", - WorkspaceID: h.Resources().UserWorkspace.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: keyAuthID}, - CreatedAtM: time.Now().UnixMilli(), - }) - require.NoError(t, err) rootKey := h.CreateRootKey(h.Resources().UserWorkspace.ID, "api.*.create_key", "api.*.encrypt_key") @@ -222,7 +161,7 @@ func TestCreateKeyWithEncryption(t *testing.T) { name := "Test Key" req := handler.Request{ - ApiId: apiID, + ApiId: api.ID, Name: &name, ExternalId: ptr.P("user_123"), Enabled: ptr.P(true), diff --git a/go/apps/api/routes/v2_keys_create_key/400_test.go b/go/apps/api/routes/v2_keys_create_key/400_test.go index 339c0856d5..d4ba360320 100644 --- a/go/apps/api/routes/v2_keys_create_key/400_test.go +++ b/go/apps/api/routes/v2_keys_create_key/400_test.go @@ -1,59 +1,37 @@ package handler_test import ( - "context" - "database/sql" "fmt" "net/http" "strings" "testing" - "time" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_create_key" - "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/ptr" "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/uid" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" ) func Test_CreateKey_BadRequest(t *testing.T) { h := testutil.NewHarness(t) - ctx := context.Background() route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + Vault: h.Vault, } h.Register(route) - // Create API for valid tests - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: time.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: false, String: ""}, - DefaultBytes: sql.NullInt32{Valid: false, Int32: 0}, - }) - require.NoError(t, err) - - apiID := uid.New(uid.APIPrefix) - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: "test-api", + // Create API using testutil helper + api := h.CreateApi(seed.CreateApiRequest{ WorkspaceID: h.Resources().UserWorkspace.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: keyAuthID}, - CreatedAtM: time.Now().UnixMilli(), }) - require.NoError(t, err) rootKey := h.CreateRootKey(h.Resources().UserWorkspace.ID, "api.*.create_key") @@ -95,7 +73,7 @@ func Test_CreateKey_BadRequest(t *testing.T) { t.Run("byteLength too small", func(t *testing.T) { invalidByteLength := 10 // minimum is 16 req := handler.Request{ - ApiId: apiID, + ApiId: api.ID, ByteLength: &invalidByteLength, } @@ -107,7 +85,7 @@ func Test_CreateKey_BadRequest(t *testing.T) { t.Run("byteLength too large", func(t *testing.T) { invalidByteLength := 300 // maximum is 255 req := handler.Request{ - ApiId: apiID, + ApiId: api.ID, ByteLength: &invalidByteLength, } @@ -119,7 +97,7 @@ func Test_CreateKey_BadRequest(t *testing.T) { t.Run("prefix too long", func(t *testing.T) { invalidPrefix := "this_prefix_is_way_too_long_for_the_api" // max is 16 req := handler.Request{ - ApiId: apiID, + ApiId: api.ID, Prefix: &invalidPrefix, } @@ -131,7 +109,7 @@ func Test_CreateKey_BadRequest(t *testing.T) { t.Run("negative expires timestamp", func(t *testing.T) { invalidExpires := int64(-1) req := handler.Request{ - ApiId: apiID, + ApiId: api.ID, Expires: &invalidExpires, } @@ -143,7 +121,7 @@ func Test_CreateKey_BadRequest(t *testing.T) { t.Run("empty permission in list", func(t *testing.T) { emptyPermissions := []string{""} req := handler.Request{ - ApiId: apiID, + ApiId: api.ID, Permissions: &emptyPermissions, } @@ -155,7 +133,7 @@ func Test_CreateKey_BadRequest(t *testing.T) { t.Run("empty role in list", func(t *testing.T) { emptyRoles := []string{""} req := handler.Request{ - ApiId: apiID, + ApiId: api.ID, Roles: &emptyRoles, } @@ -169,7 +147,7 @@ func Test_CreateKey_BadRequest(t *testing.T) { longPermission := strings.Repeat("a", 513) longPermissions := []string{longPermission} req := handler.Request{ - ApiId: apiID, + ApiId: api.ID, Permissions: &longPermissions, } @@ -181,7 +159,7 @@ func Test_CreateKey_BadRequest(t *testing.T) { t.Run("role too long", func(t *testing.T) { // Create a role string that's longer than 512 characters req := handler.Request{ - ApiId: apiID, + ApiId: api.ID, Roles: ptr.P([]string{strings.Repeat("a", 513)}), } diff --git a/go/apps/api/routes/v2_keys_create_key/401_test.go b/go/apps/api/routes/v2_keys_create_key/401_test.go index 5228b72c93..0ce5de1a09 100644 --- a/go/apps/api/routes/v2_keys_create_key/401_test.go +++ b/go/apps/api/routes/v2_keys_create_key/401_test.go @@ -17,11 +17,11 @@ func Test_CreateKey_Unauthorized(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + Vault: h.Vault, } h.Register(route) diff --git a/go/apps/api/routes/v2_keys_create_key/403_test.go b/go/apps/api/routes/v2_keys_create_key/403_test.go index e92f3b3b6e..a38f41db69 100644 --- a/go/apps/api/routes/v2_keys_create_key/403_test.go +++ b/go/apps/api/routes/v2_keys_create_key/403_test.go @@ -23,11 +23,11 @@ func Test_CreateKey_Forbidden(t *testing.T) { ctx := context.Background() route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + Vault: h.Vault, } h.Register(route) diff --git a/go/apps/api/routes/v2_keys_create_key/404_test.go b/go/apps/api/routes/v2_keys_create_key/404_test.go index b75e51c3b6..1672a55671 100644 --- a/go/apps/api/routes/v2_keys_create_key/404_test.go +++ b/go/apps/api/routes/v2_keys_create_key/404_test.go @@ -17,11 +17,11 @@ func Test_CreateKey_NotFound(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + Vault: h.Vault, } h.Register(route) diff --git a/go/apps/api/routes/v2_keys_create_key/412_test.go b/go/apps/api/routes/v2_keys_create_key/412_test.go index bcc7827c6a..a0266f8dcd 100644 --- a/go/apps/api/routes/v2_keys_create_key/412_test.go +++ b/go/apps/api/routes/v2_keys_create_key/412_test.go @@ -22,11 +22,11 @@ func TestPreconditionError(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Vault: h.Vault, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Vault: h.Vault, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_keys_create_key/handler.go b/go/apps/api/routes/v2_keys_create_key/handler.go index bb212b7cf1..b951061215 100644 --- a/go/apps/api/routes/v2_keys_create_key/handler.go +++ b/go/apps/api/routes/v2_keys_create_key/handler.go @@ -13,7 +13,6 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/auditlogs" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/auditlog" "github.com/unkeyed/unkey/go/pkg/codes" @@ -31,12 +30,11 @@ type Request = openapi.V2KeysCreateKeyRequestBody type Response = openapi.V2KeysCreateKeyResponseBody type Handler struct { - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService - Auditlogs auditlogs.AuditLogService - Vault *vault.Service + Logger logging.Logger + DB db.Database + Keys keys.KeyService + Auditlogs auditlogs.AuditLogService + Vault *vault.Service } // Method returns the HTTP method this route responds to @@ -54,7 +52,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { h.Logger.Debug("handling request", "requestId", s.RequestID(), "path", "/v2/keys.createKey") // 1. Authentication - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -66,22 +64,18 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { } // 3. Permission check - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: req.ApiId, - Action: rbac.CreateKey, - }), - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: "*", - Action: rbac.CreateKey, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: req.ApiId, + Action: rbac.CreateKey, + }), + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: "*", + Action: rbac.CreateKey, + }), + ))) if err != nil { return err } @@ -138,22 +132,25 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { encrypt := ptr.SafeDeref(req.Recoverable, false) var encryption *vaultv1.EncryptResponse if encrypt { - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: "*", - Action: rbac.EncryptKey, - }), - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: api.ID, - Action: rbac.EncryptKey, - }), - ), - ) + if h.Vault == nil { + return fault.New("vault missing", + fault.Code(codes.App.Precondition.PreconditionFailed.URN()), + fault.Public("Vault hasn't been set up."), + ) + } + + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: "*", + Action: rbac.EncryptKey, + }), + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: api.ID, + Action: rbac.EncryptKey, + }), + ))) if err != nil { return err } @@ -247,10 +244,6 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { IdentityID: sql.NullString{String: "", Valid: false}, Meta: sql.NullString{String: "", Valid: false}, Expires: sql.NullTime{Time: time.Time{}, Valid: false}, - RatelimitAsync: sql.NullBool{Bool: false, Valid: false}, - RatelimitLimit: sql.NullInt32{Int32: 0, Valid: false}, - RatelimitDuration: sql.NullInt64{Int64: 0, Valid: false}, - Environment: sql.NullString{String: "", Valid: false}, } // Set optional fields @@ -380,7 +373,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: auditlog.AuthConnectPermissionKeyEvent, ActorType: auditlog.RootKeyActor, - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorName: "root key", ActorMeta: map[string]any{}, Display: fmt.Sprintf("Added permission %s to key %s", permission.Name, keyID), @@ -434,7 +427,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: auditlog.AuthConnectRoleKeyEvent, ActorType: auditlog.RootKeyActor, - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorName: "root key", ActorMeta: map[string]any{}, Display: fmt.Sprintf("Connected role %s to key %s", role.Name, keyID), @@ -479,7 +472,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: auditlog.KeyCreateEvent, ActorType: auditlog.RootKeyActor, - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorName: "root key", ActorMeta: map[string]any{}, Display: fmt.Sprintf("Created key %s", keyID), diff --git a/go/apps/api/routes/v2_keys_delete_key/200_test.go b/go/apps/api/routes/v2_keys_delete_key/200_test.go index fdf61e2c97..7f44928f02 100644 --- a/go/apps/api/routes/v2_keys_delete_key/200_test.go +++ b/go/apps/api/routes/v2_keys_delete_key/200_test.go @@ -11,11 +11,10 @@ import ( "github.com/stretchr/testify/require" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_delete_key" vaultv1 "github.com/unkeyed/unkey/go/gen/proto/vault/v1" - "github.com/unkeyed/unkey/go/internal/services/keys" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/ptr" "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/uid" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" ) func TestKeyDeleteSuccess(t *testing.T) { @@ -23,11 +22,11 @@ func TestKeyDeleteSuccess(t *testing.T) { ctx := context.Background() route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -35,78 +34,56 @@ func TestKeyDeleteSuccess(t *testing.T) { // Create a workspace and user workspace := h.Resources().UserWorkspace - // Create a keyAuth (keyring) for the API - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - CreatedAtM: time.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: false}, - DefaultBytes: sql.NullInt32{Valid: false}, - }) - require.NoError(t, err) - - // Create a test API - apiID := uid.New("api") - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: "Test API", + // Create a test API using testutil helper + apiName := "Test API" + api := h.CreateApi(seed.CreateApiRequest{ WorkspaceID: workspace.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: keyAuthID}, - CreatedAtM: time.Now().UnixMilli(), - }) - require.NoError(t, err) - - softDeleteKeyID := uid.New(uid.KeyPrefix) - softDeleteKey, err := h.Keys.CreateKey(ctx, keys.CreateKeyRequest{ - Prefix: "test", - ByteLength: 16, - }) - require.NoError(t, err) - - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: softDeleteKeyID, - KeyringID: keyAuthID, - Hash: softDeleteKey.Hash, - Start: softDeleteKey.Start, - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "test-key"}, - Expires: sql.NullTime{Valid: false}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false, String: ""}, - RemainingRequests: sql.NullInt32{Int32: 0, Valid: false}, + Name: &apiName, }) - require.NoError(t, err) - hardDeleteKeyID := uid.New(uid.KeyPrefix) - hardDeleteKey, err := h.Keys.CreateKey(ctx, keys.CreateKeyRequest{ - Prefix: "test", - ByteLength: 16, + // Create a test key for soft delete using testutil helper + softDeleteKeyName := "test-key" + softDeleteKeyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &softDeleteKeyName, }) - require.NoError(t, err) + softDeleteKeyID := softDeleteKeyResponse.KeyID - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: hardDeleteKeyID, - KeyringID: keyAuthID, - Hash: hardDeleteKey.Hash, - Start: hardDeleteKey.Start, - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "test-key"}, - Expires: sql.NullTime{Valid: false}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false, String: ""}, - RemainingRequests: sql.NullInt32{Int32: 0, Valid: false}, + // Create a test key for hard delete with all relationships using testutil helper + hardDeleteKeyName := "test-key" + hardDeleteKeyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &hardDeleteKeyName, + Permissions: []seed.CreatePermissionRequest{ + { + WorkspaceID: workspace.ID, + Name: "read_data", + Slug: "read_data", + }, + }, + Roles: []seed.CreateRoleRequest{ + { + WorkspaceID: workspace.ID, + Name: "data_admin", + }, + }, + Ratelimits: []seed.CreateRatelimitRequest{ + { + WorkspaceID: workspace.ID, + Name: "api_calls", + Limit: 100, + Duration: 60000, + }, + }, }) - require.NoError(t, err) + hardDeleteKeyID := hardDeleteKeyResponse.KeyID + // Add encryption to the hard delete key encryption, err := h.Vault.Encrypt(ctx, &vaultv1.EncryptRequest{ Keyring: workspace.ID, - Data: hardDeleteKey.Key, + Data: hardDeleteKeyResponse.Key, }) require.NoError(t, err) @@ -119,55 +96,6 @@ func TestKeyDeleteSuccess(t *testing.T) { }) require.NoError(t, err) - // Create permissions - perm1ID := uid.New(uid.PermissionPrefix) - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ - PermissionID: perm1ID, - WorkspaceID: workspace.ID, - Name: "read_data", - Slug: "read_data", - CreatedAtM: time.Now().UnixMilli(), - }) - require.NoError(t, err) - - // Assign permissions to key - err = db.Query.InsertKeyPermission(ctx, h.DB.RW(), db.InsertKeyPermissionParams{ - KeyID: hardDeleteKeyID, - PermissionID: perm1ID, - WorkspaceID: workspace.ID, - }) - require.NoError(t, err) - - roleID := uid.New(uid.RolePrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ - RoleID: roleID, - WorkspaceID: workspace.ID, - Name: "data_admin", - }) - require.NoError(t, err) - - // Assign role to key - err = db.Query.InsertKeyRole(ctx, h.DB.RW(), db.InsertKeyRoleParams{ - KeyID: hardDeleteKeyID, - RoleID: roleID, - WorkspaceID: workspace.ID, - CreatedAtM: time.Now().UnixMilli(), - }) - require.NoError(t, err) - - // Create ratelimits for the key - rl1ID := uid.New(uid.RatelimitPrefix) - err = db.Query.InsertKeyRatelimit(ctx, h.DB.RW(), db.InsertKeyRatelimitParams{ - ID: rl1ID, - WorkspaceID: workspace.ID, - KeyID: sql.NullString{Valid: true, String: hardDeleteKeyID}, - Name: "api_calls", - Limit: 100, - Duration: 60000, // 1 minute - CreatedAt: time.Now().UnixMilli(), - }) - require.NoError(t, err) - // Create a root key with appropriate permissions rootKey := h.CreateRootKey(workspace.ID, "api.*.delete_key") headers := http.Header{ diff --git a/go/apps/api/routes/v2_keys_delete_key/400_test.go b/go/apps/api/routes/v2_keys_delete_key/400_test.go index 298a8863a5..147f38cfc9 100644 --- a/go/apps/api/routes/v2_keys_delete_key/400_test.go +++ b/go/apps/api/routes/v2_keys_delete_key/400_test.go @@ -15,11 +15,11 @@ func TestKeyDeleteBadRequest(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) diff --git a/go/apps/api/routes/v2_keys_delete_key/401_test.go b/go/apps/api/routes/v2_keys_delete_key/401_test.go index 3ec9df0530..f95eac9430 100644 --- a/go/apps/api/routes/v2_keys_delete_key/401_test.go +++ b/go/apps/api/routes/v2_keys_delete_key/401_test.go @@ -20,11 +20,11 @@ func TestKeyDeleteUnauthorized(t *testing.T) { ctx := t.Context() route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) diff --git a/go/apps/api/routes/v2_keys_delete_key/403_test.go b/go/apps/api/routes/v2_keys_delete_key/403_test.go index 286fa22ff3..4d64da0927 100644 --- a/go/apps/api/routes/v2_keys_delete_key/403_test.go +++ b/go/apps/api/routes/v2_keys_delete_key/403_test.go @@ -23,11 +23,11 @@ func TestKeyDeleteForbidden(t *testing.T) { ctx := context.Background() route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -117,10 +117,6 @@ func TestKeyDeleteForbidden(t *testing.T) { Meta: sql.NullString{Valid: false}, Expires: sql.NullTime{Valid: false}, RemainingRequests: sql.NullInt32{Valid: true, Int32: 100}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, }) require.NoError(t, err) diff --git a/go/apps/api/routes/v2_keys_delete_key/404_test.go b/go/apps/api/routes/v2_keys_delete_key/404_test.go index e12584764b..ea0c979ee7 100644 --- a/go/apps/api/routes/v2_keys_delete_key/404_test.go +++ b/go/apps/api/routes/v2_keys_delete_key/404_test.go @@ -21,11 +21,11 @@ func TestKeyDeleteNotFound(t *testing.T) { ctx := t.Context() route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) diff --git a/go/apps/api/routes/v2_keys_delete_key/handler.go b/go/apps/api/routes/v2_keys_delete_key/handler.go index d9e83804aa..b297256cfb 100644 --- a/go/apps/api/routes/v2_keys_delete_key/handler.go +++ b/go/apps/api/routes/v2_keys_delete_key/handler.go @@ -10,8 +10,8 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/auditlogs" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/auditlog" + "github.com/unkeyed/unkey/go/pkg/cache" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/fault" @@ -26,11 +26,11 @@ type Response = openapi.V2KeysDeleteKeyResponseBody // Handler implements zen.Route interface for the v2 keys.deleteKey endpoint type Handler struct { - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService - Auditlogs auditlogs.AuditLogService + Logger logging.Logger + DB db.Database + Keys keys.KeyService + Auditlogs auditlogs.AuditLogService + KeyCache cache.Cache[string, db.FindKeyForVerificationRow] } // Method returns the HTTP method this route responds to @@ -47,7 +47,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { h.Logger.Debug("handling request", "requestId", s.RequestID(), "path", "/v2/keys.deleteKey") // Authentication - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -92,22 +92,18 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { } // Permission check - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: "*", - Action: rbac.DeleteKey, - }), - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: key.Api.ID, - Action: rbac.DeleteKey, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: "*", + Action: rbac.DeleteKey, + }), + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: key.Api.ID, + Action: rbac.DeleteKey, + }), + ))) if err != nil { return err } @@ -137,7 +133,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { Event: auditlog.KeyDeleteEvent, WorkspaceID: auth.AuthorizedWorkspaceID, ActorType: auditlog.RootKeyActor, - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorName: "root key", ActorMeta: map[string]any{}, Display: fmt.Sprintf("%s %s", description, key.ID), @@ -157,11 +153,12 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { return err }) - if err != nil { return err } + h.KeyCache.Remove(ctx, key.Hash) + return s.JSON(http.StatusOK, Response{ Meta: openapi.Meta{ RequestId: s.RequestID(), diff --git a/go/apps/api/routes/v2_keys_get_key/200_test.go b/go/apps/api/routes/v2_keys_get_key/200_test.go index b750de0f86..625452b93a 100644 --- a/go/apps/api/routes/v2_keys_get_key/200_test.go +++ b/go/apps/api/routes/v2_keys_get_key/200_test.go @@ -2,7 +2,6 @@ package handler_test import ( "context" - "database/sql" "encoding/json" "fmt" "net/http" @@ -15,11 +14,10 @@ import ( "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_get_key" - "github.com/unkeyed/unkey/go/internal/services/keys" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/ptr" "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/uid" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" ) func TestGetKeyByKeyID(t *testing.T) { @@ -27,12 +25,11 @@ func TestGetKeyByKeyID(t *testing.T) { ctx := context.Background() route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, - Vault: h.Vault, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, + Vault: h.Vault, } h.Register(route) @@ -40,83 +37,41 @@ func TestGetKeyByKeyID(t *testing.T) { // Create a workspace and user workspace := h.Resources().UserWorkspace - // Create a keyAuth (keyring) for the API - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, + // Create a test API with encrypted keys using testutil helper + apiName := "Test API" + api := h.CreateApi(seed.CreateApiRequest{ WorkspaceID: workspace.ID, - CreatedAtM: time.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: false}, - DefaultBytes: sql.NullInt32{Valid: false}, + Name: &apiName, + EncryptedKeys: true, }) - require.NoError(t, err) - - err = db.Query.UpdateKeyringKeyEncryption(ctx, h.DB.RW(), db.UpdateKeyringKeyEncryptionParams{ - ID: keyAuthID, - StoreEncryptedKeys: true, - }) - require.NoError(t, err) - - // Create a test API - apiID := uid.New("api") - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: "Test API", - WorkspaceID: workspace.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: keyAuthID}, - CreatedAtM: time.Now().UnixMilli(), - }) - require.NoError(t, err) - // Create test identities - identityID := uid.New("identity") - identityExternalID := "test_user" - err = db.Query.InsertIdentity(ctx, h.DB.RW(), db.InsertIdentityParams{ - ID: identityID, - ExternalID: identityExternalID, + // Create test identity with ratelimit using testutil helper + identityID := h.CreateIdentity(seed.CreateIdentityRequest{ WorkspaceID: workspace.ID, - Environment: "", - CreatedAt: time.Now().UnixMilli(), + ExternalID: "test_user", Meta: []byte(`{"role": "admin"}`), + Ratelimits: []seed.CreateRatelimitRequest{ + { + WorkspaceID: workspace.ID, + Name: "api_calls", + Limit: 100, + Duration: 60000, + }, + }, }) - require.NoError(t, err) - - ratelimitID := uid.New(uid.RatelimitPrefix) - err = db.Query.InsertIdentityRatelimit(ctx, h.DB.RW(), db.InsertIdentityRatelimitParams{ - ID: ratelimitID, - WorkspaceID: h.Resources().UserWorkspace.ID, - IdentityID: sql.NullString{String: identityID, Valid: true}, - Name: "api_calls", - Limit: 100, - Duration: 60000, // 1 minute - CreatedAt: time.Now().UnixMilli(), - }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - key, _ := h.Keys.CreateKey(ctx, keys.CreateKeyRequest{ - Prefix: "test", - ByteLength: 16, + // Create test key with identity and encryption using testutil helper + keyName := "test-key" + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, + IdentityID: &identityID, }) + keyID := key.KeyID + // key := keyResponse.Key - insertParams := db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: key.Hash, - Start: key.Start, - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "test-key"}, - Expires: sql.NullTime{Valid: false}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: true, String: identityID}, - } - - err = db.Query.InsertKey(ctx, h.DB.RW(), insertParams) - require.NoError(t, err) - + // Add encryption for the key since API has encrypted keys enabled encryption, err := h.Vault.Encrypt(ctx, &vaultv1.EncryptRequest{ Keyring: workspace.ID, Data: key.Key, @@ -190,43 +145,24 @@ func TestGetKeyByKeyID(t *testing.T) { func TestGetKey_AdditionalScenarios(t *testing.T) { h := testutil.NewHarness(t) - ctx := context.Background() - route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, - Vault: h.Vault, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, + Vault: h.Vault, } h.Register(route) workspace := h.Resources().UserWorkspace - // Create keyAuth (keyring) for the API - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - CreatedAtM: time.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: false}, - DefaultBytes: sql.NullInt32{Valid: false}, - }) - require.NoError(t, err) - - // Create test API - apiID := uid.New("api") - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: "Test API", + // Create test API using testutil helper + apiName := "Test API" + api := h.CreateApi(seed.CreateApiRequest{ WorkspaceID: workspace.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: keyAuthID}, - CreatedAtM: time.Now().UnixMilli(), + Name: &apiName, }) - require.NoError(t, err) // Create root key with appropriate permissions rootKey := h.CreateRootKey(workspace.ID, "api.*.read_key") @@ -236,12 +172,7 @@ func TestGetKey_AdditionalScenarios(t *testing.T) { } t.Run("key with complex meta data", func(t *testing.T) { - keyID := uid.New(uid.KeyPrefix) - key, _ := h.Keys.CreateKey(ctx, keys.CreateKeyRequest{ - Prefix: "test", - ByteLength: 16, - }) - + // Create test key with complex meta using testutil helper complexMeta := map[string]interface{}{ "user_id": 12345, "plan": "premium", @@ -253,19 +184,15 @@ func TestGetKey_AdditionalScenarios(t *testing.T) { }, } metaBytes, _ := json.Marshal(complexMeta) - - err := db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: key.Hash, - Start: key.Start, + metaString := string(metaBytes) + keyName := "complex-meta-key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ WorkspaceID: workspace.ID, - Name: sql.NullString{Valid: true, String: "complex-meta-key"}, - Meta: sql.NullString{Valid: true, String: string(metaBytes)}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, + Meta: &metaString, }) - require.NoError(t, err) + keyID := keyResponse.KeyID req := handler.Request{ KeyId: ptr.P(keyID), @@ -284,28 +211,15 @@ func TestGetKey_AdditionalScenarios(t *testing.T) { }) t.Run("key with expiration date", func(t *testing.T) { - keyID := uid.New(uid.KeyPrefix) - key, _ := h.Keys.CreateKey(ctx, keys.CreateKeyRequest{ - Prefix: "test", - ByteLength: 16, - }) - futureDate := time.Now().Add(24 * time.Hour).Truncate(time.Hour) - err := db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: key.Hash, - Start: key.Start, + keyResponse := h.CreateKey(seed.CreateKeyRequest{ WorkspaceID: workspace.ID, - Name: sql.NullString{Valid: true, String: "expiring-key"}, - Expires: sql.NullTime{Valid: true, Time: futureDate}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, + KeyAuthID: api.KeyAuthID.String, + Expires: &futureDate, }) - require.NoError(t, err) req := handler.Request{ - KeyId: ptr.P(keyID), + KeyId: ptr.P(keyResponse.KeyID), Decrypt: ptr.P(false), } @@ -317,29 +231,15 @@ func TestGetKey_AdditionalScenarios(t *testing.T) { }) t.Run("key with credits and daily refill", func(t *testing.T) { - keyID := uid.New(uid.KeyPrefix) - key, _ := h.Keys.CreateKey(ctx, keys.CreateKeyRequest{ - Prefix: "test", - ByteLength: 16, - }) - - err := db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: key.Hash, - Start: key.Start, - WorkspaceID: workspace.ID, - Name: sql.NullString{Valid: true, String: "credits-key"}, - RemainingRequests: sql.NullInt32{Valid: true, Int32: 50}, - RefillAmount: sql.NullInt32{Valid: true, Int32: 100}, - RefillDay: sql.NullInt16{Valid: false, Int16: 0}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Remaining: ptr.P(int32(50)), + RefillAmount: ptr.P(int32(100)), }) - require.NoError(t, err) req := handler.Request{ - KeyId: ptr.P(keyID), + KeyId: ptr.P(keyResponse.KeyID), Decrypt: ptr.P(false), } @@ -354,29 +254,16 @@ func TestGetKey_AdditionalScenarios(t *testing.T) { }) t.Run("key with monthly refill", func(t *testing.T) { - keyID := uid.New(uid.KeyPrefix) - key, _ := h.Keys.CreateKey(ctx, keys.CreateKeyRequest{ - Prefix: "test", - ByteLength: 16, - }) - - err := db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: key.Hash, - Start: key.Start, - WorkspaceID: workspace.ID, - Name: sql.NullString{Valid: true, String: "monthly-refill-key"}, - RemainingRequests: sql.NullInt32{Valid: true, Int32: 1000}, - RefillAmount: sql.NullInt32{Valid: true, Int32: 2000}, - RefillDay: sql.NullInt16{Valid: true, Int16: 1}, // 1st of month - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Remaining: ptr.P(int32(50)), + RefillAmount: ptr.P(int32(100)), + RefillDay: ptr.P(int16(1)), }) - require.NoError(t, err) req := handler.Request{ - KeyId: ptr.P(keyID), + KeyId: ptr.P(keyResponse.KeyID), Decrypt: ptr.P(false), } @@ -390,80 +277,29 @@ func TestGetKey_AdditionalScenarios(t *testing.T) { }) t.Run("key with roles and permissions", func(t *testing.T) { - keyID := uid.New(uid.KeyPrefix) - key, _ := h.Keys.CreateKey(ctx, keys.CreateKeyRequest{ - Prefix: "test", - ByteLength: 16, - }) - - err := db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: key.Hash, - Start: key.Start, - WorkspaceID: workspace.ID, - Name: sql.NullString{Valid: true, String: "rbac-key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - }) - require.NoError(t, err) - - // Create permissions - perm1ID := uid.New(uid.PermissionPrefix) - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ - PermissionID: perm1ID, - WorkspaceID: workspace.ID, - Name: "read_data", - Slug: "read_data", - CreatedAtM: time.Now().UnixMilli(), - }) - require.NoError(t, err) - - perm2ID := uid.New(uid.PermissionPrefix) - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ - PermissionID: perm2ID, - WorkspaceID: workspace.ID, - Name: "write_data", - Slug: "write_data", - CreatedAtM: time.Now().UnixMilli(), - }) - require.NoError(t, err) - - // Create role - roleID := uid.New(uid.RolePrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ - RoleID: roleID, - WorkspaceID: workspace.ID, - Name: "data_admin", - }) - require.NoError(t, err) - - // Assign permissions to key - err = db.Query.InsertKeyPermission(ctx, h.DB.RW(), db.InsertKeyPermissionParams{ - KeyID: keyID, - PermissionID: perm1ID, - WorkspaceID: workspace.ID, - }) - require.NoError(t, err) - - err = db.Query.InsertKeyPermission(ctx, h.DB.RW(), db.InsertKeyPermissionParams{ - KeyID: keyID, - PermissionID: perm2ID, - WorkspaceID: workspace.ID, - }) - require.NoError(t, err) - - // Assign role to key - err = db.Query.InsertKeyRole(ctx, h.DB.RW(), db.InsertKeyRoleParams{ - KeyID: keyID, - RoleID: roleID, + keyResponse := h.CreateKey(seed.CreateKeyRequest{ WorkspaceID: workspace.ID, - CreatedAtM: time.Now().UnixMilli(), + KeyAuthID: api.KeyAuthID.String, + Permissions: []seed.CreatePermissionRequest{{ + Name: "read_data", + Slug: "read_data", + Description: nil, + WorkspaceID: workspace.ID, + }, { + Name: "write_data", + Slug: "write_data", + Description: nil, + WorkspaceID: workspace.ID, + }}, + Roles: []seed.CreateRoleRequest{{ + Name: "data_admin", + Description: nil, + WorkspaceID: workspace.ID, + }}, }) - require.NoError(t, err) req := handler.Request{ - KeyId: ptr.P(keyID), + KeyId: ptr.P(keyResponse.KeyID), Decrypt: ptr.P(false), } @@ -484,52 +320,33 @@ func TestGetKey_AdditionalScenarios(t *testing.T) { }) t.Run("key with ratelimits", func(t *testing.T) { - keyID := uid.New(uid.KeyPrefix) - key, _ := h.Keys.CreateKey(ctx, keys.CreateKeyRequest{ - Prefix: "test", - ByteLength: 16, - }) - - err := db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: key.Hash, - Start: key.Start, - WorkspaceID: workspace.ID, - Name: sql.NullString{Valid: true, String: "ratelimited-key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - }) - require.NoError(t, err) - - // Create ratelimits for the key - rl1ID := uid.New(uid.RatelimitPrefix) - err = db.Query.InsertKeyRatelimit(ctx, h.DB.RW(), db.InsertKeyRatelimitParams{ - ID: rl1ID, + keyResponse := h.CreateKey(seed.CreateKeyRequest{ WorkspaceID: workspace.ID, - KeyID: sql.NullString{Valid: true, String: keyID}, - Name: "api_calls", - Limit: 100, - Duration: 60000, // 1 minute - CreatedAt: time.Now().UnixMilli(), - }) - require.NoError(t, err) - - rl2ID := uid.New(uid.RatelimitPrefix) - err = db.Query.InsertKeyRatelimit(ctx, h.DB.RW(), db.InsertKeyRatelimitParams{ - ID: rl2ID, - WorkspaceID: workspace.ID, - KeyID: sql.NullString{Valid: true, String: keyID}, - Name: "data_transfer", - Limit: 1000, - Duration: 3600000, // 1 hour - AutoApply: true, - CreatedAt: time.Now().UnixMilli(), + KeyAuthID: api.KeyAuthID.String, + Ratelimits: []seed.CreateRatelimitRequest{ + { + Name: "api_calls", + WorkspaceID: workspace.ID, + AutoApply: false, + Duration: 60000, // 1minute + Limit: 100, + IdentityID: nil, + KeyID: nil, + }, + { + Name: "data_transfer", + WorkspaceID: workspace.ID, + AutoApply: true, + Duration: 3600000, // 1 hour + Limit: 1000, + IdentityID: nil, + KeyID: nil, + }, + }, }) - require.NoError(t, err) req := handler.Request{ - KeyId: ptr.P(keyID), + KeyId: ptr.P(keyResponse.KeyID), Decrypt: ptr.P(false), } @@ -564,26 +381,14 @@ func TestGetKey_AdditionalScenarios(t *testing.T) { }) t.Run("disabled key", func(t *testing.T) { - keyID := uid.New(uid.KeyPrefix) - key, _ := h.Keys.CreateKey(ctx, keys.CreateKeyRequest{ - Prefix: "test", - ByteLength: 16, - }) - - err := db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: key.Hash, - Start: key.Start, + keyResponse := h.CreateKey(seed.CreateKeyRequest{ WorkspaceID: workspace.ID, - Name: sql.NullString{Valid: true, String: "disabled-key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: false, // Key is disabled + KeyAuthID: api.KeyAuthID.String, + Disabled: true, }) - require.NoError(t, err) req := handler.Request{ - KeyId: ptr.P(keyID), + KeyId: ptr.P(keyResponse.KeyID), Decrypt: ptr.P(false), } diff --git a/go/apps/api/routes/v2_keys_get_key/400_test.go b/go/apps/api/routes/v2_keys_get_key/400_test.go index fd71d13872..c3d5ef695d 100644 --- a/go/apps/api/routes/v2_keys_get_key/400_test.go +++ b/go/apps/api/routes/v2_keys_get_key/400_test.go @@ -16,12 +16,11 @@ func Test_GetKey_BadRequest(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, - Vault: h.Vault, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + Vault: h.Vault, } h.Register(route) diff --git a/go/apps/api/routes/v2_keys_get_key/401_test.go b/go/apps/api/routes/v2_keys_get_key/401_test.go index fcd649ecd2..0bf26776ab 100644 --- a/go/apps/api/routes/v2_keys_get_key/401_test.go +++ b/go/apps/api/routes/v2_keys_get_key/401_test.go @@ -16,12 +16,11 @@ func Test_GetKey_Unauthorized(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, - Vault: h.Vault, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + Vault: h.Vault, } h.Register(route) diff --git a/go/apps/api/routes/v2_keys_get_key/403_test.go b/go/apps/api/routes/v2_keys_get_key/403_test.go index 8c9f21545e..64600b86d5 100644 --- a/go/apps/api/routes/v2_keys_get_key/403_test.go +++ b/go/apps/api/routes/v2_keys_get_key/403_test.go @@ -24,11 +24,11 @@ func Test_GetKey_Forbidden(t *testing.T) { ctx := context.Background() route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + Vault: h.Vault, } h.Register(route) @@ -118,10 +118,6 @@ func Test_GetKey_Forbidden(t *testing.T) { Meta: sql.NullString{Valid: false}, Expires: sql.NullTime{Valid: false}, RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, }) require.NoError(t, err) diff --git a/go/apps/api/routes/v2_keys_get_key/404_test.go b/go/apps/api/routes/v2_keys_get_key/404_test.go index 04e420658b..2604031b65 100644 --- a/go/apps/api/routes/v2_keys_get_key/404_test.go +++ b/go/apps/api/routes/v2_keys_get_key/404_test.go @@ -17,12 +17,11 @@ func Test_GetKey_NotFound(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, - Vault: h.Vault, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + Vault: h.Vault, } h.Register(route) diff --git a/go/apps/api/routes/v2_keys_get_key/412_test.go b/go/apps/api/routes/v2_keys_get_key/412_test.go index da5c63572c..d73c574687 100644 --- a/go/apps/api/routes/v2_keys_get_key/412_test.go +++ b/go/apps/api/routes/v2_keys_get_key/412_test.go @@ -1,58 +1,36 @@ package handler_test import ( - "context" - "database/sql" "fmt" "net/http" "testing" - "time" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_get_key" - "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/ptr" "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/uid" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" ) func TestPreconditionError(t *testing.T) { - ctx := context.Background() h := testutil.NewHarness(t) route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Vault: h.Vault, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Vault: h.Vault, } h.Register(route) - // Create API manually - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: time.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: false, String: ""}, - DefaultBytes: sql.NullInt32{Valid: false, Int32: 0}, - }) - require.NoError(t, err) - - apiID := uid.New(uid.APIPrefix) - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: "test-api", + // Create API using testutil helper + apiName := "test-api" + api := h.CreateApi(seed.CreateApiRequest{ WorkspaceID: h.Resources().UserWorkspace.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: keyAuthID}, - CreatedAtM: time.Now().UnixMilli(), + Name: &apiName, }) - require.NoError(t, err) // Create a root key with appropriate permissions rootKey := h.CreateRootKey(h.Resources().UserWorkspace.ID, "api.*.read_key", "api.*.decrypt_key") @@ -63,33 +41,16 @@ func TestPreconditionError(t *testing.T) { "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, } - keyID := uid.New(uid.KeyPrefix) - key, _ := h.Keys.CreateKey(ctx, keys.CreateKeyRequest{ - Prefix: "test", - ByteLength: 16, + key := h.CreateKey(seed.CreateKeyRequest{ + KeyAuthID: api.KeyAuthID.String, + WorkspaceID: h.Resources().UserWorkspace.ID, }) - insertParams := db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: key.Hash, - Start: key.Start, - WorkspaceID: h.Resources().UserWorkspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "test-key"}, - Expires: sql.NullTime{Valid: false}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - } - - err = db.Query.InsertKey(ctx, h.DB.RW(), insertParams) - require.NoError(t, err) - // Test case for API ID with special characters t.Run("Try getting a recoverable key without being opt-in", func(t *testing.T) { req := handler.Request{ Decrypt: ptr.P(true), - KeyId: ptr.P(keyID), + KeyId: ptr.P(key.KeyID), } res := testutil.CallRoute[handler.Request, openapi.PreconditionFailedErrorResponse]( diff --git a/go/apps/api/routes/v2_keys_get_key/handler.go b/go/apps/api/routes/v2_keys_get_key/handler.go index ebe6ad6947..81f885b83a 100644 --- a/go/apps/api/routes/v2_keys_get_key/handler.go +++ b/go/apps/api/routes/v2_keys_get_key/handler.go @@ -11,7 +11,6 @@ import ( vaultv1 "github.com/unkeyed/unkey/go/gen/proto/vault/v1" "github.com/unkeyed/unkey/go/internal/services/auditlogs" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/fault" @@ -29,12 +28,11 @@ type Response = openapi.V2KeysGetKeyResponseBody // Handler implements zen.Route interface for the v2 keys.getKey endpoint type Handler struct { // Services as public fields - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService - Auditlogs auditlogs.AuditLogService - Vault *vault.Service + Logger logging.Logger + DB db.Database + Keys keys.KeyService + Auditlogs auditlogs.AuditLogService + Vault *vault.Service } // Method returns the HTTP method this route responds to @@ -51,7 +49,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { h.Logger.Debug("handling request", "requestId", s.RequestID(), "path", "/v2/keys.getKey") // Authentication - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -113,22 +111,18 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { } // Permission check - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: "*", - Action: rbac.ReadKey, - }), - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: key.Api.ID, - Action: rbac.ReadKey, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: "*", + Action: rbac.ReadKey, + }), + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: key.Api.ID, + Action: rbac.ReadKey, + }), + ))) if err != nil { return err } @@ -150,22 +144,26 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { decrypt := ptr.SafeDeref(req.Decrypt, false) var plaintext *string if decrypt { - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: "*", - Action: rbac.DecryptKey, - }), - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: key.Api.ID, - Action: rbac.DecryptKey, - }), - ), - ) + if h.Vault == nil { + return fault.New("vault missing", + fault.Code(codes.App.Precondition.PreconditionFailed.URN()), + fault.Public("Vault hasn't been set up."), + ) + } + + // Permission check + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: "*", + Action: rbac.DecryptKey, + }), + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: key.Api.ID, + Action: rbac.DecryptKey, + }), + ))) if err != nil { return err } diff --git a/go/apps/api/routes/v2_keys_remove_permissions/200_test.go b/go/apps/api/routes/v2_keys_remove_permissions/200_test.go index 3a0fd30bc9..c085c61e5d 100644 --- a/go/apps/api/routes/v2_keys_remove_permissions/200_test.go +++ b/go/apps/api/routes/v2_keys_remove_permissions/200_test.go @@ -11,8 +11,8 @@ import ( "github.com/stretchr/testify/require" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_remove_permissions" "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/uid" ) @@ -21,11 +21,11 @@ func TestSuccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -41,61 +41,33 @@ func TestSuccess(t *testing.T) { } t.Run("remove single permission by ID", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API with keyring using testutil helper + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, - }) - require.NoError(t, err) - - // Create a permission - permissionID := uid.New(uid.TestPrefix) - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ - PermissionID: permissionID, - WorkspaceID: workspace.ID, - Name: "documents.read.remove.single.id", - Slug: "documents.read.remove.single.id", - Description: sql.NullString{Valid: true, String: "Read documents permission"}, - }) - require.NoError(t, err) - // Add permission to key first - err = db.Query.InsertKeyPermission(ctx, h.DB.RW(), db.InsertKeyPermissionParams{ - KeyID: keyID, - PermissionID: permissionID, - WorkspaceID: workspace.ID, - CreatedAt: time.Now().UnixMilli(), + // Create a test key with permission using testutil helper + keyName := "Test Key" + permissionDescription := "Read documents permission" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, + Permissions: []seed.CreatePermissionRequest{ + { + WorkspaceID: workspace.ID, + Name: "documents.read.remove.single.id", + Slug: "documents.read.remove.single.id", + Description: &permissionDescription, + }, + }, }) - require.NoError(t, err) + keyID := keyResponse.KeyID + permissionID := keyResponse.PermissionIds[0] // Verify key has the permission initially currentPermissions, err := db.Query.ListDirectPermissionsByKeyID(ctx, h.DB.RO(), keyID) @@ -146,62 +118,33 @@ func TestSuccess(t *testing.T) { }) t.Run("remove single permission by name", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API with keyring using testutil helper + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, - }) - require.NoError(t, err) - - // Create a permission - permissionID := uid.New(uid.TestPrefix) + // Create a test key with permission using testutil helper + keyName := "Test Key" permissionName := "documents.write.remove.single.name" - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ - PermissionID: permissionID, - WorkspaceID: workspace.ID, - Name: permissionName, - Slug: permissionName, - Description: sql.NullString{Valid: true, String: "Write documents permission"}, - }) - require.NoError(t, err) - - // Add permission to key first - err = db.Query.InsertKeyPermission(ctx, h.DB.RW(), db.InsertKeyPermissionParams{ - KeyID: keyID, - PermissionID: permissionID, - WorkspaceID: workspace.ID, - CreatedAt: time.Now().UnixMilli(), + permissionDescription := "Write documents permission" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, + Permissions: []seed.CreatePermissionRequest{ + { + WorkspaceID: workspace.ID, + Name: permissionName, + Slug: permissionName, + Description: &permissionDescription, + }, + }, }) - require.NoError(t, err) + keyID := keyResponse.KeyID req := handler.Request{ KeyId: keyID, @@ -232,67 +175,44 @@ func TestSuccess(t *testing.T) { }) t.Run("remove multiple permissions", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API with keyring using testutil helper + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + // Create a test key using testutil helper + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) - - // Create permissions - permission1ID := uid.New(uid.TestPrefix) - permission1Name := "documents.read.remove.multiple" - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ - PermissionID: permission1ID, - WorkspaceID: workspace.ID, - Name: permission1Name, - Slug: permission1Name, - Description: sql.NullString{Valid: true, String: "Read documents permission"}, + keyID := keyResponse.KeyID + + // Create permissions using testutil helpers + permission1Description := "Read documents permission" + permission1ID := h.CreatePermission(seed.CreatePermissionRequest{ + WorkspaceID: workspace.ID, + Name: "documents.read.remove.multiple", + Slug: "documents.read.remove.multiple", + Description: &permission1Description, }) - require.NoError(t, err) - permission2ID := uid.New(uid.TestPrefix) + permission2Description := "Write documents permission" permission2Name := "documents.write.remove.multiple" - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ - PermissionID: permission2ID, - WorkspaceID: workspace.ID, - Name: permission2Name, - Slug: permission2Name, - Description: sql.NullString{Valid: true, String: "Write documents permission"}, + permission2ID := h.CreatePermission(seed.CreatePermissionRequest{ + WorkspaceID: workspace.ID, + Name: permission2Name, + Slug: permission2Name, + Description: &permission2Description, }) - require.NoError(t, err) // Add both permissions to key first - err = db.Query.InsertKeyPermission(ctx, h.DB.RW(), db.InsertKeyPermissionParams{ + err := db.Query.InsertKeyPermission(ctx, h.DB.RW(), db.InsertKeyPermissionParams{ KeyID: keyID, PermissionID: permission1ID, WorkspaceID: workspace.ID, @@ -338,52 +258,31 @@ func TestSuccess(t *testing.T) { }) t.Run("idempotent operation - removing permission that isn't assigned", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create a permission but don't assign it to the key - permissionID := uid.New(uid.TestPrefix) - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ - PermissionID: permissionID, - WorkspaceID: workspace.ID, - Name: "documents.read.remove.idempotent", - Slug: "documents.read.remove.idempotent", - Description: sql.NullString{Valid: true, String: "Read documents permission"}, + permissionDescription := "Read documents permission" + permissionID := h.CreatePermission(seed.CreatePermissionRequest{ + WorkspaceID: workspace.ID, + Name: "documents.read.remove.idempotent", + Slug: "documents.read.remove.idempotent", + Description: &permissionDescription, }) - require.NoError(t, err) req := handler.Request{ KeyId: keyID, @@ -424,45 +323,26 @@ func TestSuccess(t *testing.T) { }) t.Run("remove some permissions from key with multiple permissions", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create permissions keepPermissionID := uid.New(uid.TestPrefix) - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ + err := db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ PermissionID: keepPermissionID, WorkspaceID: workspace.ID, Name: "documents.read.remove.partial.keep", @@ -528,75 +408,50 @@ func TestSuccess(t *testing.T) { }) t.Run("remove all permissions from key", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) - - // Create multiple permissions and add them all to the key - permission1ID := uid.New(uid.TestPrefix) - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ - PermissionID: permission1ID, - WorkspaceID: workspace.ID, - Name: "documents.read.remove.all.1", - Slug: "documents.read.remove.all.1", - Description: sql.NullString{Valid: true, String: "Read documents permission"}, + keyID := keyResponse.KeyID + + // Create multiple permissions using testutil helpers + permission1Description := "Read documents permission" + permission1ID := h.CreatePermission(seed.CreatePermissionRequest{ + WorkspaceID: workspace.ID, + Name: "documents.read.remove.all.1", + Slug: "documents.read.remove.all.1", + Description: &permission1Description, }) - require.NoError(t, err) - permission2ID := uid.New(uid.TestPrefix) - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ - PermissionID: permission2ID, - WorkspaceID: workspace.ID, - Name: "documents.write.remove.all.2", - Slug: "documents.write.remove.all.2", - Description: sql.NullString{Valid: true, String: "Write documents permission"}, + permission2Description := "Write documents permission" + permission2ID := h.CreatePermission(seed.CreatePermissionRequest{ + WorkspaceID: workspace.ID, + Name: "documents.write.remove.all.2", + Slug: "documents.write.remove.all.2", + Description: &permission2Description, }) - require.NoError(t, err) - permission3ID := uid.New(uid.TestPrefix) - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ - PermissionID: permission3ID, - WorkspaceID: workspace.ID, - Name: "documents.delete.remove.all.3", - Slug: "documents.delete.remove.all.3", - Description: sql.NullString{Valid: true, String: "Delete documents permission"}, + permission3Description := "Delete documents permission" + permission3ID := h.CreatePermission(seed.CreatePermissionRequest{ + WorkspaceID: workspace.ID, + Name: "documents.delete.remove.all.3", + Slug: "documents.delete.remove.all.3", + Description: &permission3Description, }) - require.NoError(t, err) // Add all permissions to key - err = db.Query.InsertKeyPermission(ctx, h.DB.RW(), db.InsertKeyPermissionParams{ + err := db.Query.InsertKeyPermission(ctx, h.DB.RW(), db.InsertKeyPermissionParams{ KeyID: keyID, PermissionID: permission1ID, WorkspaceID: workspace.ID, diff --git a/go/apps/api/routes/v2_keys_remove_permissions/400_test.go b/go/apps/api/routes/v2_keys_remove_permissions/400_test.go index 4a4d10f7e5..a8b88fcbcd 100644 --- a/go/apps/api/routes/v2_keys_remove_permissions/400_test.go +++ b/go/apps/api/routes/v2_keys_remove_permissions/400_test.go @@ -6,14 +6,13 @@ import ( "fmt" "net/http" "testing" - "time" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_remove_permissions" "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/uid" ) @@ -22,11 +21,11 @@ func TestValidationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -41,40 +40,22 @@ func TestValidationErrors(t *testing.T) { "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, } - // Create a valid key for testing - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create a valid API and key for testing using testutil helper + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - - validKeyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: validKeyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + validKeyID := keyResponse.KeyID t.Run("missing keyId", func(t *testing.T) { req := map[string]interface{}{ @@ -178,7 +159,7 @@ func TestValidationErrors(t *testing.T) { t.Run("permission missing both id and slug", func(t *testing.T) { // Create a permission for valid structure permissionID := uid.New(uid.TestPrefix) - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ + err := db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ PermissionID: permissionID, WorkspaceID: workspace.ID, Name: "documents.read.remove.validation", diff --git a/go/apps/api/routes/v2_keys_remove_permissions/401_test.go b/go/apps/api/routes/v2_keys_remove_permissions/401_test.go index bc3838870a..1568731bc0 100644 --- a/go/apps/api/routes/v2_keys_remove_permissions/401_test.go +++ b/go/apps/api/routes/v2_keys_remove_permissions/401_test.go @@ -1,32 +1,26 @@ package handler_test import ( - "context" - "database/sql" "fmt" "net/http" "testing" - "time" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_remove_permissions" - "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/uid" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" ) func TestAuthenticationErrors(t *testing.T) { - ctx := context.Background() h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -34,50 +28,32 @@ func TestAuthenticationErrors(t *testing.T) { // Create a workspace workspace := h.Resources().UserWorkspace - // Create test data - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create test data using testutil helper + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, - }) - require.NoError(t, err) - - permissionID := uid.New(uid.TestPrefix) - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ - PermissionID: permissionID, - WorkspaceID: workspace.ID, - Name: "documents.read.remove.auth", - Slug: "documents.read.remove.auth", - Description: sql.NullString{Valid: true, String: "Read documents permission"}, + + keyName := "Test Key" + permissionDescription := "Read documents permission" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, + Permissions: []seed.CreatePermissionRequest{ + { + WorkspaceID: workspace.ID, + Name: "documents.read.remove.auth", + Slug: "documents.read.remove.auth", + Description: &permissionDescription, + }, + }, }) - require.NoError(t, err) + keyID := keyResponse.KeyID + permissionID := keyResponse.PermissionIds[0] req := handler.Request{ KeyId: keyID, diff --git a/go/apps/api/routes/v2_keys_remove_permissions/403_test.go b/go/apps/api/routes/v2_keys_remove_permissions/403_test.go index 539394d5c1..751bc431f1 100644 --- a/go/apps/api/routes/v2_keys_remove_permissions/403_test.go +++ b/go/apps/api/routes/v2_keys_remove_permissions/403_test.go @@ -12,8 +12,8 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_remove_permissions" "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/uid" ) @@ -22,11 +22,11 @@ func TestAuthorizationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -34,50 +34,32 @@ func TestAuthorizationErrors(t *testing.T) { // Create a workspace workspace := h.Resources().UserWorkspace - // Create test data - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create test data using testutil helper + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, - }) - require.NoError(t, err) - - permissionID := uid.New(uid.TestPrefix) - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ - PermissionID: permissionID, - WorkspaceID: workspace.ID, - Name: "documents.read.remove.auth403", - Slug: "documents.read.remove.auth403", - Description: sql.NullString{Valid: true, String: "Read documents permission"}, + + keyName := "Test Key" + permissionDescription := "Read documents permission" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, + Permissions: []seed.CreatePermissionRequest{ + { + WorkspaceID: workspace.ID, + Name: "documents.read.remove.auth403", + Slug: "documents.read.remove.auth403", + Description: &permissionDescription, + }, + }, }) - require.NoError(t, err) + keyID := keyResponse.KeyID + permissionID := keyResponse.PermissionIds[0] req := handler.Request{ KeyId: keyID, @@ -142,41 +124,22 @@ func TestAuthorizationErrors(t *testing.T) { }) require.NoError(t, err) - // Create keyring in other workspace - otherKeyAuthID := uid.New(uid.KeyAuthPrefix) - err = db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: otherKeyAuthID, - WorkspaceID: otherWorkspaceID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key in other workspace using testutil helper + otherApiName := "Other Workspace API" + otherApi := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: otherWorkspaceID, + Name: &otherApiName, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create key in other workspace - otherKeyID := uid.New(uid.KeyPrefix) - otherKeyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: otherKeyID, - KeyringID: otherKeyAuthID, - Hash: hash.Sha256(otherKeyString), - Start: otherKeyString[:4], - WorkspaceID: otherWorkspaceID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Other Workspace Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + otherKeyName := "Other Workspace Key" + otherKeyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: otherWorkspaceID, + KeyAuthID: otherApi.KeyAuthID.String, + Name: &otherKeyName, }) - require.NoError(t, err) + otherKeyID := otherKeyResponse.KeyID // Create root key for original workspace (authorized for workspace.ID, not otherWorkspaceID) authorizedRootKey := h.CreateRootKey(workspace.ID, "api.*.update_key") diff --git a/go/apps/api/routes/v2_keys_remove_permissions/404_test.go b/go/apps/api/routes/v2_keys_remove_permissions/404_test.go index b02896d722..75e77c0270 100644 --- a/go/apps/api/routes/v2_keys_remove_permissions/404_test.go +++ b/go/apps/api/routes/v2_keys_remove_permissions/404_test.go @@ -14,6 +14,7 @@ import ( "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/uid" ) @@ -22,11 +23,11 @@ func TestNotFoundErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -42,16 +43,14 @@ func TestNotFoundErrors(t *testing.T) { } t.Run("key not found", func(t *testing.T) { - // Create a permission that exists - permissionID := uid.New(uid.TestPrefix) - err := db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ - PermissionID: permissionID, - WorkspaceID: workspace.ID, - Name: "documents.read.remove.404keynotfound", - Slug: "documents.read.remove.404keynotfound", - Description: sql.NullString{Valid: true, String: "Read documents permission"}, + // Create a permission that exists using testutil helper + permissionDescription := "Read documents permission" + permissionID := h.CreatePermission(seed.CreatePermissionRequest{ + WorkspaceID: workspace.ID, + Name: "documents.read.remove.404keynotfound", + Slug: "documents.read.remove.404keynotfound", + Description: &permissionDescription, }) - require.NoError(t, err) // Use a non-existent key ID nonExistentKeyID := uid.New(uid.KeyPrefix) @@ -79,41 +78,22 @@ func TestNotFoundErrors(t *testing.T) { }) t.Run("permission not found by ID", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Use a non-existent permission ID nonExistentPermissionID := uid.New(uid.TestPrefix) @@ -141,41 +121,22 @@ func TestNotFoundErrors(t *testing.T) { }) t.Run("permission not found by slug", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Use a non-existent permission name nonExistentPermissionSlug := "nonexistent.permission.remove.name" @@ -253,10 +214,6 @@ func TestNotFoundErrors(t *testing.T) { Meta: sql.NullString{Valid: false}, Expires: sql.NullTime{Valid: false}, RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, }) require.NoError(t, err) @@ -322,10 +279,6 @@ func TestNotFoundErrors(t *testing.T) { Meta: sql.NullString{Valid: false}, Expires: sql.NullTime{Valid: false}, RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, }) require.NoError(t, err) diff --git a/go/apps/api/routes/v2_keys_remove_permissions/handler.go b/go/apps/api/routes/v2_keys_remove_permissions/handler.go index 28e6f98a60..2487db8de3 100644 --- a/go/apps/api/routes/v2_keys_remove_permissions/handler.go +++ b/go/apps/api/routes/v2_keys_remove_permissions/handler.go @@ -9,8 +9,8 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/auditlogs" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/auditlog" + "github.com/unkeyed/unkey/go/pkg/cache" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/fault" @@ -25,11 +25,11 @@ type Response = openapi.V2KeysRemovePermissionsResponse // Handler implements zen.Route interface for the v2 keys remove permissions endpoint type Handler struct { // Services as public fields - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService - Auditlogs auditlogs.AuditLogService + Logger logging.Logger + DB db.Database + Keys keys.KeyService + Auditlogs auditlogs.AuditLogService + KeyCache cache.Cache[string, db.FindKeyForVerificationRow] } // Method returns the HTTP method this route responds to @@ -47,7 +47,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { h.Logger.Debug("handling request", "requestId", s.RequestID(), "path", "/v2/keys.removePermissions") // 1. Authentication - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -59,17 +59,13 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { } // 3. Permission check - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: "*", - Action: rbac.UpdateKey, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: "*", + Action: rbac.UpdateKey, + }), + ))) if err != nil { return err } @@ -198,7 +194,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: auditlog.AuthDisconnectPermissionKeyEvent, ActorType: auditlog.RootKeyActor, - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorName: "root key", ActorMeta: map[string]any{}, Display: fmt.Sprintf("Removed permission %s from key %s", permission.Name, req.KeyId), @@ -236,6 +232,8 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { if err != nil { return err } + + h.KeyCache.Remove(ctx, key.Hash) } // 9. Get final state of direct permissions and build response diff --git a/go/apps/api/routes/v2_keys_remove_roles/200_test.go b/go/apps/api/routes/v2_keys_remove_roles/200_test.go index bd861a917d..83a6e2dc5c 100644 --- a/go/apps/api/routes/v2_keys_remove_roles/200_test.go +++ b/go/apps/api/routes/v2_keys_remove_roles/200_test.go @@ -11,8 +11,8 @@ import ( "github.com/stretchr/testify/require" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_remove_roles" "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/uid" ) @@ -21,11 +21,11 @@ func TestSuccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -41,45 +41,26 @@ func TestSuccess(t *testing.T) { } t.Run("remove single role by ID", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create roles role1ID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: role1ID, WorkspaceID: workspace.ID, Name: "admin", @@ -165,46 +146,27 @@ func TestSuccess(t *testing.T) { }) t.Run("remove single role by name", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create a role roleID := uid.New(uid.TestPrefix) roleName := "editor" - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: roleID, WorkspaceID: workspace.ID, Name: roleName, @@ -255,45 +217,26 @@ func TestSuccess(t *testing.T) { }) t.Run("idempotent operation - removing non-assigned role", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create a role but don't assign it to the key roleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: roleID, WorkspaceID: workspace.ID, Name: "unassigned_role", diff --git a/go/apps/api/routes/v2_keys_remove_roles/400_test.go b/go/apps/api/routes/v2_keys_remove_roles/400_test.go index dc5cc1feb8..ab28560a4e 100644 --- a/go/apps/api/routes/v2_keys_remove_roles/400_test.go +++ b/go/apps/api/routes/v2_keys_remove_roles/400_test.go @@ -6,14 +6,13 @@ import ( "fmt" "net/http" "testing" - "time" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_remove_roles" "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/uid" ) @@ -22,11 +21,11 @@ func TestValidationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -139,40 +138,22 @@ func TestValidationErrors(t *testing.T) { }) t.Run("missing roles field", func(t *testing.T) { - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Use map to completely omit the roles field reqMap := map[string]interface{}{ @@ -192,40 +173,22 @@ func TestValidationErrors(t *testing.T) { }) t.Run("empty roles array", func(t *testing.T) { - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID req := handler.Request{ KeyId: keyID, @@ -248,40 +211,22 @@ func TestValidationErrors(t *testing.T) { }) t.Run("role missing both id and name", func(t *testing.T) { - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID req := handler.Request{ KeyId: keyID, @@ -307,40 +252,22 @@ func TestValidationErrors(t *testing.T) { }) t.Run("role not found by ID", func(t *testing.T) { - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID nonExistentRoleID := uid.New(uid.TestPrefix) @@ -367,40 +294,22 @@ func TestValidationErrors(t *testing.T) { }) t.Run("role not found by name", func(t *testing.T) { - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID nonExistentRoleName := "non_existent_role_name" @@ -462,45 +371,27 @@ func TestValidationErrors(t *testing.T) { }) t.Run("role from different workspace by ID", func(t *testing.T) { - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create a second workspace and role in it workspace2 := h.CreateWorkspace() roleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: roleID, WorkspaceID: workspace2.ID, Name: "role_different_workspace_" + uid.New(""), @@ -532,46 +423,28 @@ func TestValidationErrors(t *testing.T) { }) t.Run("role from different workspace by name", func(t *testing.T) { - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create a second workspace and role in it workspace2 := h.CreateWorkspace() roleID := uid.New(uid.TestPrefix) roleName := "role_different_workspace_by_name_" + uid.New("") - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: roleID, WorkspaceID: workspace2.ID, Name: roleName, diff --git a/go/apps/api/routes/v2_keys_remove_roles/401_test.go b/go/apps/api/routes/v2_keys_remove_roles/401_test.go index 3f75f83853..6f07292308 100644 --- a/go/apps/api/routes/v2_keys_remove_roles/401_test.go +++ b/go/apps/api/routes/v2_keys_remove_roles/401_test.go @@ -6,14 +6,13 @@ import ( "fmt" "net/http" "testing" - "time" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_remove_roles" "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/uid" ) @@ -22,11 +21,11 @@ func TestAuthenticationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -35,44 +34,26 @@ func TestAuthenticationErrors(t *testing.T) { workspace := h.Resources().UserWorkspace t.Run("missing authorization header", func(t *testing.T) { - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create a role roleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: roleID, WorkspaceID: workspace.ID, Name: "test_role_missing_auth_" + uid.New(""), @@ -108,44 +89,26 @@ func TestAuthenticationErrors(t *testing.T) { }) t.Run("invalid bearer token", func(t *testing.T) { - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create a role roleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: roleID, WorkspaceID: workspace.ID, Name: "test_role_invalid_token_" + uid.New(""), @@ -182,44 +145,26 @@ func TestAuthenticationErrors(t *testing.T) { }) t.Run("malformed authorization header", func(t *testing.T) { - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create a role roleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: roleID, WorkspaceID: workspace.ID, Name: "test_role_malformed_" + uid.New(""), @@ -256,44 +201,26 @@ func TestAuthenticationErrors(t *testing.T) { }) t.Run("empty bearer token", func(t *testing.T) { - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create a role roleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: roleID, WorkspaceID: workspace.ID, Name: "test_role_empty_token_" + uid.New(""), @@ -330,44 +257,26 @@ func TestAuthenticationErrors(t *testing.T) { }) t.Run("non-existent root key", func(t *testing.T) { - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create a role roleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: roleID, WorkspaceID: workspace.ID, Name: "test_role_nonexistent_" + uid.New(""), @@ -408,44 +317,26 @@ func TestAuthenticationErrors(t *testing.T) { workspace2 := h.CreateWorkspace() rootKeyFromDifferentWorkspace := h.CreateRootKey(workspace2.ID, "api.*.update_key") - // Create a test keyring and key in original workspace - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key in original workspace using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create a role in original workspace roleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: roleID, WorkspaceID: workspace.ID, Name: "test_role_diff_workspace_" + uid.New(""), diff --git a/go/apps/api/routes/v2_keys_remove_roles/403_test.go b/go/apps/api/routes/v2_keys_remove_roles/403_test.go index ffcffe1c48..7261e9b957 100644 --- a/go/apps/api/routes/v2_keys_remove_roles/403_test.go +++ b/go/apps/api/routes/v2_keys_remove_roles/403_test.go @@ -6,14 +6,13 @@ import ( "fmt" "net/http" "testing" - "time" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_remove_roles" "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/uid" ) @@ -22,11 +21,11 @@ func TestAuthorizationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -38,44 +37,26 @@ func TestAuthorizationErrors(t *testing.T) { // Create root key WITHOUT update permissions rootKey := h.CreateRootKey(workspace.ID) // No permissions - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create a role roleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: roleID, WorkspaceID: workspace.ID, Name: "test_role_no_perms_" + uid.New(""), @@ -114,44 +95,26 @@ func TestAuthorizationErrors(t *testing.T) { // Create root key with insufficient permissions rootKey := h.CreateRootKey(workspace.ID, "api.read.update_key") // Read instead of wildcard - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create a role roleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: roleID, WorkspaceID: workspace.ID, Name: "test_role_partial_perms_" + uid.New(""), @@ -190,44 +153,26 @@ func TestAuthorizationErrors(t *testing.T) { // Create root key with unrelated permissions rootKey := h.CreateRootKey(workspace.ID, "different.permission.scope") - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create a role roleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: roleID, WorkspaceID: workspace.ID, Name: "test_role_unrelated_perms_" + uid.New(""), @@ -266,44 +211,26 @@ func TestAuthorizationErrors(t *testing.T) { // Create root key with create permission instead of update rootKey := h.CreateRootKey(workspace.ID, "api.*.create_key") - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create a role roleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: roleID, WorkspaceID: workspace.ID, Name: "test_role_create_perms_" + uid.New(""), @@ -342,44 +269,26 @@ func TestAuthorizationErrors(t *testing.T) { // Create root key with delete permission instead of update rootKey := h.CreateRootKey(workspace.ID, "api.*.delete_key") - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create a role roleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: roleID, WorkspaceID: workspace.ID, Name: "test_role_delete_perms_" + uid.New(""), @@ -418,44 +327,26 @@ func TestAuthorizationErrors(t *testing.T) { // Create root key with specific API permission (not wildcard) rootKey := h.CreateRootKey(workspace.ID, "api.specific.update_key") - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create a role roleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: roleID, WorkspaceID: workspace.ID, Name: "test_role_specific_perms_" + uid.New(""), @@ -495,44 +386,26 @@ func TestAuthorizationErrors(t *testing.T) { rootKey := h.CreateRootKey(workspace.ID, "api.*.create_key", "api.*.delete_key", "api.*.read_key") // Missing "api.*.update_key" - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create a role roleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: roleID, WorkspaceID: workspace.ID, Name: "test_role_mixed_perms_" + uid.New(""), diff --git a/go/apps/api/routes/v2_keys_remove_roles/404_test.go b/go/apps/api/routes/v2_keys_remove_roles/404_test.go index 5be9f58d76..151eb90677 100644 --- a/go/apps/api/routes/v2_keys_remove_roles/404_test.go +++ b/go/apps/api/routes/v2_keys_remove_roles/404_test.go @@ -6,14 +6,13 @@ import ( "fmt" "net/http" "testing" - "time" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_remove_roles" "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/uid" ) @@ -22,11 +21,11 @@ func TestNotFoundErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -79,40 +78,22 @@ func TestNotFoundErrors(t *testing.T) { }) t.Run("role not found by ID", func(t *testing.T) { - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Generate a non-existent role ID nonExistentRoleID := uid.New(uid.TestPrefix) @@ -142,40 +123,22 @@ func TestNotFoundErrors(t *testing.T) { }) t.Run("role not found by name", func(t *testing.T) { - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Use a non-existent role name nonExistentRoleName := "non_existent_role_name" @@ -208,44 +171,26 @@ func TestNotFoundErrors(t *testing.T) { // Create two workspaces workspace2 := h.CreateWorkspace() - // Create API and key in workspace2 - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace2.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key in workspace2 using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace2.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace2.ID, // Different workspace - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace2.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create role in original workspace roleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: roleID, WorkspaceID: workspace.ID, Name: "test_role_key_diff_workspace_" + uid.New(""), @@ -281,44 +226,26 @@ func TestNotFoundErrors(t *testing.T) { // Create two workspaces workspace2 := h.CreateWorkspace() - // Create API and key in original workspace - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key in original workspace using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create role in workspace2 (different workspace) roleInDifferentWorkspace := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: roleInDifferentWorkspace, WorkspaceID: workspace2.ID, Name: "admin_diff_workspace_id_" + uid.New(""), @@ -353,45 +280,27 @@ func TestNotFoundErrors(t *testing.T) { // Create two workspaces workspace2 := h.CreateWorkspace() - // Create API and key in original workspace - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key in original workspace using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create role in workspace2 (different workspace) roleInDifferentWorkspace := uid.New(uid.TestPrefix) roleName := "admin_by_name_" + uid.New("") - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: roleInDifferentWorkspace, WorkspaceID: workspace2.ID, Name: roleName, @@ -423,44 +332,26 @@ func TestNotFoundErrors(t *testing.T) { }) t.Run("multiple roles with one not found", func(t *testing.T) { - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create one valid role validRoleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: validRoleID, WorkspaceID: workspace.ID, Name: "valid_role_multiple_" + uid.New(""), @@ -498,44 +389,26 @@ func TestNotFoundErrors(t *testing.T) { }) t.Run("mixed valid and invalid role references", func(t *testing.T) { - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create one valid role validRoleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: validRoleID, WorkspaceID: workspace.ID, Name: "valid_role_mixed_" + uid.New(""), diff --git a/go/apps/api/routes/v2_keys_remove_roles/handler.go b/go/apps/api/routes/v2_keys_remove_roles/handler.go index 0c7c84369f..82eb626578 100644 --- a/go/apps/api/routes/v2_keys_remove_roles/handler.go +++ b/go/apps/api/routes/v2_keys_remove_roles/handler.go @@ -9,8 +9,8 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/auditlogs" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/auditlog" + "github.com/unkeyed/unkey/go/pkg/cache" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/fault" @@ -24,12 +24,11 @@ type Response = openapi.V2KeysRemoveRolesResponse // Handler implements zen.Route interface for the v2 keys remove roles endpoint type Handler struct { - // Services as public fields - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService - Auditlogs auditlogs.AuditLogService + Logger logging.Logger + DB db.Database + Keys keys.KeyService + Auditlogs auditlogs.AuditLogService + KeyCache cache.Cache[string, db.FindKeyForVerificationRow] } // Method returns the HTTP method this route responds to @@ -47,7 +46,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { h.Logger.Debug("handling request", "requestId", s.RequestID(), "path", "/v2/keys.removeRoles") // 1. Authentication - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -58,23 +57,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { return err } - // 3. Permission check - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: "*", - Action: rbac.UpdateKey, - }), - ), - ) - if err != nil { - return err - } - - // 4. Validate key exists and belongs to workspace + // 3. Validate key exists and belongs to workspace key, err := db.Query.FindKeyByID(ctx, h.DB.RO(), req.KeyId) if err != nil { if db.IsNotFound(err) { @@ -89,6 +72,23 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { ) } + // TODO: Get api id + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: "*", + Action: rbac.UpdateKey, + }), + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: key.KeyAuthID, + Action: rbac.UpdateKey, + }), + ))) + if err != nil { + return err + } + // Validate key belongs to authorized workspace if key.WorkspaceID != auth.AuthorizedWorkspaceID { return fault.New("key not found", @@ -198,7 +198,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: auditlog.AuthDisconnectRoleKeyEvent, ActorType: auditlog.RootKeyActor, - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorName: "root key", ActorMeta: map[string]any{}, Display: fmt.Sprintf("Removed role %s from key %s", role.Name, req.KeyId), @@ -236,6 +236,8 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { if err != nil { return err } + + h.KeyCache.Remove(ctx, key.Hash) } // 9. Get final state of roles and build response diff --git a/go/apps/api/routes/v2_keys_set_permissions/200_test.go b/go/apps/api/routes/v2_keys_set_permissions/200_test.go index 5530f753af..5024879782 100644 --- a/go/apps/api/routes/v2_keys_set_permissions/200_test.go +++ b/go/apps/api/routes/v2_keys_set_permissions/200_test.go @@ -11,8 +11,8 @@ import ( "github.com/stretchr/testify/require" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_set_permissions" "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/uid" ) @@ -21,11 +21,11 @@ func TestSuccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -41,45 +41,27 @@ func TestSuccess(t *testing.T) { } t.Run("set permissions using permission IDs", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API with keyring using testutil helper + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + // Create a test key using testutil helper + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create permissions permission1ID := uid.New(uid.TestPrefix) - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ + err := db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ PermissionID: permission1ID, WorkspaceID: workspace.ID, Name: "documents.read.initial", @@ -168,45 +150,26 @@ func TestSuccess(t *testing.T) { }) t.Run("set permissions using permission names", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create permissions permission1ID := uid.New(uid.TestPrefix) - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ + err := db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ PermissionID: permission1ID, WorkspaceID: workspace.ID, Name: "documents.read.byname", @@ -267,45 +230,26 @@ func TestSuccess(t *testing.T) { }) t.Run("set empty permissions (remove all)", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create and assign permissions permission1ID := uid.New(uid.TestPrefix) - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ + err := db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ PermissionID: permission1ID, WorkspaceID: workspace.ID, Name: "documents.read.empty", @@ -373,45 +317,26 @@ func TestSuccess(t *testing.T) { }) t.Run("set permissions with no changes (idempotent)", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create permission permissionID := uid.New(uid.TestPrefix) - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ + err := db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ PermissionID: permissionID, WorkspaceID: workspace.ID, Name: "documents.read.idempotent", @@ -461,41 +386,22 @@ func TestSuccess(t *testing.T) { }) t.Run("create permission on-the-fly using slug", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Use a slug that doesn't exist yet newPermissionSlug := "documents.create.onthefly" diff --git a/go/apps/api/routes/v2_keys_set_permissions/400_test.go b/go/apps/api/routes/v2_keys_set_permissions/400_test.go index e65cd683cb..984be05c24 100644 --- a/go/apps/api/routes/v2_keys_set_permissions/400_test.go +++ b/go/apps/api/routes/v2_keys_set_permissions/400_test.go @@ -1,32 +1,26 @@ package handler_test import ( - "context" - "database/sql" "fmt" "net/http" "testing" - "time" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_set_permissions" - "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/uid" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" ) func TestBadRequest(t *testing.T) { - ctx := context.Background() h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -60,40 +54,22 @@ func TestBadRequest(t *testing.T) { }) t.Run("missing permissions field", func(t *testing.T) { - // Create test data - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create test data using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID req := map[string]interface{}{ "keyId": keyID, @@ -151,40 +127,22 @@ func TestBadRequest(t *testing.T) { }) t.Run("permission with neither id nor name", func(t *testing.T) { - // Create test data - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create test data using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID req := handler.Request{ KeyId: keyID, @@ -211,40 +169,22 @@ func TestBadRequest(t *testing.T) { }) t.Run("permission with empty string id and name", func(t *testing.T) { - // Create test data - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create test data using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID emptyString := "" req := handler.Request{ @@ -275,40 +215,22 @@ func TestBadRequest(t *testing.T) { }) t.Run("permission not found - invalid format", func(t *testing.T) { - // Create test data - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create test data using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID invalidID := "invalid_permission_id" req := handler.Request{ @@ -361,40 +283,22 @@ func TestBadRequest(t *testing.T) { }) t.Run("permissions as string instead of array", func(t *testing.T) { - // Create test data - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create test data using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID req := map[string]interface{}{ "keyId": keyID, diff --git a/go/apps/api/routes/v2_keys_set_permissions/401_test.go b/go/apps/api/routes/v2_keys_set_permissions/401_test.go index c409c2e930..53a9ed80df 100644 --- a/go/apps/api/routes/v2_keys_set_permissions/401_test.go +++ b/go/apps/api/routes/v2_keys_set_permissions/401_test.go @@ -22,11 +22,11 @@ func TestAuthenticationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -62,10 +62,6 @@ func TestAuthenticationErrors(t *testing.T) { Meta: sql.NullString{Valid: false}, Expires: sql.NullTime{Valid: false}, RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, }) require.NoError(t, err) diff --git a/go/apps/api/routes/v2_keys_set_permissions/403_test.go b/go/apps/api/routes/v2_keys_set_permissions/403_test.go index 8afeeaa9a8..add27d990b 100644 --- a/go/apps/api/routes/v2_keys_set_permissions/403_test.go +++ b/go/apps/api/routes/v2_keys_set_permissions/403_test.go @@ -22,11 +22,11 @@ func TestForbidden(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -62,10 +62,6 @@ func TestForbidden(t *testing.T) { Meta: sql.NullString{Valid: false}, Expires: sql.NullTime{Valid: false}, RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, }) require.NoError(t, err) diff --git a/go/apps/api/routes/v2_keys_set_permissions/404_test.go b/go/apps/api/routes/v2_keys_set_permissions/404_test.go index 6e54af478b..63a82c32a7 100644 --- a/go/apps/api/routes/v2_keys_set_permissions/404_test.go +++ b/go/apps/api/routes/v2_keys_set_permissions/404_test.go @@ -6,14 +6,13 @@ import ( "fmt" "net/http" "testing" - "time" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_set_permissions" "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/uid" ) @@ -22,11 +21,11 @@ func TestNotFound(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -81,40 +80,22 @@ func TestNotFound(t *testing.T) { }) t.Run("non-existent permission ID", func(t *testing.T) { - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Use non-existent permission ID nonExistentPermissionID := uid.New(uid.TestPrefix) @@ -144,40 +125,22 @@ func TestNotFound(t *testing.T) { }) t.Run("non-existent permission name", func(t *testing.T) { - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID nonExistentPermissionName := "nonexistent.permission.name" @@ -209,44 +172,26 @@ func TestNotFound(t *testing.T) { // Create another workspace otherWorkspace := h.CreateWorkspace() - // Create keyring and key in the other workspace - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: otherWorkspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key in the other workspace using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: otherWorkspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: otherWorkspace.ID, // Different workspace - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: otherWorkspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create permission in the authorized workspace permissionID := uid.New(uid.TestPrefix) - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ + err := db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ PermissionID: permissionID, WorkspaceID: workspace.ID, Name: "documents.read.isolation", @@ -283,44 +228,26 @@ func TestNotFound(t *testing.T) { // Create another workspace otherWorkspace := h.CreateWorkspace() - // Create a test keyring and key in the authorized workspace - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key in the authorized workspace using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create permission in the other workspace permissionID := uid.New(uid.TestPrefix) - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ + err := db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ PermissionID: permissionID, WorkspaceID: otherWorkspace.ID, // Different workspace Name: "documents.read.otherworkspace", @@ -354,44 +281,26 @@ func TestNotFound(t *testing.T) { }) t.Run("multiple permissions with early failure", func(t *testing.T) { - // Create a test keyring and key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create a valid permission for the second item validPermissionID := uid.New(uid.TestPrefix) - err = db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ + err := db.Query.InsertPermission(ctx, h.DB.RW(), db.InsertPermissionParams{ PermissionID: validPermissionID, WorkspaceID: workspace.ID, Name: "documents.read.valid", diff --git a/go/apps/api/routes/v2_keys_set_permissions/handler.go b/go/apps/api/routes/v2_keys_set_permissions/handler.go index d4ebe2ac43..d6ea6d3cda 100644 --- a/go/apps/api/routes/v2_keys_set_permissions/handler.go +++ b/go/apps/api/routes/v2_keys_set_permissions/handler.go @@ -11,8 +11,8 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/auditlogs" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/auditlog" + "github.com/unkeyed/unkey/go/pkg/cache" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/fault" @@ -27,12 +27,11 @@ type Response = openapi.V2KeysSetPermissionsResponse // Handler implements zen.Route interface for the v2 keys set permissions endpoint type Handler struct { - // Services as public fields - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService - Auditlogs auditlogs.AuditLogService + Logger logging.Logger + DB db.Database + Keys keys.KeyService + Auditlogs auditlogs.AuditLogService + KeyCache cache.Cache[string, db.FindKeyForVerificationRow] } // Method returns the HTTP method this route responds to @@ -50,7 +49,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { h.Logger.Debug("handling request", "requestId", s.RequestID(), "path", "/v2/keys.setPermissions") // 1. Authentication - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -62,17 +61,13 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { } // 3. Permission check - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: "*", - Action: rbac.UpdateKey, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: "*", + Action: rbac.UpdateKey, + }), + ))) if err != nil { return err } @@ -255,7 +250,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: auditlog.AuthDisconnectPermissionKeyEvent, ActorType: auditlog.RootKeyActor, - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorName: "root key", ActorMeta: map[string]any{}, Display: fmt.Sprintf("Removed permission %s from key %s", permissionName, req.KeyId), @@ -299,7 +294,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: auditlog.AuthConnectPermissionKeyEvent, ActorType: auditlog.RootKeyActor, - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorName: "root key", ActorMeta: map[string]any{}, Display: fmt.Sprintf("Added permission %s to key %s", permission.Name, req.KeyId), @@ -338,6 +333,8 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { return err } + h.KeyCache.Remove(ctx, key.Hash) + // 10. Get final state of permissions and build response finalPermissions, err := db.Query.ListDirectPermissionsByKeyID(ctx, h.DB.RO(), req.KeyId) if err != nil { diff --git a/go/apps/api/routes/v2_keys_set_roles/200_test.go b/go/apps/api/routes/v2_keys_set_roles/200_test.go index 236946add7..86e852d670 100644 --- a/go/apps/api/routes/v2_keys_set_roles/200_test.go +++ b/go/apps/api/routes/v2_keys_set_roles/200_test.go @@ -11,8 +11,8 @@ import ( "github.com/stretchr/testify/require" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_set_roles" "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/uid" ) @@ -21,11 +21,11 @@ func TestSuccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -41,51 +41,26 @@ func TestSuccess(t *testing.T) { } t.Run("set single role by ID", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API with keyring using testutil helper + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + // Create a test key with role using testutil helper + roleID := h.CreateRole(seed.CreateRoleRequest{ + Name: "admin_set_single_id", + WorkspaceID: workspace.ID, }) - require.NoError(t, err) - // Create a role - roleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ - RoleID: roleID, + keyResponse := h.CreateKey(seed.CreateKeyRequest{ WorkspaceID: workspace.ID, - Name: "admin_set_single_id", - Description: sql.NullString{Valid: true, String: "Admin role"}, + KeyAuthID: api.KeyAuthID.String, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Verify key has no roles initially currentRoles, err := db.Query.ListRolesByKeyID(ctx, h.DB.RO(), keyID) @@ -139,46 +114,27 @@ func TestSuccess(t *testing.T) { }) t.Run("set single role by name", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create a role roleID := uid.New(uid.TestPrefix) roleName := "editor_set_single_name" - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: roleID, WorkspaceID: workspace.ID, Name: roleName, @@ -218,45 +174,26 @@ func TestSuccess(t *testing.T) { }) t.Run("set multiple roles mixed references", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create multiple roles adminRoleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: adminRoleID, WorkspaceID: workspace.ID, Name: "admin_set_multi", @@ -332,45 +269,26 @@ func TestSuccess(t *testing.T) { }) t.Run("replace existing roles", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create roles oldRoleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: oldRoleID, WorkspaceID: workspace.ID, Name: "admin_replace_old", @@ -455,45 +373,26 @@ func TestSuccess(t *testing.T) { }) t.Run("set roles to empty - remove all roles", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create a role and assign it to the key roleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: roleID, WorkspaceID: workspace.ID, Name: "admin_remove_all", @@ -557,45 +456,26 @@ func TestSuccess(t *testing.T) { }) t.Run("set same roles as current - no changes", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create a role and assign it to the key roleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: roleID, WorkspaceID: workspace.ID, Name: "admin_no_change", @@ -655,45 +535,26 @@ func TestSuccess(t *testing.T) { }) t.Run("role reference with both ID and name", func(t *testing.T) { - // Create a test keyring - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create API and key using testutil helpers + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create roles role1ID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: role1ID, WorkspaceID: workspace.ID, Name: "admin_set_both_ref", diff --git a/go/apps/api/routes/v2_keys_set_roles/400_test.go b/go/apps/api/routes/v2_keys_set_roles/400_test.go index 4b2e013a33..09c33bfa3a 100644 --- a/go/apps/api/routes/v2_keys_set_roles/400_test.go +++ b/go/apps/api/routes/v2_keys_set_roles/400_test.go @@ -1,32 +1,26 @@ package handler_test import ( - "context" - "database/sql" "fmt" "net/http" "testing" - "time" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_set_roles" - "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/uid" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" ) func TestValidationErrors(t *testing.T) { - ctx := context.Background() h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -41,40 +35,22 @@ func TestValidationErrors(t *testing.T) { "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, } - // Create a test key for valid requests - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create a test API and key for valid requests using testutil helper + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - - validKeyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: validKeyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + + keyName := "Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + validKeyID := keyResponse.KeyID // Test case for missing keyId t.Run("missing keyId", func(t *testing.T) { diff --git a/go/apps/api/routes/v2_keys_set_roles/401_test.go b/go/apps/api/routes/v2_keys_set_roles/401_test.go index 10887af5fd..ceeee1873e 100644 --- a/go/apps/api/routes/v2_keys_set_roles/401_test.go +++ b/go/apps/api/routes/v2_keys_set_roles/401_test.go @@ -14,11 +14,11 @@ func TestAuthenticationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) diff --git a/go/apps/api/routes/v2_keys_set_roles/403_test.go b/go/apps/api/routes/v2_keys_set_roles/403_test.go index 8a9854711d..88db520adf 100644 --- a/go/apps/api/routes/v2_keys_set_roles/403_test.go +++ b/go/apps/api/routes/v2_keys_set_roles/403_test.go @@ -1,32 +1,26 @@ package handler_test import ( - "context" - "database/sql" "fmt" "net/http" "testing" - "time" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_set_roles" - "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/uid" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" ) func TestAuthorizationErrors(t *testing.T) { - ctx := context.Background() h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -34,51 +28,31 @@ func TestAuthorizationErrors(t *testing.T) { // Create a workspace workspace := h.Resources().UserWorkspace - // Create test data manually - // Create a keyring and test key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create test data using testutil helper + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, - }) - require.NoError(t, err) - // Create a test role - roleID := uid.New(uid.TestPrefix) - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ - RoleID: roleID, + keyName := "Test Key" + roleDescription := "Test role" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ WorkspaceID: workspace.ID, - Name: "test-role", - Description: sql.NullString{Valid: true, String: "Test role"}, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, + Roles: []seed.CreateRoleRequest{ + { + WorkspaceID: workspace.ID, + Name: "test-role", + Description: &roleDescription, + }, + }, }) - require.NoError(t, err) + keyID := keyResponse.KeyID + roleID := keyResponse.RolesIds[0] // Test case for insufficient permissions - missing update_key t.Run("missing update_key permission", func(t *testing.T) { diff --git a/go/apps/api/routes/v2_keys_set_roles/404_test.go b/go/apps/api/routes/v2_keys_set_roles/404_test.go index 22f2d5db30..d90d2ab9e5 100644 --- a/go/apps/api/routes/v2_keys_set_roles/404_test.go +++ b/go/apps/api/routes/v2_keys_set_roles/404_test.go @@ -14,6 +14,7 @@ import ( "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/uid" ) @@ -22,11 +23,11 @@ func TestNotFoundErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -43,46 +44,27 @@ func TestNotFoundErrors(t *testing.T) { "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, } - // Create test data - // Create a keyring and valid key - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{Valid: true, String: "test"}, - DefaultBytes: sql.NullInt32{Valid: true, Int32: 16}, - CreatedAtM: time.Now().UnixMilli(), + // Create test data using testutil helper + defaultPrefix := "test" + defaultBytes := int32(16) + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: workspace.ID, + DefaultPrefix: &defaultPrefix, + DefaultBytes: &defaultBytes, }) - require.NoError(t, err) - validKeyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: validKeyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Valid Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + keyName := "Valid Test Key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - require.NoError(t, err) + validKeyID := keyResponse.KeyID // Create a valid role validRoleID := uid.New(uid.TestPrefix) validRoleName := "valid-test-role" - err = db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ + err := db.Query.InsertRole(ctx, h.DB.RW(), db.InsertRoleParams{ RoleID: validRoleID, WorkspaceID: workspace.ID, Name: validRoleName, @@ -217,10 +199,6 @@ func TestNotFoundErrors(t *testing.T) { Meta: sql.NullString{Valid: false}, Expires: sql.NullTime{Valid: false}, RemainingRequests: sql.NullInt32{Valid: false}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, }) require.NoError(t, err) diff --git a/go/apps/api/routes/v2_keys_set_roles/handler.go b/go/apps/api/routes/v2_keys_set_roles/handler.go index c453b7f708..37901a69e7 100644 --- a/go/apps/api/routes/v2_keys_set_roles/handler.go +++ b/go/apps/api/routes/v2_keys_set_roles/handler.go @@ -10,8 +10,8 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/auditlogs" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/auditlog" + "github.com/unkeyed/unkey/go/pkg/cache" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/fault" @@ -25,12 +25,11 @@ type Response = openapi.V2KeysSetRolesResponse // Handler implements zen.Route interface for the v2 keys set roles endpoint type Handler struct { - // Services as public fields - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService - Auditlogs auditlogs.AuditLogService + Logger logging.Logger + DB db.Database + Keys keys.KeyService + Auditlogs auditlogs.AuditLogService + KeyCache cache.Cache[string, db.FindKeyForVerificationRow] } // Method returns the HTTP method this route responds to @@ -48,7 +47,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { h.Logger.Debug("handling request", "requestId", s.RequestID(), "path", "/v2/keys.setRoles") // 1. Authentication - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -60,17 +59,13 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { } // 3. Permission check - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: "*", - Action: rbac.UpdateKey, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: "*", + Action: rbac.UpdateKey, + }), + ))) if err != nil { return err } @@ -222,7 +217,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: auditlog.AuthDisconnectRoleKeyEvent, ActorType: auditlog.RootKeyActor, - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorName: "root key", ActorMeta: map[string]any{}, Display: fmt.Sprintf("Removed role %s from key %s", removedRole.Name, req.KeyId), @@ -266,7 +261,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: auditlog.AuthConnectRoleKeyEvent, ActorType: auditlog.RootKeyActor, - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorName: "root key", ActorMeta: map[string]any{}, Display: fmt.Sprintf("Added role %s to key %s", role.Name, req.KeyId), @@ -305,6 +300,8 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { return err } + h.KeyCache.Remove(ctx, key.Hash) + // 10. Get final state of roles and build response finalRoles, err := db.Query.ListRolesByKeyID(ctx, h.DB.RO(), req.KeyId) if err != nil { diff --git a/go/apps/api/routes/v2_keys_update_credits/200_test.go b/go/apps/api/routes/v2_keys_update_credits/200_test.go index 156acd58aa..c80a5603a3 100644 --- a/go/apps/api/routes/v2_keys_update_credits/200_test.go +++ b/go/apps/api/routes/v2_keys_update_credits/200_test.go @@ -2,21 +2,18 @@ package handler_test import ( "context" - "database/sql" "fmt" "math/rand/v2" "net/http" "testing" - "time" "github.com/oapi-codegen/nullable" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_update_credits" - "github.com/unkeyed/unkey/go/internal/services/keys" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/uid" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" ) func TestKeyUpdateCreditsSuccess(t *testing.T) { @@ -24,11 +21,11 @@ func TestKeyUpdateCreditsSuccess(t *testing.T) { ctx := context.Background() route := &handler.Handler{ - Logger: h.Logger, - DB: h.DB, - Keys: h.Keys, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -36,54 +33,22 @@ func TestKeyUpdateCreditsSuccess(t *testing.T) { // Create a workspace and user workspace := h.Resources().UserWorkspace - // Create a keyAuth (keyring) for the API - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - CreatedAtM: time.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: false}, - DefaultBytes: sql.NullInt32{Valid: false}, - }) - require.NoError(t, err) - - // Create a test API - apiID := uid.New("api") - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: "Test API", + // Create a test API and key with random initial credits using testutil helper + apiName := "Test API" + api := h.CreateApi(seed.CreateApiRequest{ WorkspaceID: workspace.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: keyAuthID}, - CreatedAtM: time.Now().UnixMilli(), - }) - require.NoError(t, err) - - keyID := uid.New(uid.KeyPrefix) - key, _ := h.Keys.CreateKey(ctx, keys.CreateKeyRequest{ - Prefix: "test", - ByteLength: 16, + Name: &apiName, }) + keyName := "test-key" initialCredits := int32(rand.IntN(50)) - - insertParams := db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: key.Hash, - Start: key.Start, - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "test-key"}, - Expires: sql.NullTime{Valid: false}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false, String: ""}, - RemainingRequests: sql.NullInt32{Int32: initialCredits, Valid: initialCredits > 0}, - } - - err = db.Query.InsertKey(ctx, h.DB.RW(), insertParams) - require.NoError(t, err) + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, + Remaining: &initialCredits, + }) + keyID := keyResponse.KeyID // Create a root key with appropriate permissions rootKey := h.CreateRootKey(workspace.ID, "api.*.update_key") @@ -173,6 +138,12 @@ func TestKeyUpdateCreditsSuccess(t *testing.T) { require.True(t, currentKey.RemainingRequests.Valid) currentCredits := int64(currentKey.RemainingRequests.Int32) + // If we are decreasing credits into the negative, it will be automatically set to 0 + shouldBeRemaining := int64(0) + if currentCredits-decreaseBy > 0 { + shouldBeRemaining = currentCredits - decreaseBy + } + req := handler.Request{ KeyId: keyID, Operation: openapi.Decrement, @@ -185,12 +156,12 @@ func TestKeyUpdateCreditsSuccess(t *testing.T) { require.NoError(t, err) require.Equal(t, 200, res.Status) require.NotNil(t, res.Body) - require.Equal(t, remaining, currentCredits-decreaseBy) + require.Equal(t, remaining, shouldBeRemaining) key, err := db.Query.FindKeyByID(ctx, h.DB.RO(), keyID) require.NoError(t, err) require.NotNil(t, key) require.Equal(t, key.RemainingRequests.Valid, true) - require.EqualValues(t, key.RemainingRequests.Int32, currentCredits-decreaseBy) + require.EqualValues(t, key.RemainingRequests.Int32, shouldBeRemaining) }) } diff --git a/go/apps/api/routes/v2_keys_update_credits/400_test.go b/go/apps/api/routes/v2_keys_update_credits/400_test.go index 624cd2cc2b..ef128bc33b 100644 --- a/go/apps/api/routes/v2_keys_update_credits/400_test.go +++ b/go/apps/api/routes/v2_keys_update_credits/400_test.go @@ -1,32 +1,27 @@ package handler_test import ( - "database/sql" "fmt" "net/http" "testing" - "time" "github.com/oapi-codegen/nullable" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_update_credits" - "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/uid" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" ) func TestKeyUpdateCreditsBadRequest(t *testing.T) { h := testutil.NewHarness(t) - ctx := t.Context() route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -34,52 +29,20 @@ func TestKeyUpdateCreditsBadRequest(t *testing.T) { // Create a workspace and user workspace := h.Resources().UserWorkspace - // Create a keyAuth (keyring) for the API - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - CreatedAtM: time.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: false}, - DefaultBytes: sql.NullInt32{Valid: false}, - }) - require.NoError(t, err) - - // Create a test API - apiID := uid.New("api") - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: "Test API", + // Create a test API and key using testutil helper + apiName := "Test API" + api := h.CreateApi(seed.CreateApiRequest{ WorkspaceID: workspace.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: keyAuthID}, - CreatedAtM: time.Now().UnixMilli(), + Name: &apiName, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - key, _ := h.Keys.CreateKey(ctx, keys.CreateKeyRequest{ - Prefix: "test", - ByteLength: 16, + keyName := "test-key" + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, }) - - insertParams := db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: key.Hash, - Start: key.Start, - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "test-key"}, - Expires: sql.NullTime{Valid: false}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false, String: ""}, - RemainingRequests: sql.NullInt32{Int32: 0, Valid: false}, - } - - err = db.Query.InsertKey(ctx, h.DB.RW(), insertParams) - require.NoError(t, err) + keyID := keyResponse.KeyID // Create root key with read permissions rootKey := h.CreateRootKey(h.Resources().UserWorkspace.ID, "api.*.update_key") diff --git a/go/apps/api/routes/v2_keys_update_credits/401_test.go b/go/apps/api/routes/v2_keys_update_credits/401_test.go index 9b36ee8273..8d410a29f2 100644 --- a/go/apps/api/routes/v2_keys_update_credits/401_test.go +++ b/go/apps/api/routes/v2_keys_update_credits/401_test.go @@ -1,31 +1,27 @@ package handler_test import ( - "database/sql" "net/http" "testing" - "time" "github.com/oapi-codegen/nullable" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_update_credits" - "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/uid" ) func TestKeyUpdateCreditsUnauthorized(t *testing.T) { h := testutil.NewHarness(t) - ctx := t.Context() route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) @@ -33,52 +29,22 @@ func TestKeyUpdateCreditsUnauthorized(t *testing.T) { // Create a workspace and user workspace := h.Resources().UserWorkspace - // Create a keyAuth (keyring) for the API - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: workspace.ID, - CreatedAtM: time.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: false}, - DefaultBytes: sql.NullInt32{Valid: false}, - }) - require.NoError(t, err) - - // Create a test API - apiID := uid.New("api") - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: "Test API", + // Create a test API and key with credits using testutil helper + apiName := "Test API" + api := h.CreateApi(seed.CreateApiRequest{ WorkspaceID: workspace.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: keyAuthID}, - CreatedAtM: time.Now().UnixMilli(), + Name: &apiName, }) - require.NoError(t, err) - keyID := uid.New(uid.KeyPrefix) - key, _ := h.Keys.CreateKey(ctx, keys.CreateKeyRequest{ - Prefix: "test", - ByteLength: 16, + keyName := "test-key" + remainingRequests := int32(100) + keyResponse := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Name: &keyName, + Remaining: &remainingRequests, }) - - insertParams := db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: key.Hash, - Start: key.Start, - WorkspaceID: workspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "test-key"}, - Expires: sql.NullTime{Valid: false}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false, String: ""}, - RemainingRequests: sql.NullInt32{Int32: 100, Valid: true}, - } - - err = db.Query.InsertKey(ctx, h.DB.RW(), insertParams) - require.NoError(t, err) + keyID := keyResponse.KeyID req := handler.Request{ KeyId: keyID, diff --git a/go/apps/api/routes/v2_keys_update_credits/403_test.go b/go/apps/api/routes/v2_keys_update_credits/403_test.go index 8cfdd6506b..3650e0f0a0 100644 --- a/go/apps/api/routes/v2_keys_update_credits/403_test.go +++ b/go/apps/api/routes/v2_keys_update_credits/403_test.go @@ -1,132 +1,52 @@ package handler_test import ( - "context" - "database/sql" "fmt" "net/http" "testing" - "time" "github.com/oapi-codegen/nullable" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/openapi" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_update_credits" - "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/hash" + "github.com/unkeyed/unkey/go/pkg/ptr" "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/uid" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" ) func TestKeyUpdateCreditsForbidden(t *testing.T) { - h := testutil.NewHarness(t) - ctx := context.Background() route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) - // Create API for testing - keyAuthID := uid.New(uid.KeyAuthPrefix) - err := db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: keyAuthID, - WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: time.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: false, String: ""}, - DefaultBytes: sql.NullInt32{Valid: false, Int32: 0}, - }) - require.NoError(t, err) - - apiID := uid.New(uid.APIPrefix) - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: apiID, - Name: "test-api", + // Create API for testing using testutil helper + apiName := "test-api" + api := h.CreateApi(seed.CreateApiRequest{ WorkspaceID: h.Resources().UserWorkspace.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: keyAuthID}, - CreatedAtM: time.Now().UnixMilli(), - }) - require.NoError(t, err) - - // Create another API for cross-API testing - otherKeyAuthID := uid.New(uid.KeyAuthPrefix) - err = db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: otherKeyAuthID, - WorkspaceID: h.Resources().UserWorkspace.ID, - CreatedAtM: time.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: false, String: ""}, - DefaultBytes: sql.NullInt32{Valid: false, Int32: 0}, + Name: &apiName, }) - require.NoError(t, err) - otherApiID := uid.New(uid.APIPrefix) - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: otherApiID, - Name: "other-api", + diffApi := h.CreateApi(seed.CreateApiRequest{ WorkspaceID: h.Resources().UserWorkspace.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: otherKeyAuthID}, - CreatedAtM: time.Now().UnixMilli(), + Name: &apiName, }) - require.NoError(t, err) - - // Create another Workspace for cross-API testing - otherWorkspace := h.CreateWorkspace() - - otherWsKeyAuthID := uid.New(uid.KeyAuthPrefix) - err = db.Query.InsertKeyring(ctx, h.DB.RW(), db.InsertKeyringParams{ - ID: otherWsKeyAuthID, - WorkspaceID: otherWorkspace.ID, - CreatedAtM: time.Now().UnixMilli(), - DefaultPrefix: sql.NullString{Valid: false, String: ""}, - DefaultBytes: sql.NullInt32{Valid: false, Int32: 0}, - }) - require.NoError(t, err) - - otherWsApiID := uid.New(uid.APIPrefix) - err = db.Query.InsertApi(ctx, h.DB.RW(), db.InsertApiParams{ - ID: otherWsApiID, - Name: "test-api", - WorkspaceID: otherWorkspace.ID, - AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, - KeyAuthID: sql.NullString{Valid: true, String: otherWsKeyAuthID}, - CreatedAtM: time.Now().UnixMilli(), - }) - require.NoError(t, err) - - // Create a test key - keyID := uid.New(uid.KeyPrefix) - keyString := "test_" + uid.New("") - err = db.Query.InsertKey(ctx, h.DB.RW(), db.InsertKeyParams{ - ID: keyID, - KeyringID: keyAuthID, - Hash: hash.Sha256(keyString), - Start: keyString[:4], - WorkspaceID: h.Resources().UserWorkspace.ID, - ForWorkspaceID: sql.NullString{Valid: false}, - Name: sql.NullString{Valid: true, String: "Test Key"}, - CreatedAtM: time.Now().UnixMilli(), - Enabled: true, - IdentityID: sql.NullString{Valid: false}, - Meta: sql.NullString{Valid: false}, - Expires: sql.NullTime{Valid: false}, - RemainingRequests: sql.NullInt32{Valid: true, Int32: 100}, - RatelimitAsync: sql.NullBool{Valid: false}, - RatelimitLimit: sql.NullInt32{Valid: false}, - RatelimitDuration: sql.NullInt64{Valid: false}, - Environment: sql.NullString{Valid: false}, + + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: api.WorkspaceID, + KeyAuthID: api.KeyAuthID.String, + Remaining: ptr.P(int32(100)), }) - require.NoError(t, err) req := handler.Request{ - KeyId: keyID, + KeyId: key.KeyID, Operation: openapi.Increment, Value: nullable.NewNullableWithValue(int64(10)), } @@ -178,7 +98,7 @@ func TestKeyUpdateCreditsForbidden(t *testing.T) { t.Run("cross api access", func(t *testing.T) { // Create root key with read permission for a single api - rootKey := h.CreateRootKey(h.Resources().UserWorkspace.ID, fmt.Sprintf("api.%s.update_key", otherApiID)) + rootKey := h.CreateRootKey(h.Resources().UserWorkspace.ID, fmt.Sprintf("api.%s.update_key", diffApi.ID)) headers := http.Header{ "Content-Type": {"application/json"}, diff --git a/go/apps/api/routes/v2_keys_update_credits/404_test.go b/go/apps/api/routes/v2_keys_update_credits/404_test.go index 38c469261e..4854f46cc7 100644 --- a/go/apps/api/routes/v2_keys_update_credits/404_test.go +++ b/go/apps/api/routes/v2_keys_update_credits/404_test.go @@ -16,11 +16,11 @@ func TestUpdateKeyCreditsNotFound(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + KeyCache: h.Caches.VerificationKeyByHash, } h.Register(route) diff --git a/go/apps/api/routes/v2_keys_update_credits/handler.go b/go/apps/api/routes/v2_keys_update_credits/handler.go index c2663ba31c..eb8d7ab629 100644 --- a/go/apps/api/routes/v2_keys_update_credits/handler.go +++ b/go/apps/api/routes/v2_keys_update_credits/handler.go @@ -10,8 +10,8 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/auditlogs" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/auditlog" + "github.com/unkeyed/unkey/go/pkg/cache" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/fault" @@ -26,11 +26,11 @@ type Response = openapi.V2KeysUpdateCreditsResponse // Handler implements zen.Route interface for the v2 keys.updateCredits endpoint type Handler struct { - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService - Auditlogs auditlogs.AuditLogService + Logger logging.Logger + DB db.Database + Keys keys.KeyService + Auditlogs auditlogs.AuditLogService + KeyCache cache.Cache[string, db.FindKeyForVerificationRow] } // Method returns the HTTP method this route responds to @@ -47,7 +47,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { h.Logger.Debug("handling request", "requestId", s.RequestID(), "path", "/v2/keys.updateCredits") // Authentication - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -101,22 +101,18 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { } // Permission check - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: "*", - Action: rbac.UpdateKey, - }), - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: key.Api.ID, - Action: rbac.UpdateKey, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: "*", + Action: rbac.UpdateKey, + }), + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: key.Api.ID, + Action: rbac.UpdateKey, + }), + ))) if err != nil { return err } @@ -124,7 +120,6 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { if (req.Operation == openapi.Decrement || req.Operation == openapi.Increment) && (!req.Value.IsSpecified() || req.Value.IsNull()) { return fault.New("wrong operation usage", fault.Code(codes.App.Validation.InvalidInput.URN()), - fault.Internal("wrong operation usage"), fault.Public("When specifying an increment or decrement operation, a value must be provided."), ) } @@ -132,7 +127,6 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { if (req.Operation == openapi.Decrement || req.Operation == openapi.Increment) && !key.RemainingRequests.Valid { return fault.New("wrong operation usage", fault.Code(codes.App.Validation.InvalidInput.URN()), - fault.Internal("wrong operation usage"), fault.Public("You cannot increment or decrement a key with unlimited credits."), ) } @@ -213,7 +207,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: auditlog.KeyUpdateEvent, Display: fmt.Sprintf("Updated Key %s, set remaining to %s.", key.ID, remaining), - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorName: "root key", ActorMeta: map[string]any{}, ActorType: auditlog.RootKeyActor, @@ -273,6 +267,8 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { } } + h.KeyCache.Remove(ctx, key.Hash) + return s.JSON(http.StatusOK, Response{ Meta: openapi.Meta{ RequestId: s.RequestID(), diff --git a/go/apps/api/routes/v2_keys_verify_key/200_test.go b/go/apps/api/routes/v2_keys_verify_key/200_test.go new file mode 100644 index 0000000000..0fe6d1f413 --- /dev/null +++ b/go/apps/api/routes/v2_keys_verify_key/200_test.go @@ -0,0 +1,509 @@ +package handler_test + +import ( + "encoding/json" + "fmt" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/unkeyed/unkey/go/apps/api/openapi" + handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_verify_key" + "github.com/unkeyed/unkey/go/pkg/ptr" + "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" + "github.com/unkeyed/unkey/go/pkg/uid" +) + +func TestSuccess(t *testing.T) { + // ctx := context.Background() + h := testutil.NewHarness(t) + + route := &handler.Handler{ + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + ClickHouse: h.ClickHouse, + } + + h.Register(route) + + // Create a workspace + workspace := h.Resources().UserWorkspace + // Create a root key with appropriate permissions + rootKey := h.CreateRootKey(workspace.ID, "api.*.verify_key") + + api := h.CreateApi(seed.CreateApiRequest{WorkspaceID: workspace.ID}) + + // Set up request headers + headers := http.Header{ + "Content-Type": {"application/json"}, + "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, + } + + t.Run("verifies key as valid", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + } + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.VALID, res.Body.Data.Code, "Key should be valid but got %s", res.Body.Data.Code) + require.True(t, res.Body.Data.Valid, "Key should be valid but got %t", res.Body.Data.Valid) + }) + + t.Run("verifies expired key as valid and then invalid", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Expires: ptr.P(time.Now().Add(time.Second * 3)), + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + } + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.VALID, res.Body.Data.Code, "Key should be valid but got %s", res.Body.Data.Code) + require.True(t, res.Body.Data.Valid, "Key should be valid but got %t", res.Body.Data.Valid) + + time.Sleep(time.Second * 3) + + res = testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.EXPIRED, res.Body.Data.Code, "Key should be expired but got %s", res.Body.Data.Code) + require.False(t, res.Body.Data.Valid, "Key should be invalid but got %t", res.Body.Data.Valid) + }) + + t.Run("disabled key", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Disabled: true, + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + } + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.DISABLED, res.Body.Data.Code, "Key should be disabled but got %s", res.Body.Data.Code) + require.False(t, res.Body.Data.Valid, "Key should be invalid but got %t", res.Body.Data.Valid) + }) + + t.Run("key with credits", func(t *testing.T) { + t.Run("allowed default credit cost", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Remaining: ptr.P(int32(5)), + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.VALID, res.Body.Data.Code, "Key should be valid but got %s", res.Body.Data.Code) + require.True(t, res.Body.Data.Valid, "Key should be valid but got %t", res.Body.Data.Valid) + require.EqualValues(t, *res.Body.Data.Credits, int32(4), "Key should have 4 credits but got %d", *res.Body.Data.Credits) + }) + + t.Run("exceeding with default credit cost", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Remaining: ptr.P(int32(0)), + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.USAGEEXCEEDED, res.Body.Data.Code, "Key should show usage exceeded but got %s", res.Body.Data.Code) + require.False(t, res.Body.Data.Valid, "Key should be invalid but got %t", res.Body.Data.Valid) + require.EqualValues(t, *res.Body.Data.Credits, int32(0), "Key should have 0 credits but got %d", *res.Body.Data.Credits) + }) + + t.Run("allowed custom credit cost", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Remaining: ptr.P(int32(5)), + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + Credits: &openapi.KeysVerifyKeyCredits{ + Cost: 5, + }, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.VALID, res.Body.Data.Code, "Key should be valid but got %s", res.Body.Data.Code) + require.True(t, res.Body.Data.Valid, "Key should be invalid but got %t", res.Body.Data.Valid) + require.EqualValues(t, *res.Body.Data.Credits, int32(0), "Key should have 0 credits remaining but got %d", *res.Body.Data.Credits) + }) + + t.Run("exceeding with custom credit cost", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Remaining: ptr.P(int32(5)), + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + Credits: &openapi.KeysVerifyKeyCredits{ + Cost: 15, + }, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.USAGEEXCEEDED, res.Body.Data.Code, "Key should be usage exceeded but got %s", res.Body.Data.Code) + require.False(t, res.Body.Data.Valid, "Key should be invalid but got %t", res.Body.Data.Valid) + require.EqualValues(t, *res.Body.Data.Credits, int32(0), "Key should have 0 credits remaining but got %d", *res.Body.Data.Credits) + }) + + t.Run("allow credits 0 even when remaining 0", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Remaining: ptr.P(int32(0)), + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + Credits: &openapi.KeysVerifyKeyCredits{ + Cost: 0, + }, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.VALID, res.Body.Data.Code, "Key should be code valid but got %s", res.Body.Data.Code) + require.True(t, res.Body.Data.Valid, "Key should be valid but got %t", res.Body.Data.Valid) + require.EqualValues(t, *res.Body.Data.Credits, int32(0), "Key should have 0 credits remaining but got %d", *res.Body.Data.Credits) + }) + }) + + t.Run("with ip whitelist", func(t *testing.T) { + ipWhitelistApi := h.CreateApi(seed.CreateApiRequest{WorkspaceID: workspace.ID, IpWhitelist: "127.0.0.1"}) + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + }) + + req := handler.Request{ + ApiId: ipWhitelistApi.ID, + Key: key.Key, + } + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.FORBIDDEN, res.Body.Data.Code, "Key should be forbidden but got %s", res.Body.Data.Code) + require.False(t, res.Body.Data.Valid, "Key should be invalid but got %t", res.Body.Data.Valid) + }) + + t.Run("key with permissions", func(t *testing.T) { + t.Run("with role permission valid", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Roles: []seed.CreateRoleRequest{{ + Name: "test-role", + Description: nil, + WorkspaceID: workspace.ID, + Permissions: []seed.CreatePermissionRequest{{ + Name: "domain.write", + Slug: "domain.write", + Description: nil, + WorkspaceID: workspace.ID, + }}, + }}, + }) + + perms := &openapi.V2KeysVerifyKeyRequestBody_Permissions{} + perms.FromV2KeysVerifyKeyRequestBodyPermissions0(openapi.V2KeysVerifyKeyRequestBodyPermissions0("domain.write")) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + Permissions: perms, + } + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.VALID, res.Body.Data.Code, "Key should be valid but got %s", res.Body.Data.Code) + require.True(t, res.Body.Data.Valid, "Key should be valid but got %t", res.Body.Data.Valid) + require.Len(t, ptr.SafeDeref(res.Body.Data.Permissions), 1, "Key should be have a single permission attached") + }) + + t.Run("with direct permission valid", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Permissions: []seed.CreatePermissionRequest{{ + Name: "domain.read", + Slug: "domain.read", + Description: nil, + WorkspaceID: workspace.ID, + }}, + }) + + perms := &openapi.V2KeysVerifyKeyRequestBody_Permissions{} + perms.FromV2KeysVerifyKeyRequestBodyPermissions0(openapi.V2KeysVerifyKeyRequestBodyPermissions0("domain.read")) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + Permissions: perms, + } + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.VALID, res.Body.Data.Code, "Key should be valid but got %s", res.Body.Data.Code) + require.True(t, res.Body.Data.Valid, "Key should be valid but got %t", res.Body.Data.Valid) + require.Len(t, ptr.SafeDeref(res.Body.Data.Permissions), 1, "Key should be have a single permission attached") + }) + + t.Run("missing permissions", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + }) + + perms := &openapi.V2KeysVerifyKeyRequestBody_Permissions{} + perms.FromV2KeysVerifyKeyRequestBodyPermissions0(openapi.V2KeysVerifyKeyRequestBodyPermissions0("domain.write")) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + Permissions: perms, + } + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.INSUFFICIENTPERMISSIONS, res.Body.Data.Code, "Key should be no perms but got %s", res.Body.Data.Code) + require.False(t, res.Body.Data.Valid, "Key should be valid but got %t", res.Body.Data.Valid) + }) + }) + + t.Run("key with auto applied ratelimit", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Ratelimits: []seed.CreateRatelimitRequest{ + { + Name: "auto-apply", + WorkspaceID: workspace.ID, + AutoApply: true, + Duration: time.Minute.Milliseconds(), + Limit: 1, + }, + }, + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.VALID, res.Body.Data.Code, "Key should be valid but got %s", res.Body.Data.Code) + require.True(t, res.Body.Data.Valid, "Key should be valid but got %t", res.Body.Data.Valid) + + res = testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.RATELIMITED, res.Body.Data.Code, "Key should be ratelimited but got %s", res.Body.Data.Code) + require.False(t, res.Body.Data.Valid, "Key should be invalid but got %t", res.Body.Data.Valid) + }) + + t.Run("key with specified ratelimit", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Ratelimits: []seed.CreateRatelimitRequest{ + { + Name: "requests", + WorkspaceID: workspace.ID, + AutoApply: false, + Duration: time.Minute.Milliseconds(), + Limit: 1, + }, + }, + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + Ratelimits: &[]openapi.KeysVerifyKeyRatelimit{{Name: "requests"}}, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.VALID, res.Body.Data.Code, "Key should be valid but got %s", res.Body.Data.Code) + require.True(t, res.Body.Data.Valid, "Key should be valid but got %t", res.Body.Data.Valid) + + res = testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.RATELIMITED, res.Body.Data.Code, "Key should be ratelimited but got %s", res.Body.Data.Code) + require.False(t, res.Body.Data.Valid, "Key should be invalid but got %t", res.Body.Data.Valid) + }) + + t.Run("key with custom ratelimit", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + Ratelimits: &[]openapi.KeysVerifyKeyRatelimit{{Name: "requests", Cost: ptr.P(15), Duration: ptr.P(int(time.Minute.Milliseconds())), Limit: ptr.P(20)}}, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.VALID, res.Body.Data.Code, "Key should be valid but got %s", res.Body.Data.Code) + require.True(t, res.Body.Data.Valid, "Key should be valid but got %t", res.Body.Data.Valid) + + res = testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.RATELIMITED, res.Body.Data.Code, "Key should be ratelimited but got %s", res.Body.Data.Code) + require.False(t, res.Body.Data.Valid, "Key should be invalid but got %t", res.Body.Data.Valid) + }) + + t.Run("key with identity ratelimit", func(t *testing.T) { + identity := h.CreateIdentity(seed.CreateIdentityRequest{ + WorkspaceID: workspace.ID, + ExternalID: "test-123", + Ratelimits: []seed.CreateRatelimitRequest{ + { + Name: "tokens", + WorkspaceID: workspace.ID, + AutoApply: false, + Duration: (time.Minute * 30).Milliseconds(), + Limit: 4, + // Will be set later + IdentityID: nil, + KeyID: nil, + }, + }, + }) + + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + IdentityID: ptr.P(identity), + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + Ratelimits: &[]openapi.KeysVerifyKeyRatelimit{{Name: "tokens", Cost: ptr.P(4)}}, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.VALID, res.Body.Data.Code, "Key should be valid but got %s", res.Body.Data.Code) + require.True(t, res.Body.Data.Valid, "Key should be valid but got %t", res.Body.Data.Valid) + + res = testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.RATELIMITED, res.Body.Data.Code, "Key should be ratelimited but got %s", res.Body.Data.Code) + require.False(t, res.Body.Data.Valid, "Key should be invalid but got %t", res.Body.Data.Valid) + }) + + t.Run("returns correct information", func(t *testing.T) { + meta := map[string]interface{}{"key": "value"} + + raw, err := json.Marshal(meta) + require.NoError(t, err) + + externalId := uid.New("ext") + identity := h.CreateIdentity(seed.CreateIdentityRequest{WorkspaceID: workspace.ID, ExternalID: externalId, Meta: raw, Ratelimits: nil}) + keyName := "valid-info" + + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + IdentityID: ptr.P(identity), + Name: ptr.P(keyName), + Roles: []seed.CreateRoleRequest{{ + Name: "read-writer", + Description: nil, + WorkspaceID: workspace.ID, + Permissions: []seed.CreatePermissionRequest{{ + Name: "domain.delete", + Slug: "domain.delete", + Description: nil, + WorkspaceID: workspace.ID, + }, { + Name: "domain.edit", + Slug: "domain.edit", + Description: nil, + WorkspaceID: workspace.ID, + }}, + }}, + Permissions: []seed.CreatePermissionRequest{{ + Name: "domain.create", + Slug: "domain.create", + Description: nil, + WorkspaceID: workspace.ID, + }}, + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.VALID, res.Body.Data.Code, "Key should be valid but got %s", res.Body.Data.Code) + require.True(t, res.Body.Data.Valid, "Key should be valid but got %t", res.Body.Data.Valid) + require.Len(t, ptr.SafeDeref(res.Body.Data.Roles), 1, "Key should have 1 role") + require.Len(t, ptr.SafeDeref(res.Body.Data.Permissions), 3, "Key should have 3 permissions") + require.EqualValues(t, openapi.Identity{ExternalId: externalId, Id: identity, Meta: &meta, Ratelimits: nil}, ptr.SafeDeref(res.Body.Data.Identity)) + require.Equal(t, keyName, ptr.SafeDeref(res.Body.Data.Name), "Key should have the same name") + }) +} diff --git a/go/apps/api/routes/v2_keys_verify_key/400_test.go b/go/apps/api/routes/v2_keys_verify_key/400_test.go new file mode 100644 index 0000000000..322565b416 --- /dev/null +++ b/go/apps/api/routes/v2_keys_verify_key/400_test.go @@ -0,0 +1,110 @@ +package handler_test + +import ( + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + "github.com/unkeyed/unkey/go/apps/api/openapi" + handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_verify_key" + "github.com/unkeyed/unkey/go/pkg/ptr" + "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" +) + +func TestBadRequest(t *testing.T) { + h := testutil.NewHarness(t) + + route := &handler.Handler{ + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + ClickHouse: h.ClickHouse, + } + + h.Register(route) + + workspace := h.Resources().UserWorkspace + rootKey := h.CreateRootKey(workspace.ID, "api.*.verify_key") + api := h.CreateApi(seed.CreateApiRequest{WorkspaceID: workspace.ID}) + + validHeaders := http.Header{ + "Content-Type": {"application/json"}, + "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, + } + + t.Run("missing required fields", func(t *testing.T) { + t.Run("missing apiId", func(t *testing.T) { + req := handler.Request{ + Key: "test_key", + // ApiId missing + } + + res := testutil.CallRoute[handler.Request, openapi.BadRequestErrorResponse](h, route, validHeaders, req) + require.Equal(t, 400, res.Status) + require.NotNil(t, res.Body) + require.NotNil(t, res.Body.Error) + }) + + t.Run("missing key", func(t *testing.T) { + req := handler.Request{ + ApiId: api.ID, + // Key missing + } + + res := testutil.CallRoute[handler.Request, openapi.BadRequestErrorResponse](h, route, validHeaders, req) + require.Equal(t, 400, res.Status) + require.NotNil(t, res.Body) + require.NotNil(t, res.Body.Error) + }) + }) + + t.Run("invalid validation", func(t *testing.T) { + t.Run("invalid cost value", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + Ratelimits: &[]openapi.KeysVerifyKeyRatelimit{ + { + Name: "test", + Cost: ptr.P(-1), // Invalid negative cost + Limit: ptr.P(10), + Duration: ptr.P(60000), + }, + }, + } + + res := testutil.CallRoute[handler.Request, openapi.BadRequestErrorResponse](h, route, validHeaders, req) + require.Equal(t, 400, res.Status) + require.NotNil(t, res.Body) + require.NotNil(t, res.Body.Error) + }) + + t.Run("invalid credits cost", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + Credits: &openapi.KeysVerifyKeyCredits{ + Cost: -1, // Invalid negative cost + }, + } + + res := testutil.CallRoute[handler.Request, openapi.BadRequestErrorResponse](h, route, validHeaders, req) + require.Equal(t, 400, res.Status) + require.NotNil(t, res.Body) + require.NotNil(t, res.Body.Error) + }) + }) +} diff --git a/go/apps/api/routes/v2_keys_verify_key/401_test.go b/go/apps/api/routes/v2_keys_verify_key/401_test.go new file mode 100644 index 0000000000..5540059cc1 --- /dev/null +++ b/go/apps/api/routes/v2_keys_verify_key/401_test.go @@ -0,0 +1,86 @@ +package handler_test + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" + "github.com/unkeyed/unkey/go/apps/api/openapi" + handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_verify_key" + "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" +) + +func TestUnauthorized(t *testing.T) { + h := testutil.NewHarness(t) + + route := &handler.Handler{ + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + ClickHouse: h.ClickHouse, + } + + h.Register(route) + + workspace := h.Resources().UserWorkspace + api := h.CreateApi(seed.CreateApiRequest{WorkspaceID: workspace.ID}) + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + } + + t.Run("missing authorization header", func(t *testing.T) { + headers := http.Header{ + "Content-Type": {"application/json"}, + // Authorization header missing + } + + res := testutil.CallRoute[handler.Request, openapi.UnauthorizedErrorResponse](h, route, headers, req) + require.Equal(t, 400, res.Status) + require.NotNil(t, res.Body) + require.NotNil(t, res.Body.Error) + }) + + t.Run("invalid bearer token", func(t *testing.T) { + headers := http.Header{ + "Content-Type": {"application/json"}, + "Authorization": {"Bearer invalid_token_here"}, + } + + res := testutil.CallRoute[handler.Request, openapi.UnauthorizedErrorResponse](h, route, headers, req) + require.Equal(t, 401, res.Status) + require.NotNil(t, res.Body) + require.NotNil(t, res.Body.Error) + }) + + t.Run("malformed authorization header", func(t *testing.T) { + headers := http.Header{ + "Content-Type": {"application/json"}, + "Authorization": {"InvalidFormat"}, + } + + res := testutil.CallRoute[handler.Request, openapi.UnauthorizedErrorResponse](h, route, headers, req) + require.Equal(t, 400, res.Status) + require.NotNil(t, res.Body) + require.NotNil(t, res.Body.Error) + }) + + t.Run("requestId is returned in error response", func(t *testing.T) { + headers := http.Header{ + "Content-Type": {"application/json"}, + "Authorization": {"Bearer invalid_token"}, + } + + res := testutil.CallRoute[handler.Request, openapi.UnauthorizedErrorResponse](h, route, headers, req) + require.Equal(t, 401, res.Status) + require.NotNil(t, res.Body) + require.NotEmpty(t, res.Body.Meta.RequestId, "RequestId should be returned even in error response") + }) +} diff --git a/go/apps/api/routes/v2_keys_verify_key/403_test.go b/go/apps/api/routes/v2_keys_verify_key/403_test.go new file mode 100644 index 0000000000..7ef7b76a43 --- /dev/null +++ b/go/apps/api/routes/v2_keys_verify_key/403_test.go @@ -0,0 +1,75 @@ +package handler_test + +import ( + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + "github.com/unkeyed/unkey/go/apps/api/openapi" + handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_verify_key" + "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" +) + +func TestForbidden(t *testing.T) { + h := testutil.NewHarness(t) + + route := &handler.Handler{ + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + ClickHouse: h.ClickHouse, + } + + h.Register(route) + + workspace := h.Resources().UserWorkspace + rootKey := h.CreateRootKey(workspace.ID, "api.*.verify_key") + + headers := http.Header{ + "Content-Type": {"application/json"}, + "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, + } + + api := h.CreateApi(seed.CreateApiRequest{WorkspaceID: workspace.ID}) + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + }) + + t.Run("wrong api id", func(t *testing.T) { + api2 := h.CreateApi(seed.CreateApiRequest{WorkspaceID: workspace.ID}) + req := handler.Request{ + ApiId: api2.ID, // Wrong API ID + Key: key.Key, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.FORBIDDEN, res.Body.Data.Code, "Key should be forbidden but got %s", res.Body.Data.Code) + require.False(t, res.Body.Data.Valid, "Key should be invalid but got %t", res.Body.Data.Valid) + }) + + t.Run("root key without sufficient permissions", func(t *testing.T) { + // Create root key with insufficient permissions + limitedRootKey := h.CreateRootKey(workspace.ID, "api.*.read") // Wrong permission + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + } + + headers := http.Header{ + "Content-Type": {"application/json"}, + "Authorization": {fmt.Sprintf("Bearer %s", limitedRootKey)}, + } + + res := testutil.CallRoute[handler.Request, openapi.UnauthorizedErrorResponse](h, route, headers, req) + require.Equal(t, 403, res.Status) + require.NotNil(t, res.Body) + require.NotNil(t, res.Body.Error) + }) +} diff --git a/go/apps/api/routes/v2_keys_verify_key/404_test.go b/go/apps/api/routes/v2_keys_verify_key/404_test.go new file mode 100644 index 0000000000..5ed4f0369d --- /dev/null +++ b/go/apps/api/routes/v2_keys_verify_key/404_test.go @@ -0,0 +1,67 @@ +package handler_test + +import ( + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + "github.com/unkeyed/unkey/go/apps/api/openapi" + handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_verify_key" + "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" + "github.com/unkeyed/unkey/go/pkg/uid" +) + +func TestNotFound(t *testing.T) { + h := testutil.NewHarness(t) + + route := &handler.Handler{ + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + ClickHouse: h.ClickHouse, + } + + h.Register(route) + + workspace := h.Resources().UserWorkspace + rootKey := h.CreateRootKey(workspace.ID, "api.*.verify_key") + api := h.CreateApi(seed.CreateApiRequest{WorkspaceID: workspace.ID}) + + headers := http.Header{ + "Content-Type": {"application/json"}, + "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, + } + + t.Run("key not found", func(t *testing.T) { + req := handler.Request{ + ApiId: api.ID, + Key: uid.New("test"), + } + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.NOTFOUND, res.Body.Data.Code, "Key should be not found but got %s", res.Body.Data.Code) + require.False(t, res.Body.Data.Valid, "Key should be invalid but got %t", res.Body.Data.Valid) + }) + + t.Run("soft deleted key", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Deleted: true, + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + } + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.NOTFOUND, res.Body.Data.Code, "Key should be not found but got %s", res.Body.Data.Code) + require.False(t, res.Body.Data.Valid, "Key should be invalid but got %t", res.Body.Data.Valid) + }) +} diff --git a/go/apps/api/routes/v2_keys_verify_key/412_test.go b/go/apps/api/routes/v2_keys_verify_key/412_test.go new file mode 100644 index 0000000000..e56b0fdd9b --- /dev/null +++ b/go/apps/api/routes/v2_keys_verify_key/412_test.go @@ -0,0 +1,159 @@ +package handler_test + +import ( + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + "github.com/unkeyed/unkey/go/apps/api/openapi" + handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_verify_key" + "github.com/unkeyed/unkey/go/pkg/ptr" + "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" +) + +func TestPreconditionFailed(t *testing.T) { + h := testutil.NewHarness(t) + + route := &handler.Handler{ + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + ClickHouse: h.ClickHouse, + } + + h.Register(route) + + workspace := h.Resources().UserWorkspace + rootKey := h.CreateRootKey(workspace.ID, "api.*.verify_key") + api := h.CreateApi(seed.CreateApiRequest{WorkspaceID: workspace.ID}) + + validHeaders := http.Header{ + "Content-Type": {"application/json"}, + "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, + } + + t.Run("with identity - missing ratelimit", func(t *testing.T) { + identity := h.CreateIdentity(seed.CreateIdentityRequest{ + WorkspaceID: workspace.ID, + ExternalID: "test-missing-ratelimit", + Ratelimits: []seed.CreateRatelimitRequest{ + { + Name: "existing-ratelimit", + WorkspaceID: workspace.ID, + Duration: 60_000, + Limit: 100, + }, + }, + }) + + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + IdentityID: ptr.P(identity), + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + Ratelimits: &[]openapi.KeysVerifyKeyRatelimit{ + {Name: "does-not-exist"}, + }, + } + + res := testutil.CallRoute[handler.Request, openapi.BadRequestErrorResponse](h, route, validHeaders, req) + require.Equal(t, 412, res.Status) + require.NotNil(t, res.Body) + require.NotNil(t, res.Body.Error) + + // Should contain useful error message about missing ratelimit for key and identity + expectedMsg := fmt.Sprintf("ratelimit \"does-not-exist\" was requested but does not exist for key \"%s\" nor identity", key.KeyID) + require.Contains(t, res.Body.Error.Detail, expectedMsg) + require.Contains(t, res.Body.Error.Detail, identity) + require.Contains(t, res.Body.Error.Detail, "test-missing-ratelimit") + }) + + t.Run("without identity - missing ratelimit", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Ratelimits: []seed.CreateRatelimitRequest{ + { + Name: "existing-ratelimit", + WorkspaceID: workspace.ID, + Duration: 60_000, + Limit: 100, + }, + }, + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + Ratelimits: &[]openapi.KeysVerifyKeyRatelimit{ + {Name: "does-not-exist"}, + }, + } + + res := testutil.CallRoute[handler.Request, openapi.BadRequestErrorResponse](h, route, validHeaders, req) + require.Equal(t, 412, res.Status) + require.NotNil(t, res.Body) + require.NotNil(t, res.Body.Error) + + // Should contain error message indicating no identity connected + expectedMsg := fmt.Sprintf("ratelimit \"does-not-exist\" was requested but does not exist for key \"%s\" and there is no identity connected", key.KeyID) + require.Contains(t, res.Body.Error.Detail, expectedMsg) + }) + + t.Run("missing required fields", func(t *testing.T) { + t.Run("missing apiId", func(t *testing.T) { + req := handler.Request{ + Key: "test_key", + // ApiId missing + } + + res := testutil.CallRoute[handler.Request, openapi.BadRequestErrorResponse](h, route, validHeaders, req) + require.Equal(t, 400, res.Status) + require.NotNil(t, res.Body) + require.NotNil(t, res.Body.Error) + }) + + t.Run("missing key", func(t *testing.T) { + req := handler.Request{ + ApiId: api.ID, + // Key missing + } + + res := testutil.CallRoute[handler.Request, openapi.BadRequestErrorResponse](h, route, validHeaders, req) + require.Equal(t, 400, res.Status) + require.NotNil(t, res.Body) + require.NotNil(t, res.Body.Error) + }) + }) + + t.Run("invalid ratelimit configuration", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + Ratelimits: &[]openapi.KeysVerifyKeyRatelimit{ + { + Name: "missing_config", + Cost: ptr.P(4), + // Missing limit and duration for custom ratelimit + }, + }, + } + + res := testutil.CallRoute[handler.Request, openapi.BadRequestErrorResponse](h, route, validHeaders, req) + require.Equal(t, 412, res.Status) + require.NotNil(t, res.Body) + require.NotNil(t, res.Body.Error) + }) +} diff --git a/go/apps/api/routes/v2_keys_verify_key/handler.go b/go/apps/api/routes/v2_keys_verify_key/handler.go new file mode 100644 index 0000000000..9d0bfd8280 --- /dev/null +++ b/go/apps/api/routes/v2_keys_verify_key/handler.go @@ -0,0 +1,272 @@ +package handler + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/unkeyed/unkey/go/apps/api/openapi" + + "github.com/unkeyed/unkey/go/internal/services/auditlogs" + "github.com/unkeyed/unkey/go/internal/services/keys" + + "github.com/unkeyed/unkey/go/pkg/clickhouse" + "github.com/unkeyed/unkey/go/pkg/codes" + "github.com/unkeyed/unkey/go/pkg/db" + "github.com/unkeyed/unkey/go/pkg/fault" + "github.com/unkeyed/unkey/go/pkg/otel/logging" + "github.com/unkeyed/unkey/go/pkg/ptr" + "github.com/unkeyed/unkey/go/pkg/rbac" + "github.com/unkeyed/unkey/go/pkg/zen" +) + +type Request = openapi.V2KeysVerifyKeyRequestBody +type Response = openapi.V2KeysVerifyKeyResponseBody + +const DefaultCost = 1 + +// Handler implements zen.Route interface for the v2 keys.verify endpoint +type Handler struct { + Logger logging.Logger + DB db.Database + Keys keys.KeyService + Auditlogs auditlogs.AuditLogService + ClickHouse clickhouse.ClickHouse +} + +// Method returns the HTTP method this route responds to +func (h *Handler) Method() string { + return "POST" +} + +// Path returns the URL path pattern this route matches +func (h *Handler) Path() string { + return "/v2/keys.verifyKey" +} + +func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { + h.Logger.Debug("handling request", "requestId", s.RequestID(), "path", "/v2/keys.verifyKey") + + // Authentication + auth, err := h.Keys.GetRootKey(ctx, s) + if err != nil { + return err + } + + // Request validation + req, err := zen.BindBody[Request](s) + if err != nil { + return err + } + + key, err := h.Keys.Get(ctx, s, req.Key) + if err != nil { + return err + } + + // Validate key belongs to authorized workspace + if key.Key.WorkspaceID != auth.AuthorizedWorkspaceID { + return s.JSON(http.StatusOK, Response{ + Meta: openapi.Meta{ + RequestId: s.RequestID(), + }, + // nolint:exhaustruct + Data: openapi.KeysVerifyKeyResponseData{ + Code: openapi.NOTFOUND, + Valid: false, + }, + }) + } + + // Check if API is deleted + if key.Key.ApiDeletedAtM.Valid { + return s.JSON(http.StatusOK, Response{ + Meta: openapi.Meta{ + RequestId: s.RequestID(), + }, + // nolint:exhaustruct + Data: openapi.KeysVerifyKeyResponseData{ + Code: openapi.NOTFOUND, + Valid: false, + }, + }) + } + + // FIXME: We are leaking a keys existance here... by telling the user that he doesn't have perms + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: "*", + Action: rbac.VerifyKey, + }), + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: key.Key.ApiID, + Action: rbac.VerifyKey, + }), + ))) + if err != nil { + return err + } + + opts := []keys.VerifyOption{keys.WithIPWhitelist(), keys.WithApiID(req.ApiId), keys.WithTags(ptr.SafeDeref(req.Tags))} + + // If a custom cost was specified, use it, otherwise use a DefaultCost of 1 + if req.Credits != nil { + opts = append(opts, keys.WithCredits(req.Credits.Cost)) + } else if key.Key.RemainingRequests.Valid { + opts = append(opts, keys.WithCredits(DefaultCost)) + } + + if req.Ratelimits != nil { + opts = append(opts, keys.WithRateLimits(*req.Ratelimits)) + } else { + // check auto applied ratelimits + opts = append(opts, keys.WithRateLimits(nil)) + } + + if req.Permissions != nil { + query, queryErr := convertPermissionsToQuery(*req.Permissions) + if queryErr != nil { + return queryErr + } + + opts = append(opts, keys.WithPermissions(query)) + } + + err = key.Verify(ctx, opts...) + if err != nil { + return err + } + + res := Response{ + Meta: openapi.Meta{ + RequestId: s.RequestID(), + }, + // nolint:exhaustruct + Data: openapi.KeysVerifyKeyResponseData{ + Code: key.ToOpenAPIStatus(), + Valid: key.Status == keys.StatusValid, + Enabled: ptr.P(key.Key.Enabled), + Name: ptr.P(key.Key.Name.String), + Permissions: ptr.P(key.Permissions), + Roles: ptr.P(key.Roles), + KeyId: ptr.P(key.Key.ID), + Credits: nil, + Expires: nil, + Identity: nil, + Meta: nil, + Ratelimits: nil, + }, + } + + remaining := key.Key.RemainingRequests + if remaining.Valid { + res.Data.Credits = ptr.P(remaining.Int32) + } + + if key.Key.Expires.Valid { + res.Data.Expires = ptr.P(key.Key.Expires.Time.UnixMilli()) + } + + if key.Key.Meta.Valid { + err = json.Unmarshal([]byte(key.Key.Meta.String), &res.Data.Meta) + if err != nil { + return fault.Wrap(err, fault.Code(codes.App.Internal.UnexpectedError.URN()), + fault.Internal("unable to unmarshal key meta"), + fault.Public("We encountered an error while trying to unmarshal the key meta data."), + ) + } + } + + if key.Key.IdentityID.Valid { + res.Data.Identity = &openapi.Identity{ + ExternalId: key.Key.ExternalID.String, + Id: key.Key.IdentityID.String, + Ratelimits: nil, + Meta: nil, + } + + for _, ratelimit := range key.GetRatelimitConfigs() { + if ratelimit.IdentityID == "" { + continue + } + + res.Data.Identity.Ratelimits = append(res.Data.Identity.Ratelimits, openapi.RatelimitResponse{ + AutoApply: ratelimit.AutoApply == 1, + Duration: int64(ratelimit.Duration), + Id: ratelimit.ID, + Limit: int64(ratelimit.Limit), + Name: ratelimit.Name, + }) + } + + if len(key.Key.IdentityMeta) > 0 { + err = json.Unmarshal(key.Key.IdentityMeta, &res.Data.Identity.Meta) + if err != nil { + return fault.Wrap(err, fault.Code(codes.App.Internal.UnexpectedError.URN()), + fault.Internal("unable to unmarshal identity meta"), + fault.Public("We encountered an error while trying to unmarshal the identity meta data."), + ) + } + } + } + + if len(key.RatelimitResults) > 0 { + ratelimitResponse := make([]openapi.VerifyKeyRatelimitData, 0) + for _, result := range key.RatelimitResults { + if result.Response == nil { + continue + } + + ratelimitResponse = append(ratelimitResponse, openapi.VerifyKeyRatelimitData{ + AutoApply: result.AutoApply, + Duration: result.Duration.Milliseconds(), + Exceeded: !result.Response.Success, + Id: result.Name, + Limit: result.Limit, + Name: result.Name, + Remaining: result.Response.Remaining, + Reset: result.Response.Reset.UnixMilli(), + }) + } + + if len(ratelimitResponse) > 0 { + res.Data.Ratelimits = ptr.P(ratelimitResponse) + } + } + + return s.JSON(http.StatusOK, res) +} + +// convertPermissionsToQuery converts OpenAPI permissions to rbac.PermissionQuery +func convertPermissionsToQuery(permissions openapi.V2KeysVerifyKeyRequestBody_Permissions) (rbac.PermissionQuery, error) { + // Try to unmarshal as string first (single permission) + if perm, err := permissions.AsV2KeysVerifyKeyRequestBodyPermissions0(); err == nil { + return rbac.S(perm), nil + } + + // Try to unmarshal as object (multiple permissions with operator) + if obj, err := permissions.AsV2KeysVerifyKeyRequestBodyPermissions1(); err == nil { + if len(obj.Permissions) == 0 { + return rbac.PermissionQuery{}, fmt.Errorf("permissions array cannot be empty") + } + + queries := make([]rbac.PermissionQuery, 0, len(obj.Permissions)) + for _, perm := range obj.Permissions { + queries = append(queries, rbac.S(perm)) + } + + switch obj.Type { + case openapi.And: + return rbac.And(queries...), nil + case openapi.Or: + return rbac.Or(queries...), nil + default: + return rbac.PermissionQuery{}, fmt.Errorf("unsupported operator: %s", obj.Type) + } + } + + return rbac.PermissionQuery{}, fmt.Errorf("invalid permissions format") +} diff --git a/go/apps/api/routes/v2_keys_verify_key/multilimit_test.go b/go/apps/api/routes/v2_keys_verify_key/multilimit_test.go new file mode 100644 index 0000000000..d10622464b --- /dev/null +++ b/go/apps/api/routes/v2_keys_verify_key/multilimit_test.go @@ -0,0 +1,469 @@ +package handler_test + +import ( + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + "github.com/unkeyed/unkey/go/apps/api/openapi" + handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_verify_key" + "github.com/unkeyed/unkey/go/pkg/ptr" + "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" +) + +func TestMultiLimit(t *testing.T) { + h := testutil.NewHarness(t) + + route := &handler.Handler{ + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + ClickHouse: h.ClickHouse, + } + + h.Register(route) + + workspace := h.Resources().UserWorkspace + rootKey := h.CreateRootKey(workspace.ID, "api.*.verify_key") + api := h.CreateApi(seed.CreateApiRequest{WorkspaceID: workspace.ID}) + + headers := http.Header{ + "Content-Type": {"application/json"}, + "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, + } + + t.Run("without identities", func(t *testing.T) { + t.Run("returns valid with multiple limits", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Ratelimits: []seed.CreateRatelimitRequest{ + { + Name: "10/10s", + WorkspaceID: workspace.ID, + Duration: 10_000, + Limit: 10, + }, + { + Name: "1/1min", + WorkspaceID: workspace.ID, + Duration: 60_000, + Limit: 1, + }, + }, + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + Ratelimits: &[]openapi.KeysVerifyKeyRatelimit{ + {Name: "10/10s", Cost: ptr.P(4)}, + {Name: "1/1min", Cost: ptr.P(1)}, + }, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status) + require.NotNil(t, res.Body) + require.Equal(t, openapi.VALID, res.Body.Data.Code) + require.True(t, res.Body.Data.Valid) + }) + + t.Run("returns RATE_LIMITED when one limit exceeded", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Ratelimits: []seed.CreateRatelimitRequest{ + { + Name: "10/10s-test", + WorkspaceID: workspace.ID, + Duration: 10_000, + Limit: 10, + }, + { + Name: "1/1min-test", + WorkspaceID: workspace.ID, + Duration: 60_000, + Limit: 1, + }, + }, + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + Ratelimits: &[]openapi.KeysVerifyKeyRatelimit{ + {Name: "10/10s-test", Cost: ptr.P(4)}, + {Name: "1/1min-test", Cost: ptr.P(2)}, // Exceeds limit of 1 + }, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status) + require.NotNil(t, res.Body) + require.Equal(t, openapi.RATELIMITED, res.Body.Data.Code) + require.False(t, res.Body.Data.Valid) + }) + }) + + t.Run("with identity - key precedence over identity", func(t *testing.T) { + t.Run("key limits take precedence and pass", func(t *testing.T) { + identity := h.CreateIdentity(seed.CreateIdentityRequest{ + WorkspaceID: workspace.ID, + ExternalID: "test-precedence-pass", + Ratelimits: []seed.CreateRatelimitRequest{ + { + Name: "limit1", + WorkspaceID: workspace.ID, + Duration: 600_000, + Limit: 1, // Identity has restrictive limit + }, + { + Name: "limit2", + WorkspaceID: workspace.ID, + Duration: 600_000, + Limit: 10, + }, + }, + }) + + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + IdentityID: ptr.P(identity), + Ratelimits: []seed.CreateRatelimitRequest{ + { + Name: "limit1", + WorkspaceID: workspace.ID, + Duration: 10_000, + Limit: 4, // Key has more permissive limit + }, + }, + }) + + // Should pass 3 times due to key limit being 4 + for i := 0; i < 3; i++ { + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + Ratelimits: &[]openapi.KeysVerifyKeyRatelimit{{Name: "limit1"}, {Name: "limit2"}}, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status) + require.NotNil(t, res.Body) + require.Equal(t, openapi.VALID, res.Body.Data.Code) + require.True(t, res.Body.Data.Valid) + } + }) + + t.Run("key limits take precedence and reject", func(t *testing.T) { + identity := h.CreateIdentity(seed.CreateIdentityRequest{ + WorkspaceID: workspace.ID, + ExternalID: "test-precedence-reject", + Ratelimits: []seed.CreateRatelimitRequest{ + { + Name: "limit1-reject", + WorkspaceID: workspace.ID, + Duration: 600_000, + Limit: 10, // Identity has permissive limit + }, + { + Name: "limit2-reject", + WorkspaceID: workspace.ID, + Duration: 600_000, + Limit: 10, + }, + }, + }) + + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + IdentityID: ptr.P(identity), + Ratelimits: []seed.CreateRatelimitRequest{ + { + Name: "limit1-reject", + WorkspaceID: workspace.ID, + Duration: 10_000, + Limit: 1, // Key has restrictive limit + }, + }, + }) + + // First request should pass + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + Ratelimits: &[]openapi.KeysVerifyKeyRatelimit{{Name: "limit1-reject"}, {Name: "limit2-reject"}}, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status) + require.NotNil(t, res.Body) + require.Equal(t, openapi.VALID, res.Body.Data.Code) + require.True(t, res.Body.Data.Valid) + + // Second request should be rate limited due to key limit + res = testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status) + require.NotNil(t, res.Body) + require.Equal(t, openapi.RATELIMITED, res.Body.Data.Code) + require.False(t, res.Body.Data.Valid) + }) + + t.Run("fallback identity limits still reject", func(t *testing.T) { + identity := h.CreateIdentity(seed.CreateIdentityRequest{ + WorkspaceID: workspace.ID, + ExternalID: "test-fallback", + Ratelimits: []seed.CreateRatelimitRequest{ + { + Name: "limit1-fallback", + WorkspaceID: workspace.ID, + Duration: 600_000, + Limit: 10, + }, + { + Name: "limit2-fallback", + WorkspaceID: workspace.ID, + Duration: 600_000, + Limit: 2, // This will be the fallback limit + }, + }, + }) + + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + IdentityID: ptr.P(identity), + Ratelimits: []seed.CreateRatelimitRequest{ + { + Name: "limit1-fallback", + WorkspaceID: workspace.ID, + Duration: 10_000, + Limit: 4, // Key limit for limit1 + }, + }, + }) + + // Should pass twice (limit2 has limit of 2) + for i := 0; i < 2; i++ { + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + Ratelimits: &[]openapi.KeysVerifyKeyRatelimit{{Name: "limit1-fallback"}, {Name: "limit2-fallback"}}, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status) + require.NotNil(t, res.Body) + require.Equal(t, openapi.VALID, res.Body.Data.Code) + require.True(t, res.Body.Data.Valid) + } + + // Third request should be rate limited by limit2 (identity fallback) + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + Ratelimits: &[]openapi.KeysVerifyKeyRatelimit{{Name: "limit1-fallback"}, {Name: "limit2-fallback"}}, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status) + require.NotNil(t, res.Body) + require.Equal(t, openapi.RATELIMITED, res.Body.Data.Code) + require.False(t, res.Body.Data.Valid) + }) + }) + + t.Run("with identity - shared rate limits across keys", func(t *testing.T) { + t.Run("rate limit is shared across multiple keys", func(t *testing.T) { + identity := h.CreateIdentity(seed.CreateIdentityRequest{ + WorkspaceID: workspace.ID, + ExternalID: "test-shared", + Ratelimits: []seed.CreateRatelimitRequest{ + { + Name: "100per10m", + WorkspaceID: workspace.ID, + Duration: 600_000, + Limit: 5, // Small limit for testing + }, + }, + }) + + // Create multiple keys with same identity + key1 := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + IdentityID: ptr.P(identity), + }) + + key2 := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + IdentityID: ptr.P(identity), + }) + + // Use up some quota with key1 + for i := 0; i < 3; i++ { + req := handler.Request{ + ApiId: api.ID, + Key: key1.Key, + Ratelimits: &[]openapi.KeysVerifyKeyRatelimit{{Name: "100per10m"}}, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status) + require.NotNil(t, res.Body) + require.Equal(t, openapi.VALID, res.Body.Data.Code) + require.True(t, res.Body.Data.Valid) + } + + // key2 should only have 2 requests left due to shared limit + for i := 0; i < 2; i++ { + req := handler.Request{ + ApiId: api.ID, + Key: key2.Key, + Ratelimits: &[]openapi.KeysVerifyKeyRatelimit{{Name: "100per10m"}}, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status) + require.NotNil(t, res.Body) + require.Equal(t, openapi.VALID, res.Body.Data.Code) + require.True(t, res.Body.Data.Valid) + } + + // Next request with key2 should be rate limited + req := handler.Request{ + ApiId: api.ID, + Key: key2.Key, + Ratelimits: &[]openapi.KeysVerifyKeyRatelimit{{Name: "100per10m"}}, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status) + require.NotNil(t, res.Body) + require.Equal(t, openapi.RATELIMITED, res.Body.Data.Code) + require.False(t, res.Body.Data.Valid) + }) + }) + + t.Run("without specifying ratelimits in request", func(t *testing.T) { + t.Run("should use auto-applied default limit", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Ratelimits: []seed.CreateRatelimitRequest{ + { + Name: "default-auto", + WorkspaceID: workspace.ID, + Duration: 20_000, + Limit: 1, + AutoApply: true, + }, + }, + }) + + // First request should pass + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + // No ratelimits specified - should use auto-applied + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status) + require.NotNil(t, res.Body) + require.Equal(t, openapi.VALID, res.Body.Data.Code) + require.True(t, res.Body.Data.Valid) + + // Second request should be rate limited + res = testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status) + require.NotNil(t, res.Body) + require.Equal(t, openapi.RATELIMITED, res.Body.Data.Code) + require.False(t, res.Body.Data.Valid) + }) + }) + + t.Run("falls back to identity limits", func(t *testing.T) { + t.Run("should reject after identity limit hit", func(t *testing.T) { + identity := h.CreateIdentity(seed.CreateIdentityRequest{ + WorkspaceID: workspace.ID, + ExternalID: "test-identity-fallback", + Ratelimits: []seed.CreateRatelimitRequest{ + { + Name: "tokens-identity", + WorkspaceID: workspace.ID, + Duration: 10_000, + Limit: 10, + }, + { + Name: "10_per_10m-identity", + WorkspaceID: workspace.ID, + Duration: 600_000, + Limit: 10, + }, + }, + }) + + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + IdentityID: ptr.P(identity), + }) + + // First request with cost 4 should pass + req1 := handler.Request{ + ApiId: api.ID, + Key: key.Key, + Ratelimits: &[]openapi.KeysVerifyKeyRatelimit{ + {Name: "tokens-identity", Cost: ptr.P(4)}, + {Name: "10_per_10m-identity"}, + }, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req1) + require.Equal(t, 200, res.Status) + require.NotNil(t, res.Body) + require.Equal(t, openapi.VALID, res.Body.Data.Code) + require.True(t, res.Body.Data.Valid) + + // Second request with cost 6 should pass (total 10) + req2 := handler.Request{ + ApiId: api.ID, + Key: key.Key, + Ratelimits: &[]openapi.KeysVerifyKeyRatelimit{ + {Name: "tokens-identity", Cost: ptr.P(6)}, + {Name: "10_per_10m-identity"}, + }, + } + + res = testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req2) + require.Equal(t, 200, res.Status) + require.NotNil(t, res.Body) + require.Equal(t, openapi.VALID, res.Body.Data.Code) + require.True(t, res.Body.Data.Valid) + + // Third request with cost 1 should be rate limited (would be 11 total) + req3 := handler.Request{ + ApiId: api.ID, + Key: key.Key, + Ratelimits: &[]openapi.KeysVerifyKeyRatelimit{ + {Name: "tokens-identity", Cost: ptr.P(1)}, + {Name: "10_per_10m-identity"}, + }, + } + + res = testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req3) + require.Equal(t, 200, res.Status) + require.NotNil(t, res.Body) + require.Equal(t, openapi.RATELIMITED, res.Body.Data.Code) + require.False(t, res.Body.Data.Valid) + }) + }) +} diff --git a/go/apps/api/routes/v2_keys_verify_key/ratelimit_response_test.go b/go/apps/api/routes/v2_keys_verify_key/ratelimit_response_test.go new file mode 100644 index 0000000000..5ec05a349a --- /dev/null +++ b/go/apps/api/routes/v2_keys_verify_key/ratelimit_response_test.go @@ -0,0 +1,153 @@ +package handler_test + +import ( + "fmt" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/unkeyed/unkey/go/apps/api/openapi" + handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_verify_key" + "github.com/unkeyed/unkey/go/pkg/ptr" + "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" +) + +func TestRatelimitResponse(t *testing.T) { + h := testutil.NewHarness(t) + + route := &handler.Handler{ + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + ClickHouse: h.ClickHouse, + } + + h.Register(route) + + workspace := h.Resources().UserWorkspace + rootKey := h.CreateRootKey(workspace.ID, "api.*.verify_key") + api := h.CreateApi(seed.CreateApiRequest{WorkspaceID: workspace.ID}) + + headers := http.Header{ + "Content-Type": {"application/json"}, + "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, + } + + t.Run("rate limit response fields validation", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Ratelimits: []seed.CreateRatelimitRequest{ + { + Name: "test-limit", + WorkspaceID: workspace.ID, + AutoApply: true, + Duration: time.Minute.Milliseconds(), + Limit: 5, + }, + }, + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.VALID, res.Body.Data.Code, "Key should be valid but got %s", res.Body.Data.Code) + require.True(t, res.Body.Data.Valid, "Key should be valid but got %t", res.Body.Data.Valid) + + // Validate rate limit response fields + require.NotNil(t, res.Body.Data.Ratelimits, "Rate limits should be present") + ratelimits := *res.Body.Data.Ratelimits + require.Len(t, ratelimits, 1, "Should have one rate limit") + + rl := ratelimits[0] + require.Equal(t, "test-limit", rl.Name, "Rate limit name should match") + require.Equal(t, int64(5), rl.Limit, "Rate limit limit should match") + require.Equal(t, time.Minute.Milliseconds(), rl.Duration, "Rate limit duration should match") + require.True(t, rl.AutoApply, "Rate limit should be auto-applied") + require.False(t, rl.Exceeded, "Rate limit should not be exceeded") + require.Equal(t, int64(4), rl.Remaining, "Should have 4 remaining requests") + require.Greater(t, rl.Reset, time.Now().UnixMilli(), "Reset time should be in the future") + }) + + t.Run("rate limit exceeded fields", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Ratelimits: []seed.CreateRatelimitRequest{ + { + Name: "strict-limit", + WorkspaceID: workspace.ID, + AutoApply: true, + Duration: time.Minute.Milliseconds(), + Limit: 1, + }, + }, + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + } + + // First request should pass + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, openapi.VALID, res.Body.Data.Code) + require.True(t, res.Body.Data.Valid) + + // Second request should be rate limited + res = testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, openapi.RATELIMITED, res.Body.Data.Code) + require.False(t, res.Body.Data.Valid) + + // Validate rate limit response fields for exceeded limit + require.NotNil(t, res.Body.Data.Ratelimits, "Rate limits should be present") + ratelimits := *res.Body.Data.Ratelimits + require.Len(t, ratelimits, 1, "Should have one rate limit") + + rl := ratelimits[0] + require.True(t, rl.Exceeded, "Rate limit should be exceeded") + require.Equal(t, int64(0), rl.Remaining, "Should have 0 remaining requests") + }) + + t.Run("custom rate limit with cost", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + }) + + req := handler.Request{ + ApiId: api.ID, + Key: key.Key, + Ratelimits: &[]openapi.KeysVerifyKeyRatelimit{{ + Name: "custom", + Cost: ptr.P(3), + Duration: ptr.P(int(time.Minute.Milliseconds())), + Limit: ptr.P(10), + }}, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %#v", res) + require.NotNil(t, res.Body) + require.Equal(t, openapi.VALID, res.Body.Data.Code, "Key should be valid but got %s", res.Body.Data.Code) + + // Validate custom rate limit response + require.NotNil(t, res.Body.Data.Ratelimits, "Rate limits should be present") + ratelimits := *res.Body.Data.Ratelimits + require.Len(t, ratelimits, 1, "Should have one rate limit") + + rl := ratelimits[0] + require.Equal(t, "custom", rl.Name, "Rate limit name should match") + require.Equal(t, int64(10), rl.Limit, "Rate limit limit should match") + require.Equal(t, int64(7), rl.Remaining, "Should have 7 remaining (10-3)") + require.False(t, rl.AutoApply, "Custom rate limit should not be auto-applied") + }) +} diff --git a/go/apps/api/routes/v2_permissions_create_permission/200_test.go b/go/apps/api/routes/v2_permissions_create_permission/200_test.go index 4637348a1f..934ff76480 100644 --- a/go/apps/api/routes/v2_permissions_create_permission/200_test.go +++ b/go/apps/api/routes/v2_permissions_create_permission/200_test.go @@ -17,11 +17,10 @@ func TestSuccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_create_permission/400_test.go b/go/apps/api/routes/v2_permissions_create_permission/400_test.go index 0a0dc5b4bb..092ad77767 100644 --- a/go/apps/api/routes/v2_permissions_create_permission/400_test.go +++ b/go/apps/api/routes/v2_permissions_create_permission/400_test.go @@ -16,11 +16,10 @@ func TestValidationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_create_permission/401_test.go b/go/apps/api/routes/v2_permissions_create_permission/401_test.go index aa5efdace7..d90ea3917b 100644 --- a/go/apps/api/routes/v2_permissions_create_permission/401_test.go +++ b/go/apps/api/routes/v2_permissions_create_permission/401_test.go @@ -14,11 +14,10 @@ func TestAuthenticationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_create_permission/403_test.go b/go/apps/api/routes/v2_permissions_create_permission/403_test.go index 274574fb6d..34b3b07a63 100644 --- a/go/apps/api/routes/v2_permissions_create_permission/403_test.go +++ b/go/apps/api/routes/v2_permissions_create_permission/403_test.go @@ -15,11 +15,10 @@ func TestAuthorizationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_create_permission/409_test.go b/go/apps/api/routes/v2_permissions_create_permission/409_test.go index 9767c6ba50..94937ed7c2 100644 --- a/go/apps/api/routes/v2_permissions_create_permission/409_test.go +++ b/go/apps/api/routes/v2_permissions_create_permission/409_test.go @@ -15,11 +15,10 @@ func TestConflictErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_create_permission/handler.go b/go/apps/api/routes/v2_permissions_create_permission/handler.go index 90fb58b2ed..cc7689bc79 100644 --- a/go/apps/api/routes/v2_permissions_create_permission/handler.go +++ b/go/apps/api/routes/v2_permissions_create_permission/handler.go @@ -9,7 +9,6 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/auditlogs" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/auditlog" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" @@ -27,11 +26,10 @@ type Response = openapi.V2PermissionsCreatePermissionResponseBody // Handler implements zen.Route interface for the v2 permissions create permission endpoint type Handler struct { // Services as public fields - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService - Auditlogs auditlogs.AuditLogService + Logger logging.Logger + DB db.Database + Keys keys.KeyService + Auditlogs auditlogs.AuditLogService } // Method returns the HTTP method this route responds to @@ -46,7 +44,7 @@ func (h *Handler) Path() string { // Handle processes the HTTP request func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -56,16 +54,11 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { return err } - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.T(rbac.Tuple{ - ResourceType: rbac.Rbac, - ResourceID: "*", - Action: rbac.CreatePermission, - }), - ) - + err = auth.Verify(ctx, keys.WithPermissions(rbac.T(rbac.Tuple{ + ResourceType: rbac.Rbac, + ResourceID: "*", + Action: rbac.CreatePermission, + }))) if err != nil { return err } @@ -105,7 +98,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: "permission.create", ActorType: auditlog.RootKeyActor, - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorName: "root key", ActorMeta: map[string]any{}, Display: "Created " + permissionID, diff --git a/go/apps/api/routes/v2_permissions_create_role/200_test.go b/go/apps/api/routes/v2_permissions_create_role/200_test.go index 475a99ca5c..7746fb0353 100644 --- a/go/apps/api/routes/v2_permissions_create_role/200_test.go +++ b/go/apps/api/routes/v2_permissions_create_role/200_test.go @@ -17,11 +17,10 @@ func TestSuccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_create_role/400_test.go b/go/apps/api/routes/v2_permissions_create_role/400_test.go index a0c78155e8..2e72920c16 100644 --- a/go/apps/api/routes/v2_permissions_create_role/400_test.go +++ b/go/apps/api/routes/v2_permissions_create_role/400_test.go @@ -15,11 +15,10 @@ func TestValidationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_create_role/401_test.go b/go/apps/api/routes/v2_permissions_create_role/401_test.go index f97a2ed174..4899daeae7 100644 --- a/go/apps/api/routes/v2_permissions_create_role/401_test.go +++ b/go/apps/api/routes/v2_permissions_create_role/401_test.go @@ -14,11 +14,10 @@ func TestAuthenticationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_create_role/403_test.go b/go/apps/api/routes/v2_permissions_create_role/403_test.go index d27c2b64d7..b42989938a 100644 --- a/go/apps/api/routes/v2_permissions_create_role/403_test.go +++ b/go/apps/api/routes/v2_permissions_create_role/403_test.go @@ -17,11 +17,10 @@ func TestAuthorizationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_create_role/409_test.go b/go/apps/api/routes/v2_permissions_create_role/409_test.go index d2d44b1a83..8b58246918 100644 --- a/go/apps/api/routes/v2_permissions_create_role/409_test.go +++ b/go/apps/api/routes/v2_permissions_create_role/409_test.go @@ -19,11 +19,10 @@ func TestConflictErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_create_role/handler.go b/go/apps/api/routes/v2_permissions_create_role/handler.go index a8e3ccf2ef..c172494891 100644 --- a/go/apps/api/routes/v2_permissions_create_role/handler.go +++ b/go/apps/api/routes/v2_permissions_create_role/handler.go @@ -9,7 +9,6 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/auditlogs" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/auditlog" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" @@ -27,11 +26,10 @@ type Response = openapi.V2PermissionsCreateRoleResponseBody // Handler implements zen.Route interface for the v2 permissions create role endpoint type Handler struct { // Services as public fields - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService - Auditlogs auditlogs.AuditLogService + Logger logging.Logger + DB db.Database + Keys keys.KeyService + Auditlogs auditlogs.AuditLogService } // Method returns the HTTP method this route responds to @@ -49,7 +47,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { h.Logger.Debug("handling request", "requestId", s.RequestID(), "path", "/v2/permissions.createRole") // 1. Authentication - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -61,17 +59,13 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { } // 3. Permission check - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Rbac, - ResourceID: "*", - Action: rbac.CreateRole, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Rbac, + ResourceID: "*", + Action: rbac.CreateRole, + }), + ))) if err != nil { return err } @@ -114,7 +108,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: "role.create", ActorType: auditlog.RootKeyActor, - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorName: "root key", ActorMeta: map[string]any{}, Display: "Created " + roleID, diff --git a/go/apps/api/routes/v2_permissions_delete_permission/200_test.go b/go/apps/api/routes/v2_permissions_delete_permission/200_test.go index 9fd23c6aaf..b21aaf989f 100644 --- a/go/apps/api/routes/v2_permissions_delete_permission/200_test.go +++ b/go/apps/api/routes/v2_permissions_delete_permission/200_test.go @@ -20,11 +20,10 @@ func TestSuccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_delete_permission/400_test.go b/go/apps/api/routes/v2_permissions_delete_permission/400_test.go index 5062c7dff9..28282e102b 100644 --- a/go/apps/api/routes/v2_permissions_delete_permission/400_test.go +++ b/go/apps/api/routes/v2_permissions_delete_permission/400_test.go @@ -15,11 +15,10 @@ func TestValidationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_delete_permission/401_test.go b/go/apps/api/routes/v2_permissions_delete_permission/401_test.go index 7c2c653fd4..a919fc8122 100644 --- a/go/apps/api/routes/v2_permissions_delete_permission/401_test.go +++ b/go/apps/api/routes/v2_permissions_delete_permission/401_test.go @@ -14,11 +14,10 @@ func TestAuthenticationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_delete_permission/403_test.go b/go/apps/api/routes/v2_permissions_delete_permission/403_test.go index fe39433153..2253df8fe1 100644 --- a/go/apps/api/routes/v2_permissions_delete_permission/403_test.go +++ b/go/apps/api/routes/v2_permissions_delete_permission/403_test.go @@ -21,11 +21,10 @@ func TestAuthorizationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_delete_permission/404_test.go b/go/apps/api/routes/v2_permissions_delete_permission/404_test.go index 66e518246e..d2261ede27 100644 --- a/go/apps/api/routes/v2_permissions_delete_permission/404_test.go +++ b/go/apps/api/routes/v2_permissions_delete_permission/404_test.go @@ -21,11 +21,10 @@ func TestNotFoundErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_delete_permission/handler.go b/go/apps/api/routes/v2_permissions_delete_permission/handler.go index d94efbec80..53a430305c 100644 --- a/go/apps/api/routes/v2_permissions_delete_permission/handler.go +++ b/go/apps/api/routes/v2_permissions_delete_permission/handler.go @@ -7,7 +7,6 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/auditlogs" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/auditlog" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" @@ -23,11 +22,10 @@ type Response = openapi.V2PermissionsDeletePermissionResponseBody // Handler implements zen.Route interface for the v2 permissions delete permission endpoint type Handler struct { // Services as public fields - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService - Auditlogs auditlogs.AuditLogService + Logger logging.Logger + DB db.Database + Keys keys.KeyService + Auditlogs auditlogs.AuditLogService } // Method returns the HTTP method this route responds to @@ -43,7 +41,7 @@ func (h *Handler) Path() string { // Handle processes the HTTP request func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { // 1. Authentication - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -55,17 +53,13 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { } // 3. Permission check - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Rbac, - ResourceID: "*", - Action: rbac.DeletePermission, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Rbac, + ResourceID: "*", + Action: rbac.DeletePermission, + }), + ))) if err != nil { return err } @@ -128,7 +122,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: "permission.delete", ActorType: auditlog.RootKeyActor, - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorName: "root key", ActorMeta: map[string]any{}, Display: "Deleted " + req.PermissionId, diff --git a/go/apps/api/routes/v2_permissions_delete_role/200_test.go b/go/apps/api/routes/v2_permissions_delete_role/200_test.go index 384c19d794..b19b8e43dc 100644 --- a/go/apps/api/routes/v2_permissions_delete_role/200_test.go +++ b/go/apps/api/routes/v2_permissions_delete_role/200_test.go @@ -21,11 +21,10 @@ func TestSuccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_delete_role/400_test.go b/go/apps/api/routes/v2_permissions_delete_role/400_test.go index f2df96aa65..4a577c8bd0 100644 --- a/go/apps/api/routes/v2_permissions_delete_role/400_test.go +++ b/go/apps/api/routes/v2_permissions_delete_role/400_test.go @@ -15,11 +15,10 @@ func TestValidationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_delete_role/401_test.go b/go/apps/api/routes/v2_permissions_delete_role/401_test.go index b20e6d630a..97b83fd8ed 100644 --- a/go/apps/api/routes/v2_permissions_delete_role/401_test.go +++ b/go/apps/api/routes/v2_permissions_delete_role/401_test.go @@ -14,11 +14,10 @@ func TestAuthenticationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_delete_role/403_test.go b/go/apps/api/routes/v2_permissions_delete_role/403_test.go index 00db1d8d2a..732bb1fc68 100644 --- a/go/apps/api/routes/v2_permissions_delete_role/403_test.go +++ b/go/apps/api/routes/v2_permissions_delete_role/403_test.go @@ -20,11 +20,10 @@ func TestPermissionErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_delete_role/404_test.go b/go/apps/api/routes/v2_permissions_delete_role/404_test.go index 6941f4ae8f..7e1cb182d2 100644 --- a/go/apps/api/routes/v2_permissions_delete_role/404_test.go +++ b/go/apps/api/routes/v2_permissions_delete_role/404_test.go @@ -20,11 +20,10 @@ func TestNotFound(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_delete_role/handler.go b/go/apps/api/routes/v2_permissions_delete_role/handler.go index 293993f61b..ac6a1872d5 100644 --- a/go/apps/api/routes/v2_permissions_delete_role/handler.go +++ b/go/apps/api/routes/v2_permissions_delete_role/handler.go @@ -7,7 +7,6 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/auditlogs" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/auditlog" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" @@ -23,11 +22,10 @@ type Response = openapi.V2PermissionsDeleteRoleResponseBody // Handler implements zen.Route interface for the v2 permissions delete role endpoint type Handler struct { // Services as public fields - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService - Auditlogs auditlogs.AuditLogService + Logger logging.Logger + DB db.Database + Keys keys.KeyService + Auditlogs auditlogs.AuditLogService } // Method returns the HTTP method this route responds to @@ -45,7 +43,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { h.Logger.Debug("handling request", "requestId", s.RequestID(), "path", "/v2/permissions.deleteRole") // 1. Authentication - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -57,17 +55,13 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { } // 3. Permission check - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Rbac, - ResourceID: "*", - Action: rbac.DeleteRole, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Rbac, + ResourceID: "*", + Action: rbac.DeleteRole, + }), + ))) if err != nil { return err } @@ -130,7 +124,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: "role.delete", ActorType: auditlog.RootKeyActor, - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorName: "root key", ActorMeta: map[string]any{}, Display: "Deleted " + req.RoleId, diff --git a/go/apps/api/routes/v2_permissions_get_permission/200_test.go b/go/apps/api/routes/v2_permissions_get_permission/200_test.go index 94d9c622d3..e72fd191ce 100644 --- a/go/apps/api/routes/v2_permissions_get_permission/200_test.go +++ b/go/apps/api/routes/v2_permissions_get_permission/200_test.go @@ -20,10 +20,9 @@ func TestSuccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_get_permission/400_test.go b/go/apps/api/routes/v2_permissions_get_permission/400_test.go index 98e09b12ad..9f800872c2 100644 --- a/go/apps/api/routes/v2_permissions_get_permission/400_test.go +++ b/go/apps/api/routes/v2_permissions_get_permission/400_test.go @@ -15,10 +15,9 @@ func TestValidationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_get_permission/401_test.go b/go/apps/api/routes/v2_permissions_get_permission/401_test.go index 777f7643e4..ebdc113343 100644 --- a/go/apps/api/routes/v2_permissions_get_permission/401_test.go +++ b/go/apps/api/routes/v2_permissions_get_permission/401_test.go @@ -14,10 +14,9 @@ func TestAuthenticationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_get_permission/403_test.go b/go/apps/api/routes/v2_permissions_get_permission/403_test.go index e12faa3160..5d5c0fd4a8 100644 --- a/go/apps/api/routes/v2_permissions_get_permission/403_test.go +++ b/go/apps/api/routes/v2_permissions_get_permission/403_test.go @@ -21,10 +21,9 @@ func TestPermissionErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_get_permission/404_test.go b/go/apps/api/routes/v2_permissions_get_permission/404_test.go index 9a3833f999..9612ed6fb0 100644 --- a/go/apps/api/routes/v2_permissions_get_permission/404_test.go +++ b/go/apps/api/routes/v2_permissions_get_permission/404_test.go @@ -16,10 +16,9 @@ func TestNotFoundErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_get_permission/handler.go b/go/apps/api/routes/v2_permissions_get_permission/handler.go index 0c330506e0..c40348f4dc 100644 --- a/go/apps/api/routes/v2_permissions_get_permission/handler.go +++ b/go/apps/api/routes/v2_permissions_get_permission/handler.go @@ -6,7 +6,6 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/fault" @@ -21,10 +20,9 @@ type Response = openapi.V2PermissionsGetPermissionResponseBody // Handler implements zen.Route interface for the v2 permissions get permission endpoint type Handler struct { // Services as public fields - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService + Logger logging.Logger + DB db.Database + Keys keys.KeyService } // Method returns the HTTP method this route responds to @@ -42,7 +40,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { h.Logger.Debug("handling request", "requestId", s.RequestID(), "path", "/v2/permissions.getPermission") // 1. Authentication - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -54,17 +52,13 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { } // 3. Permission check - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Rbac, - ResourceID: "*", - Action: rbac.ReadPermission, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Rbac, + ResourceID: "*", + Action: rbac.ReadPermission, + }), + ))) if err != nil { return err } @@ -96,6 +90,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { permissionResponse := openapi.Permission{ Id: permission.ID, Name: permission.Name, + Slug: permission.Slug, Description: nil, CreatedAt: permission.CreatedAtM, } diff --git a/go/apps/api/routes/v2_permissions_get_role/200_test.go b/go/apps/api/routes/v2_permissions_get_role/200_test.go index 11f26fbef1..46d29154d2 100644 --- a/go/apps/api/routes/v2_permissions_get_role/200_test.go +++ b/go/apps/api/routes/v2_permissions_get_role/200_test.go @@ -20,10 +20,9 @@ func TestSuccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_get_role/400_test.go b/go/apps/api/routes/v2_permissions_get_role/400_test.go index 27995fd258..7e6ccf314c 100644 --- a/go/apps/api/routes/v2_permissions_get_role/400_test.go +++ b/go/apps/api/routes/v2_permissions_get_role/400_test.go @@ -15,10 +15,9 @@ func TestValidationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_get_role/401_test.go b/go/apps/api/routes/v2_permissions_get_role/401_test.go index f6ddbb6497..f9b81befae 100644 --- a/go/apps/api/routes/v2_permissions_get_role/401_test.go +++ b/go/apps/api/routes/v2_permissions_get_role/401_test.go @@ -14,10 +14,9 @@ func TestAuthenticationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_get_role/403_test.go b/go/apps/api/routes/v2_permissions_get_role/403_test.go index a58bee9304..fdba133c80 100644 --- a/go/apps/api/routes/v2_permissions_get_role/403_test.go +++ b/go/apps/api/routes/v2_permissions_get_role/403_test.go @@ -20,10 +20,9 @@ func TestPermissionErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_get_role/404_test.go b/go/apps/api/routes/v2_permissions_get_role/404_test.go index 5df138a0ad..b74a8be2bf 100644 --- a/go/apps/api/routes/v2_permissions_get_role/404_test.go +++ b/go/apps/api/routes/v2_permissions_get_role/404_test.go @@ -16,10 +16,9 @@ func TestNotFoundErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_get_role/handler.go b/go/apps/api/routes/v2_permissions_get_role/handler.go index 981a5d0e66..e3a4ee1094 100644 --- a/go/apps/api/routes/v2_permissions_get_role/handler.go +++ b/go/apps/api/routes/v2_permissions_get_role/handler.go @@ -6,7 +6,6 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/fault" @@ -21,10 +20,9 @@ type Response = openapi.V2PermissionsGetRoleResponseBody // Handler implements zen.Route interface for the v2 permissions get role endpoint type Handler struct { // Services as public fields - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService + Logger logging.Logger + DB db.Database + Keys keys.KeyService } // Method returns the HTTP method this route responds to @@ -42,7 +40,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { h.Logger.Debug("handling request", "requestId", s.RequestID(), "path", "/v2/permissions.getRole") // 1. Authentication - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -54,17 +52,13 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { } // 3. Permission check - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Rbac, - ResourceID: "*", - Action: rbac.ReadRole, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Rbac, + ResourceID: "*", + Action: rbac.ReadRole, + }), + ))) if err != nil { return err } @@ -107,8 +101,9 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { permission := openapi.Permission{ Id: perm.ID, Name: perm.Name, - Description: nil, + Slug: perm.Slug, CreatedAt: perm.CreatedAtM, + Description: nil, } // Add description only if it's valid @@ -123,9 +118,9 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { roleResponse := openapi.RoleWithPermissions{ Id: role.ID, Name: role.Name, - Description: nil, CreatedAt: role.CreatedAtM, Permissions: permissions, + Description: nil, } // Add description only if it's valid diff --git a/go/apps/api/routes/v2_permissions_list_permissions/200_test.go b/go/apps/api/routes/v2_permissions_list_permissions/200_test.go index d46cc98333..0e43df9e6f 100644 --- a/go/apps/api/routes/v2_permissions_list_permissions/200_test.go +++ b/go/apps/api/routes/v2_permissions_list_permissions/200_test.go @@ -21,10 +21,9 @@ func TestSuccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_list_permissions/400_test.go b/go/apps/api/routes/v2_permissions_list_permissions/400_test.go index 60078e003b..6cfe09c3f9 100644 --- a/go/apps/api/routes/v2_permissions_list_permissions/400_test.go +++ b/go/apps/api/routes/v2_permissions_list_permissions/400_test.go @@ -15,10 +15,9 @@ func TestValidationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_list_permissions/401_test.go b/go/apps/api/routes/v2_permissions_list_permissions/401_test.go index 45a238f2e3..5e8bb9c3b1 100644 --- a/go/apps/api/routes/v2_permissions_list_permissions/401_test.go +++ b/go/apps/api/routes/v2_permissions_list_permissions/401_test.go @@ -14,10 +14,9 @@ func TestAuthenticationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_list_permissions/403_test.go b/go/apps/api/routes/v2_permissions_list_permissions/403_test.go index 6f13000958..519119edb4 100644 --- a/go/apps/api/routes/v2_permissions_list_permissions/403_test.go +++ b/go/apps/api/routes/v2_permissions_list_permissions/403_test.go @@ -21,10 +21,9 @@ func TestAuthorizationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_list_permissions/handler.go b/go/apps/api/routes/v2_permissions_list_permissions/handler.go index dac8d59460..cef8e67edc 100644 --- a/go/apps/api/routes/v2_permissions_list_permissions/handler.go +++ b/go/apps/api/routes/v2_permissions_list_permissions/handler.go @@ -6,7 +6,6 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/fault" @@ -22,10 +21,9 @@ type Response = openapi.V2PermissionsListPermissionsResponseBody // Handler implements zen.Route interface for the v2 permissions list permissions endpoint type Handler struct { // Services as public fields - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService + Logger logging.Logger + DB db.Database + Keys keys.KeyService } // Method returns the HTTP method this route responds to @@ -43,7 +41,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { h.Logger.Debug("handling request", "requestId", s.RequestID(), "path", "/v2/permissions.listPermissions") // 1. Authentication - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -58,17 +56,13 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { cursor := ptr.SafeDeref(req.Cursor, "") // 3. Permission check - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Rbac, - ResourceID: "*", - Action: rbac.ReadPermission, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Rbac, + ResourceID: "*", + Action: rbac.ReadPermission, + }), + ))) if err != nil { return err } @@ -105,6 +99,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { permission := openapi.Permission{ Id: perm.ID, Name: perm.Name, + Slug: perm.Slug, Description: nil, CreatedAt: perm.CreatedAtM, } diff --git a/go/apps/api/routes/v2_permissions_list_roles/200_test.go b/go/apps/api/routes/v2_permissions_list_roles/200_test.go index 053963ddd2..70f4f35b66 100644 --- a/go/apps/api/routes/v2_permissions_list_roles/200_test.go +++ b/go/apps/api/routes/v2_permissions_list_roles/200_test.go @@ -20,10 +20,9 @@ func TestSuccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_list_roles/400_test.go b/go/apps/api/routes/v2_permissions_list_roles/400_test.go index 17671bf0b1..cf59f67ca1 100644 --- a/go/apps/api/routes/v2_permissions_list_roles/400_test.go +++ b/go/apps/api/routes/v2_permissions_list_roles/400_test.go @@ -15,10 +15,9 @@ func TestValidationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_list_roles/401_test.go b/go/apps/api/routes/v2_permissions_list_roles/401_test.go index 7361abcd0d..f64aa32986 100644 --- a/go/apps/api/routes/v2_permissions_list_roles/401_test.go +++ b/go/apps/api/routes/v2_permissions_list_roles/401_test.go @@ -14,10 +14,9 @@ func TestAuthenticationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_list_roles/403_test.go b/go/apps/api/routes/v2_permissions_list_roles/403_test.go index 401feec4a1..e0b01bc850 100644 --- a/go/apps/api/routes/v2_permissions_list_roles/403_test.go +++ b/go/apps/api/routes/v2_permissions_list_roles/403_test.go @@ -20,10 +20,9 @@ func TestAuthorizationErrors(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, } h.Register(route) diff --git a/go/apps/api/routes/v2_permissions_list_roles/handler.go b/go/apps/api/routes/v2_permissions_list_roles/handler.go index ebb1a38fbc..28eff01820 100644 --- a/go/apps/api/routes/v2_permissions_list_roles/handler.go +++ b/go/apps/api/routes/v2_permissions_list_roles/handler.go @@ -6,7 +6,6 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/fault" @@ -22,10 +21,9 @@ type Response = openapi.V2PermissionsListRolesResponseBody // Handler implements zen.Route interface for the v2 permissions list roles endpoint type Handler struct { // Services as public fields - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService + Logger logging.Logger + DB db.Database + Keys keys.KeyService } // Method returns the HTTP method this route responds to @@ -43,7 +41,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { h.Logger.Debug("handling request", "requestId", s.RequestID(), "path", "/v2/permissions.listRoles") // 1. Authentication - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -58,17 +56,13 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { cursor := ptr.SafeDeref(req.Cursor, "") // 3. Permission check - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Rbac, - ResourceID: "*", - Action: rbac.ReadRole, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Rbac, + ResourceID: "*", + Action: rbac.ReadRole, + }), + ))) if err != nil { return err } @@ -117,8 +111,9 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { permission := openapi.Permission{ Id: perm.ID, Name: perm.Name, - Description: nil, CreatedAt: perm.CreatedAtM, + Slug: perm.Slug, + Description: nil, } // Add description only if it's valid diff --git a/go/apps/api/routes/v2_ratelimit_delete_override/200_test.go b/go/apps/api/routes/v2_ratelimit_delete_override/200_test.go index 0ffc1f2f46..da40cea0f5 100644 --- a/go/apps/api/routes/v2_ratelimit_delete_override/200_test.go +++ b/go/apps/api/routes/v2_ratelimit_delete_override/200_test.go @@ -44,11 +44,11 @@ func TestDeleteOverrideSuccessfully(t *testing.T) { require.NoError(t, err) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + RatelimitNamespaceByNameCache: h.Caches.RatelimitNamespaceByName, } h.Register(route) diff --git a/go/apps/api/routes/v2_ratelimit_delete_override/400_test.go b/go/apps/api/routes/v2_ratelimit_delete_override/400_test.go index 21d0b01634..163d590ef3 100644 --- a/go/apps/api/routes/v2_ratelimit_delete_override/400_test.go +++ b/go/apps/api/routes/v2_ratelimit_delete_override/400_test.go @@ -18,10 +18,10 @@ func TestBadRequests(t *testing.T) { rootKey := h.CreateRootKey(h.Resources().UserWorkspace.ID) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + RatelimitNamespaceByNameCache: h.Caches.RatelimitNamespaceByName, } h.Register(route) diff --git a/go/apps/api/routes/v2_ratelimit_delete_override/401_test.go b/go/apps/api/routes/v2_ratelimit_delete_override/401_test.go index 3a0ff32e0e..31b74d6c1a 100644 --- a/go/apps/api/routes/v2_ratelimit_delete_override/401_test.go +++ b/go/apps/api/routes/v2_ratelimit_delete_override/401_test.go @@ -14,11 +14,11 @@ func TestUnauthorizedAccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + RatelimitNamespaceByNameCache: h.Caches.RatelimitNamespaceByName, } h.Register(route) diff --git a/go/apps/api/routes/v2_ratelimit_delete_override/403_test.go b/go/apps/api/routes/v2_ratelimit_delete_override/403_test.go index b2927ddc10..d3b2a56256 100644 --- a/go/apps/api/routes/v2_ratelimit_delete_override/403_test.go +++ b/go/apps/api/routes/v2_ratelimit_delete_override/403_test.go @@ -45,11 +45,11 @@ func TestWorkspacePermissions(t *testing.T) { require.NoError(t, err) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + RatelimitNamespaceByNameCache: h.Caches.RatelimitNamespaceByName, } h.Register(route) diff --git a/go/apps/api/routes/v2_ratelimit_delete_override/404_test.go b/go/apps/api/routes/v2_ratelimit_delete_override/404_test.go index 0ba4764be5..396614c050 100644 --- a/go/apps/api/routes/v2_ratelimit_delete_override/404_test.go +++ b/go/apps/api/routes/v2_ratelimit_delete_override/404_test.go @@ -31,11 +31,11 @@ func TestNotFound(t *testing.T) { require.NoError(t, err) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + RatelimitNamespaceByNameCache: h.Caches.RatelimitNamespaceByName, } h.Register(route) diff --git a/go/apps/api/routes/v2_ratelimit_delete_override/handler.go b/go/apps/api/routes/v2_ratelimit_delete_override/handler.go index c7566d4123..4f10a7efdd 100644 --- a/go/apps/api/routes/v2_ratelimit_delete_override/handler.go +++ b/go/apps/api/routes/v2_ratelimit_delete_override/handler.go @@ -10,8 +10,8 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/auditlogs" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/auditlog" + "github.com/unkeyed/unkey/go/pkg/cache" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/fault" @@ -25,12 +25,11 @@ type Response = openapi.V2RatelimitDeleteOverrideResponseBody // Handler implements zen.Route interface for the v2 ratelimit delete override endpoint type Handler struct { - // Services as public fields - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService - Auditlogs auditlogs.AuditLogService + Logger logging.Logger + DB db.Database + Keys keys.KeyService + Auditlogs auditlogs.AuditLogService + RatelimitNamespaceByNameCache cache.Cache[string, db.FindRatelimitNamespace] } // Method returns the HTTP method this route responds to @@ -45,7 +44,7 @@ func (h *Handler) Path() string { // Handle processes the HTTP request func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -80,22 +79,18 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { ) } - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Ratelimit, - ResourceID: namespace.ID, - Action: rbac.DeleteOverride, - }), - rbac.T(rbac.Tuple{ - ResourceType: rbac.Ratelimit, - ResourceID: "*", - Action: rbac.DeleteOverride, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Ratelimit, + ResourceID: namespace.ID, + Action: rbac.DeleteOverride, + }), + rbac.T(rbac.Tuple{ + ResourceType: rbac.Ratelimit, + ResourceID: "*", + Action: rbac.DeleteOverride, + }), + ))) if err != nil { return err } @@ -138,7 +133,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { WorkspaceID: auth.AuthorizedWorkspaceID, Event: auditlog.RatelimitDeleteOverrideEvent, Display: fmt.Sprintf("Deleted override %s.", override.ID), - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorType: auditlog.RootKeyActor, ActorName: "root key", ActorMeta: map[string]any{}, @@ -166,6 +161,8 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { return err } + h.RatelimitNamespaceByNameCache.Remove(ctx, namespace.Name) + return nil }) if err != nil { diff --git a/go/apps/api/routes/v2_ratelimit_get_override/200_test.go b/go/apps/api/routes/v2_ratelimit_get_override/200_test.go index 74a9f6c00e..8e884e62dd 100644 --- a/go/apps/api/routes/v2_ratelimit_get_override/200_test.go +++ b/go/apps/api/routes/v2_ratelimit_get_override/200_test.go @@ -47,10 +47,10 @@ func TestGetOverrideSuccessfully(t *testing.T) { require.NoError(t, err) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + RatelimitNamespaceByNameCache: h.Caches.RatelimitNamespaceByName, } h.Register(route) diff --git a/go/apps/api/routes/v2_ratelimit_get_override/400_test.go b/go/apps/api/routes/v2_ratelimit_get_override/400_test.go index 77e1480155..f813c4fbf4 100644 --- a/go/apps/api/routes/v2_ratelimit_get_override/400_test.go +++ b/go/apps/api/routes/v2_ratelimit_get_override/400_test.go @@ -19,10 +19,10 @@ func TestBadRequests(t *testing.T) { rootKey := h.CreateRootKey(h.Resources().UserWorkspace.ID) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + RatelimitNamespaceByNameCache: h.Caches.RatelimitNamespaceByName, } h.Register(route) @@ -99,11 +99,11 @@ func TestBadRequests(t *testing.T) { require.NotNil(t, res.Body) require.Equal(t, "https://unkey.com/docs/api-reference/errors-v2/unkey/application/invalid_input", res.Body.Error.Type) - require.Equal(t, "You must provide either a namespace ID or name.", res.Body.Error.Detail) + require.Equal(t, "POST request body for '/v2/ratelimit.getOverride' failed to validate schema", res.Body.Error.Detail) require.Equal(t, http.StatusBadRequest, res.Body.Error.Status) require.Equal(t, "Bad Request", res.Body.Error.Title) require.NotEmpty(t, res.Body.Meta.RequestId) - require.Equal(t, len(res.Body.Error.Errors), 0) + require.Equal(t, len(res.Body.Error.Errors), 3) }) t.Run("missing authorization header", func(t *testing.T) { diff --git a/go/apps/api/routes/v2_ratelimit_get_override/401_test.go b/go/apps/api/routes/v2_ratelimit_get_override/401_test.go index 2757de6b30..59d304c7b7 100644 --- a/go/apps/api/routes/v2_ratelimit_get_override/401_test.go +++ b/go/apps/api/routes/v2_ratelimit_get_override/401_test.go @@ -14,10 +14,10 @@ func TestUnauthorizedAccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + RatelimitNamespaceByNameCache: h.Caches.RatelimitNamespaceByName, } h.Register(route) diff --git a/go/apps/api/routes/v2_ratelimit_get_override/403_test.go b/go/apps/api/routes/v2_ratelimit_get_override/403_test.go index ebe44560b6..aa60db1006 100644 --- a/go/apps/api/routes/v2_ratelimit_get_override/403_test.go +++ b/go/apps/api/routes/v2_ratelimit_get_override/403_test.go @@ -45,10 +45,10 @@ func TestWorkspacePermissions(t *testing.T) { require.NoError(t, err) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + RatelimitNamespaceByNameCache: h.Caches.RatelimitNamespaceByName, } h.Register(route) diff --git a/go/apps/api/routes/v2_ratelimit_get_override/404_test.go b/go/apps/api/routes/v2_ratelimit_get_override/404_test.go index 458f2136b6..ac7f1a9fae 100644 --- a/go/apps/api/routes/v2_ratelimit_get_override/404_test.go +++ b/go/apps/api/routes/v2_ratelimit_get_override/404_test.go @@ -31,10 +31,10 @@ func TestOverrideNotFound(t *testing.T) { require.NoError(t, err) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + RatelimitNamespaceByNameCache: h.Caches.RatelimitNamespaceByName, } h.Register(route) diff --git a/go/apps/api/routes/v2_ratelimit_get_override/handler.go b/go/apps/api/routes/v2_ratelimit_get_override/handler.go index f53f9dfc38..fc3004a651 100644 --- a/go/apps/api/routes/v2_ratelimit_get_override/handler.go +++ b/go/apps/api/routes/v2_ratelimit_get_override/handler.go @@ -2,15 +2,20 @@ package handler import ( "context" + "database/sql" + "encoding/json" "net/http" + "strings" "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" + "github.com/unkeyed/unkey/go/pkg/cache" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/fault" + "github.com/unkeyed/unkey/go/pkg/match" "github.com/unkeyed/unkey/go/pkg/otel/logging" + "github.com/unkeyed/unkey/go/pkg/ptr" "github.com/unkeyed/unkey/go/pkg/rbac" "github.com/unkeyed/unkey/go/pkg/zen" ) @@ -21,10 +26,10 @@ type Response = openapi.V2RatelimitGetOverrideResponseBody // Handler implements zen.Route interface for the v2 ratelimit get override endpoint type Handler struct { // Services as public fields - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService + Logger logging.Logger + DB db.Database + Keys keys.KeyService + RatelimitNamespaceByNameCache cache.Cache[string, db.FindRatelimitNamespace] } // Method returns the HTTP method this route responds to @@ -39,7 +44,7 @@ func (h *Handler) Path() string { // Handle processes the HTTP request func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -53,56 +58,80 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { ) } - namespace, err := getNamespace(ctx, h, auth.AuthorizedWorkspaceID, req) - + response, err := db.Query.FindRatelimitNamespace(ctx, h.DB.RO(), db.FindRatelimitNamespaceParams{ + WorkspaceID: auth.AuthorizedWorkspaceID, + Name: sql.NullString{String: ptr.SafeDeref(req.NamespaceName), Valid: req.NamespaceName != nil}, + ID: sql.NullString{String: ptr.SafeDeref(req.NamespaceId), Valid: req.NamespaceId != nil}, + }) if err != nil { - // already handled correctly in getNamespace - return err + if db.IsNotFound(err) { + return fault.New("namespace not found", + fault.Code(codes.Data.RatelimitNamespace.NotFound.URN()), + fault.Internal("namespace not found"), fault.Public("The namespace was not found."), + ) + } + + return fault.Wrap(err, + fault.Code(codes.App.Internal.ServiceUnavailable.URN()), + fault.Internal("database failed to find the namespace"), fault.Public("Error finding the ratelimit namespace."), + ) + } + + namespace := db.FindRatelimitNamespace{ + ID: response.ID, + WorkspaceID: response.WorkspaceID, + Name: response.Name, + CreatedAtM: response.CreatedAtM, + UpdatedAtM: response.UpdatedAtM, + DeletedAtM: response.DeletedAtM, + DirectOverrides: make(map[string]db.FindRatelimitNamespaceLimitOverride), + WildcardOverrides: make([]db.FindRatelimitNamespaceLimitOverride, 0), } - if namespace.WorkspaceID != auth.AuthorizedWorkspaceID { - return fault.New("namespace not found", - fault.Code(codes.Data.RatelimitNamespace.NotFound.URN()), - fault.Internal("namespace was deleted"), fault.Public("This namespace does not exist."), + overrides := make([]db.FindRatelimitNamespaceLimitOverride, 0) + err = json.Unmarshal(response.Overrides.([]byte), &overrides) + if err != nil { + return fault.Wrap(err, + fault.Internal("unable to unmarshal ratelimit overrides"), + fault.Public("We're unable to parse the ratelimits overrides."), ) } - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Ratelimit, - ResourceID: namespace.ID, - Action: rbac.ReadOverride, - }), - rbac.T(rbac.Tuple{ - ResourceType: rbac.Ratelimit, - ResourceID: "*", - Action: rbac.ReadOverride, - }), - ), - ) + for _, override := range overrides { + namespace.DirectOverrides[override.Identifier] = override + if strings.Contains(override.Identifier, "*") { + namespace.WildcardOverrides = append(namespace.WildcardOverrides, override) + } + } + + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Ratelimit, + ResourceID: namespace.ID, + Action: rbac.ReadOverride, + }), + rbac.T(rbac.Tuple{ + ResourceType: rbac.Ratelimit, + ResourceID: "*", + Action: rbac.ReadOverride, + }), + ))) if err != nil { return err } - override, err := db.Query.FindRatelimitOverrideByIdentifier(ctx, h.DB.RO(), db.FindRatelimitOverrideByIdentifierParams{ - WorkspaceID: auth.AuthorizedWorkspaceID, - NamespaceID: namespace.ID, - Identifier: req.Identifier, - }) + override, found, err := matchOverride(req.Identifier, namespace) + if err != nil { + return fault.Wrap(err, + fault.Code(codes.App.Internal.UnexpectedError.URN()), + fault.Internal("error matching overrides"), fault.Public("Error matching ratelimit override"), + ) + } - if db.IsNotFound(err) { + if !found { return fault.New("override not found", fault.Code(codes.Data.RatelimitOverride.NotFound.URN()), - fault.Internal("override not found"), fault.Public("This override does not exist."), - ) - } - if err != nil { - return fault.Wrap(err, - fault.Code(codes.App.Internal.ServiceUnavailable.URN()), - fault.Internal("database failed to find the override"), fault.Public("Error finding the ratelimit override."), + fault.Public("This override does not exist."), ) } @@ -111,52 +140,32 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { RequestId: s.RequestID(), }, Data: openapi.RatelimitOverride{ - OverrideId: override.ID, - Duration: int64(override.Duration), + NamespaceId: namespace.ID, + Limit: override.Limit, + Duration: override.Duration, Identifier: override.Identifier, - NamespaceId: override.NamespaceID, - Limit: int64(override.Limit), }, }) } -func getNamespace(ctx context.Context, h *Handler, workspaceID string, req Request) (namespace db.RatelimitNamespace, err error) { - - switch { - case req.NamespaceId != nil: - { - namespace, err = db.Query.FindRatelimitNamespaceByID(ctx, h.DB.RO(), *req.NamespaceId) - break - } - case req.NamespaceName != nil: - { - namespace, err = db.Query.FindRatelimitNamespaceByName(ctx, h.DB.RO(), db.FindRatelimitNamespaceByNameParams{ - WorkspaceID: workspaceID, - Name: *req.NamespaceName, - }) - break - } - default: - return db.RatelimitNamespace{}, fault.New("namespace id or name required", - fault.Code(codes.App.Validation.InvalidInput.URN()), - fault.Internal("namespace id or name required"), fault.Public("You must provide either a namespace ID or name."), - ) +func matchOverride(identifier string, namespace db.FindRatelimitNamespace) (db.FindRatelimitNamespaceLimitOverride, bool, error) { + if override, ok := namespace.DirectOverrides[identifier]; ok { + return override, true, nil } - if err != nil { + for _, override := range namespace.WildcardOverrides { + ok, err := match.Wildcard(identifier, override.Identifier) + if err != nil { + return db.FindRatelimitNamespaceLimitOverride{}, false, err + } - if db.IsNotFound(err) { - return db.RatelimitNamespace{}, fault.New("namespace not found", - fault.Code(codes.Data.RatelimitNamespace.NotFound.URN()), - fault.Internal("namespace not found"), fault.Public("The namespace was not found."), - ) + if !ok { + continue } - return db.RatelimitNamespace{}, fault.Wrap(err, - fault.Code(codes.App.Internal.ServiceUnavailable.URN()), - fault.Internal("database failed to find the namespace"), fault.Public("Error finding the ratelimit namespace."), - ) + return override, true, nil } - return namespace, nil + + return db.FindRatelimitNamespaceLimitOverride{}, false, nil } diff --git a/go/apps/api/routes/v2_ratelimit_limit/200_test.go b/go/apps/api/routes/v2_ratelimit_limit/200_test.go index 39cff406b3..6d0113dd61 100644 --- a/go/apps/api/routes/v2_ratelimit_limit/200_test.go +++ b/go/apps/api/routes/v2_ratelimit_limit/200_test.go @@ -26,10 +26,8 @@ func TestLimitSuccessfully(t *testing.T) { Keys: h.Keys, Logger: h.Logger, ClickHouse: h.ClickHouse, - Permissions: h.Permissions, Ratelimit: h.Ratelimit, RatelimitNamespaceByNameCache: h.Caches.RatelimitNamespaceByName, - RatelimitOverrideMatchesCache: h.Caches.RatelimitOverridesMatch, } h.Register(route) diff --git a/go/apps/api/routes/v2_ratelimit_limit/400_test.go b/go/apps/api/routes/v2_ratelimit_limit/400_test.go index 4e429d2d67..a77099c626 100644 --- a/go/apps/api/routes/v2_ratelimit_limit/400_test.go +++ b/go/apps/api/routes/v2_ratelimit_limit/400_test.go @@ -23,10 +23,8 @@ func TestBadRequests(t *testing.T) { DB: h.DB, Keys: h.Keys, Logger: h.Logger, - Permissions: h.Permissions, Ratelimit: h.Ratelimit, RatelimitNamespaceByNameCache: h.Caches.RatelimitNamespaceByName, - RatelimitOverrideMatchesCache: h.Caches.RatelimitOverridesMatch, } h.Register(route) @@ -106,11 +104,10 @@ func TestMissingAuthorizationHeader(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Ratelimit: h.Ratelimit, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Ratelimit: h.Ratelimit, } h.Register(route) diff --git a/go/apps/api/routes/v2_ratelimit_limit/401_test.go b/go/apps/api/routes/v2_ratelimit_limit/401_test.go index b2b29409a5..239898523f 100644 --- a/go/apps/api/routes/v2_ratelimit_limit/401_test.go +++ b/go/apps/api/routes/v2_ratelimit_limit/401_test.go @@ -16,10 +16,8 @@ func TestUnauthorizedAccess(t *testing.T) { DB: h.DB, Keys: h.Keys, Logger: h.Logger, - Permissions: h.Permissions, Ratelimit: h.Ratelimit, RatelimitNamespaceByNameCache: h.Caches.RatelimitNamespaceByName, - RatelimitOverrideMatchesCache: h.Caches.RatelimitOverridesMatch, } h.Register(route) diff --git a/go/apps/api/routes/v2_ratelimit_limit/403_test.go b/go/apps/api/routes/v2_ratelimit_limit/403_test.go index 0dd8f6db69..176ad9fa93 100644 --- a/go/apps/api/routes/v2_ratelimit_limit/403_test.go +++ b/go/apps/api/routes/v2_ratelimit_limit/403_test.go @@ -34,10 +34,8 @@ func TestWorkspacePermissions(t *testing.T) { DB: h.DB, Keys: h.Keys, Logger: h.Logger, - Permissions: h.Permissions, Ratelimit: h.Ratelimit, RatelimitNamespaceByNameCache: h.Caches.RatelimitNamespaceByName, - RatelimitOverrideMatchesCache: h.Caches.RatelimitOverridesMatch, } h.Register(route) diff --git a/go/apps/api/routes/v2_ratelimit_limit/404_test.go b/go/apps/api/routes/v2_ratelimit_limit/404_test.go index 2d016bfc36..5b91d0c61f 100644 --- a/go/apps/api/routes/v2_ratelimit_limit/404_test.go +++ b/go/apps/api/routes/v2_ratelimit_limit/404_test.go @@ -23,10 +23,8 @@ func TestNamespaceNotFound(t *testing.T) { DB: h.DB, Keys: h.Keys, Logger: h.Logger, - Permissions: h.Permissions, Ratelimit: h.Ratelimit, RatelimitNamespaceByNameCache: h.Caches.RatelimitNamespaceByName, - RatelimitOverrideMatchesCache: h.Caches.RatelimitOverridesMatch, } h.Register(route) diff --git a/go/apps/api/routes/v2_ratelimit_limit/accuracy_test.go b/go/apps/api/routes/v2_ratelimit_limit/accuracy_test.go index 26ccfc82db..44255c9fb9 100644 --- a/go/apps/api/routes/v2_ratelimit_limit/accuracy_test.go +++ b/go/apps/api/routes/v2_ratelimit_limit/accuracy_test.go @@ -60,11 +60,9 @@ func TestRateLimitAccuracy(t *testing.T) { DB: h.DB, Keys: h.Keys, Logger: h.Logger, - Permissions: h.Permissions, ClickHouse: h.ClickHouse, Ratelimit: h.Ratelimit, RatelimitNamespaceByNameCache: h.Caches.RatelimitNamespaceByName, - RatelimitOverrideMatchesCache: h.Caches.RatelimitOverridesMatch, } h.Register(route) ctx := context.Background() diff --git a/go/apps/api/routes/v2_ratelimit_limit/handler.go b/go/apps/api/routes/v2_ratelimit_limit/handler.go index 87c60c74bd..3693a199a4 100644 --- a/go/apps/api/routes/v2_ratelimit_limit/handler.go +++ b/go/apps/api/routes/v2_ratelimit_limit/handler.go @@ -3,14 +3,15 @@ package v2RatelimitLimit import ( "context" "database/sql" - "errors" + "encoding/json" "net/http" "strconv" + "strings" "time" "github.com/unkeyed/unkey/go/apps/api/openapi" + "github.com/unkeyed/unkey/go/internal/services/caches" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/internal/services/ratelimit" "github.com/unkeyed/unkey/go/pkg/cache" "github.com/unkeyed/unkey/go/pkg/clickhouse" @@ -18,6 +19,7 @@ import ( "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/fault" + "github.com/unkeyed/unkey/go/pkg/match" "github.com/unkeyed/unkey/go/pkg/otel/logging" "github.com/unkeyed/unkey/go/pkg/otel/tracing" "github.com/unkeyed/unkey/go/pkg/rbac" @@ -34,10 +36,8 @@ type Handler struct { Keys keys.KeyService DB db.Database ClickHouse clickhouse.Bufferer - Permissions permissions.PermissionService Ratelimit ratelimit.Service - RatelimitNamespaceByNameCache cache.Cache[db.FindRatelimitNamespaceByNameParams, db.RatelimitNamespace] - RatelimitOverrideMatchesCache cache.Cache[db.ListRatelimitOverrideMatchesParams, []db.RatelimitOverride] + RatelimitNamespaceByNameCache cache.Cache[string, db.FindRatelimitNamespace] TestMode bool } @@ -54,7 +54,7 @@ func (h *Handler) Path() string { // Handle processes the HTTP request func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { // Authenticate the request with a root key - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -72,110 +72,100 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { cost = *req.Cost } - ctx, span := tracing.Start(ctx, "FindRatelimitNamespaceByName") + ctx, span := tracing.Start(ctx, "FindRatelimitNamespace") + namespace, err := h.RatelimitNamespaceByNameCache.SWR(ctx, req.Namespace, func(ctx context.Context) (db.FindRatelimitNamespace, error) { + response, err := db.Query.FindRatelimitNamespace(ctx, h.DB.RO(), db.FindRatelimitNamespaceParams{ + WorkspaceID: auth.AuthorizedWorkspaceID, + Name: sql.NullString{String: req.Namespace, Valid: true}, + ID: sql.NullString{String: "", Valid: false}, + }) + result := db.FindRatelimitNamespace{} // nolint:exhaustruct + if err != nil { + return result, err + } - findNamespaceArgs := db.FindRatelimitNamespaceByNameParams{ - WorkspaceID: auth.AuthorizedWorkspaceID, - Name: req.Namespace, - } - namespace, err := h.RatelimitNamespaceByNameCache.SWR(ctx, findNamespaceArgs, func(ctx context.Context) (db.RatelimitNamespace, error) { - return db.Query.FindRatelimitNamespaceByName(ctx, h.DB.RO(), findNamespaceArgs) - }, func(err error) cache.Op { - if err == nil { - // everything went well and we have a namespace response - return cache.WriteValue + result = db.FindRatelimitNamespace{ + ID: response.ID, + WorkspaceID: response.WorkspaceID, + Name: response.Name, + CreatedAtM: response.CreatedAtM, + UpdatedAtM: response.UpdatedAtM, + DeletedAtM: response.DeletedAtM, + DirectOverrides: make(map[string]db.FindRatelimitNamespaceLimitOverride), + WildcardOverrides: make([]db.FindRatelimitNamespaceLimitOverride, 0), } - if errors.Is(err, sql.ErrNoRows) { - // the response is empty, we need to store that the namespace does not exist - return cache.WriteNull + + overrides := make([]db.FindRatelimitNamespaceLimitOverride, 0) + err = json.Unmarshal(response.Overrides.([]byte), &overrides) + if err != nil { + return result, err } - // this is a noop in the cache - return cache.Noop - }) + for _, override := range overrides { + result.DirectOverrides[override.Identifier] = override + if strings.Contains(override.Identifier, "*") { + result.WildcardOverrides = append(result.WildcardOverrides, override) + } + } + + return result, nil + }, caches.DefaultFindFirstOp) span.End() - if db.IsNotFound(err) { - return fault.New("namespace was deleted", - fault.Code(codes.Data.RatelimitNamespace.NotFound.URN()), - fault.Internal("namespace not found"), fault.Public("This namespace does not exist."), - ) - } + if err != nil { + if db.IsNotFound(err) { + return fault.New("namespace was deleted", + fault.Code(codes.Data.RatelimitNamespace.NotFound.URN()), + fault.Public("This namespace does not exist."), + ) + } + return err } + if namespace.DeletedAtM.Valid { return fault.New("namespace was deleted", fault.Code(codes.Data.RatelimitNamespace.NotFound.URN()), - fault.Internal("namespace not found"), fault.Public("This namespace does not exist."), + fault.Public("This namespace does not exist."), ) } // Verify permissions for rate limiting - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Ratelimit, - ResourceID: namespace.ID, - Action: rbac.Limit, - }), - rbac.T(rbac.Tuple{ - ResourceType: rbac.Ratelimit, - ResourceID: "*", - Action: rbac.Limit, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Ratelimit, + ResourceID: namespace.ID, + Action: rbac.Limit, + }), + rbac.T(rbac.Tuple{ + ResourceType: rbac.Ratelimit, + ResourceID: "*", + Action: rbac.Limit, + }), + ))) if err != nil { return err } - findOverrideMatchesArgs := db.ListRatelimitOverrideMatchesParams{ - WorkspaceID: auth.AuthorizedWorkspaceID, - NamespaceID: namespace.ID, - Identifier: req.Identifier, - } - ctx, overridesSpan := tracing.Start(ctx, "ListRatelimitOverrideMatches") - overrides, err := h.RatelimitOverrideMatchesCache.SWR(ctx, findOverrideMatchesArgs, func(ctx context.Context) ([]db.RatelimitOverride, error) { - return db.Query.ListRatelimitOverrideMatches(ctx, h.DB.RO(), findOverrideMatchesArgs) - }, func(err error) cache.Op { - if err == nil { - // everything went well and we have a namespace response - return cache.WriteValue - } - // this is a noop in the cache - return cache.Noop - }) - - overridesSpan.End() - if db.IsNotFound(err) { - return fault.New("namespace was deleted", - fault.Code(codes.Data.RatelimitNamespace.NotFound.URN()), - fault.Internal("namespace not found"), fault.Public("This namespace does not exist."), - ) - } - if err != nil { - return err - } // Determine limit and duration from override or request var ( limit = req.Limit duration = req.Duration overrideId = "" ) - for _, override := range overrides { - if override.DeletedAtM.Valid { - continue - } - limit = int64(override.Limit) - duration = int64(override.Duration) - overrideId = override.ID - if override.Identifier == req.Identifier { - // Exact match takes precedence - break - } + override, found, err := matchOverride(req.Identifier, namespace) + if err != nil { + return fault.Wrap(err, + fault.Code(codes.App.Internal.UnexpectedError.URN()), + fault.Internal("error matching overrides"), fault.Public("Error matching ratelimit override"), + ) + } + + if found { + limit = override.Limit + duration = override.Duration + overrideId = override.ID } // Apply rate limit @@ -201,7 +191,8 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { result, err := h.Ratelimit.Ratelimit(ctx, limitReq) if err != nil { return fault.Wrap(err, - fault.Internal("rate limit failed"), fault.Public("We're unable to process the rate limit request."), + fault.Internal("rate limit failed"), + fault.Public("We're unable to process the rate limit request."), ) } @@ -227,9 +218,32 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { OverrideId: nil, }, } + if overrideId != "" { res.Data.OverrideId = &overrideId } + // Return success response return s.JSON(http.StatusOK, res) } + +func matchOverride(identifier string, namespace db.FindRatelimitNamespace) (db.FindRatelimitNamespaceLimitOverride, bool, error) { + if override, ok := namespace.DirectOverrides[identifier]; ok { + return override, true, nil + } + + for _, override := range namespace.WildcardOverrides { + ok, err := match.Wildcard(identifier, override.Identifier) + if err != nil { + return db.FindRatelimitNamespaceLimitOverride{}, false, err + } + + if !ok { + continue + } + + return override, true, nil + } + + return db.FindRatelimitNamespaceLimitOverride{}, false, nil +} diff --git a/go/apps/api/routes/v2_ratelimit_limit/simulation_test.gox b/go/apps/api/routes/v2_ratelimit_limit/simulation_test.gox deleted file mode 100644 index b805ce105e..0000000000 --- a/go/apps/api/routes/v2_ratelimit_limit/simulation_test.gox +++ /dev/null @@ -1,450 +0,0 @@ -package v2RatelimitLimit_test - -import ( - "context" - "database/sql" - "fmt" - "testing" - "time" - - "github.com/stretchr/testify/require" - handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_ratelimit_limit" - "github.com/unkeyed/unkey/go/internal/services/ratelimit" - "github.com/unkeyed/unkey/go/pkg/clock" - "github.com/unkeyed/unkey/go/pkg/cluster" - "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/otel/logging" - "github.com/unkeyed/unkey/go/pkg/sim" - "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/uid" -) - -// RateLimitState represents the simulation state for rate limiting tests -type RateLimitState struct { - // Time tracking - Clock *clock.TestClock - - // Namespace information - NamespaceID string - NamespaceName string - - // Identifiers for tracking different users/resources - Identifiers []string - - // Default rate limit configuration - DefaultLimit int64 - DefaultDuration time.Duration - - // Overrides - Overrides map[string]Override - - // Request tracking - Requests map[string][]Request // Map of identifier -> requests - ExpectedRemaining map[string]int64 // Expected remaining counts - LastReset map[string]time.Time // When limits were last reset -} - -// Override represents a rate limit override for a specific identifier -type Override struct { - ID string - Limit int64 - Duration time.Duration -} - -// Request represents a rate limit request made during simulation -type Request struct { - Timestamp time.Time - Cost int64 - Success bool - Remaining int64 -} - -// createSimulation initializes a simulation for rate limiting tests -func createSimulation(t *testing.T, seed sim.Seed) *sim.Simulation[RateLimitState] { - t.Helper() - - return sim.New[RateLimitState](seed, - sim.WithState(func(rng *sim.Rand) *RateLimitState { - // Create a test clock starting at a fixed time - testClock := clock.NewTestClock(time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC)) - - // Generate random namespace details - namespaceID := uid.New("ns_test") - namespaceName := fmt.Sprintf("test_namespace_%s", uid.New("")) - - // Generate random identifiers (users/resources) - identifierCount := 5 + rng.IntN(10) // 5-14 identifiers - identifiers := make([]string, identifierCount) - for i := range identifiers { - identifiers[i] = fmt.Sprintf("user_%s", uid.New("")) - } - - // Generate random limit configuration (reasonable values) - defaultLimit := int64(10 + rng.IntN(490)) // 10-500 - defaultDuration := time.Duration(1+rng.IntN(60)) * time.Second // 1-60 seconds - - return &RateLimitState{ - Clock: testClock, - NamespaceID: namespaceID, - NamespaceName: namespaceName, - Identifiers: identifiers, - DefaultLimit: defaultLimit, - DefaultDuration: defaultDuration, - Overrides: make(map[string]Override), - Requests: make(map[string][]Request), - ExpectedRemaining: make(map[string]int64), - LastReset: make(map[string]time.Time), - } - }), - ) -} - -// validateRateLimitState verifies the simulation state is consistent -func validateRateLimitState(state *RateLimitState) error { - // Validate rate limits for each identifier - for _, identifier := range state.Identifiers { - // Skip if no requests for this identifier - if len(state.Requests[identifier]) == 0 { - continue - } - - // Get limit and duration for this identifier (from override or default) - var limit int64 - var duration time.Duration - - if override, exists := state.Overrides[identifier]; exists { - limit = override.Limit - duration = override.Duration - } else { - limit = state.DefaultLimit - duration = state.DefaultDuration - } - - // Check if we've exceeded our limit - if expected := state.ExpectedRemaining[identifier]; expected < 0 { - return fmt.Errorf("validation error: negative remaining count for %s: %d", - identifier, expected) - } - - // Verify the last successful request didn't exceed the limit - requests := state.Requests[identifier] - if len(requests) > 0 { - lastRequest := requests[len(requests)-1] - if lastRequest.Success && lastRequest.Remaining > limit { - return fmt.Errorf("validation error: remaining count %d exceeds limit %d for %s", - lastRequest.Remaining, limit, identifier) - } - } - } - - return nil -} - -func TestRateLimitSimulation(t *testing.T) { - // Skip this test unless explicitly enabled - sim.CheckEnabled(t) - - // Create simulation with random seed - seed := sim.NewSeed() - simulation := createSimulation(t, seed) - - // Add state validator - simulation = sim.WithValidator(validateRateLimitState)(simulation) - - // Define events (we'll implement these below) - events := []sim.Event[RateLimitState]{ - &CreateOverrideEvent{}, - &DeleteOverrideEvent{}, - &RateLimitRequestEvent{}, - &AdvanceTimeEvent{}, - } - - // Run the simulation - err := simulation.Run(events) - require.NoError(t, err, "Simulation failed with seed %s", seed.String()) - - // Additional assertions if needed - state := simulation.State() - require.NotNil(t, state) - - // Log summary statistics - var totalRequests int - var successfulRequests int - - for _, identifier := range state.Identifiers { - requests := state.Requests[identifier] - totalRequests += len(requests) - - for _, req := range requests { - if req.Success { - successfulRequests++ - } - } - } - - t.Logf("Simulation completed: %d/%d successful requests across %d identifiers", - successfulRequests, totalRequests, len(state.Identifiers)) -} - -// CreateOverrideEvent creates a rate limit override for a random identifier -type CreateOverrideEvent struct{} - -func (e CreateOverrideEvent) Name() string { - return "CreateOverride" -} - -func (e CreateOverrideEvent) Run(rng *sim.Rand, state *RateLimitState) error { - // Randomly select an identifier to create an override for - if len(state.Identifiers) == 0 { - return nil // No identifiers to create overrides for - } - - identifierIndex := rng.IntN(len(state.Identifiers)) - identifier := state.Identifiers[identifierIndex] - - // Generate random override values - limit := int64(5 + rng.IntN(995)) // 5-1000 - duration := time.Duration(rng.IntN(120)+1) * time.Second // 1-120 seconds - - // Create the override - overrideID := uid.New("ovr_test") - state.Overrides[identifier] = Override{ - ID: overrideID, - Limit: limit, - Duration: duration, - } - - // Reset the expected remaining count to the new limit - state.ExpectedRemaining[identifier] = limit - state.LastReset[identifier] = state.Clock.Now() - - return nil -} - -// DeleteOverrideEvent removes a rate limit override -type DeleteOverrideEvent struct{} - -func (e DeleteOverrideEvent) Name() string { - return "DeleteOverride" -} - -func (e DeleteOverrideEvent) Run(rng *sim.Rand, state *RateLimitState) error { - // Find identifiers with overrides - var identifiersWithOverrides []string - for _, id := range state.Identifiers { - if _, exists := state.Overrides[id]; exists { - identifiersWithOverrides = append(identifiersWithOverrides, id) - } - } - - // If no overrides exist, do nothing - if len(identifiersWithOverrides) == 0 { - return nil - } - - // Select a random override to delete - identifierIndex := rng.IntN(len(identifiersWithOverrides)) - identifier := identifiersWithOverrides[identifierIndex] - - // Delete the override - delete(state.Overrides, identifier) - - // Reset the expected remaining count to the default limit - state.ExpectedRemaining[identifier] = state.DefaultLimit - state.LastReset[identifier] = state.Clock.Now() - - return nil -} - -// RateLimitRequestEvent simulates making a rate limit request -type RateLimitRequestEvent struct{} - -func (e RateLimitRequestEvent) Name() string { - return "RateLimitRequest" -} - -func (e RateLimitRequestEvent) Run(rng *sim.Rand, state *RateLimitState) error { - // Select a random identifier - if len(state.Identifiers) == 0 { - return nil // No identifiers to make requests for - } - - identifierIndex := rng.IntN(len(state.Identifiers)) - identifier := state.Identifiers[identifierIndex] - - // Generate a random cost with weighted distribution - costRoll := rng.Float64() - var cost int64 - - switch { - case costRoll < 0.10: // 10% chance of cost=0 - cost = 0 - case costRoll < 0.80: // 70% chance of cost=1 - cost = 1 - case costRoll < 0.95: // 15% chance of cost between 2-5 - cost = int64(2 + rng.IntN(4)) - default: // 5% chance of cost between 6-20 - cost = int64(6 + rng.IntN(15)) - } - - // Get the appropriate limit and duration - var limit int64 - var duration time.Duration - - if override, exists := state.Overrides[identifier]; exists { - limit = override.Limit - duration = override.Duration - } else { - limit = state.DefaultLimit - duration = state.DefaultDuration - } - - // Initialize expected remaining if not already set - if _, exists := state.ExpectedRemaining[identifier]; !exists { - state.ExpectedRemaining[identifier] = limit - state.LastReset[identifier] = state.Clock.Now() - } - - // Check if window has reset - now := state.Clock.Now() - lastReset := state.LastReset[identifier] - if now.Sub(lastReset) >= duration { - // Reset window - state.ExpectedRemaining[identifier] = limit - state.LastReset[identifier] = now - } - - // Calculate if the request should succeed and the remaining count - success := true - expectedRemaining := state.ExpectedRemaining[identifier] - - // Only reduce remaining if cost > 0 - if cost > 0 { - if expectedRemaining < cost { - success = false - } else { - expectedRemaining -= cost - } - } - - // Update the expected remaining count if the request was successful - if success && cost > 0 { - state.ExpectedRemaining[identifier] = expectedRemaining - } - - // Create a ratelimit request - request := handler.Request{ - Identifier: identifier, - Duration: duration, - Limit: limit, - Cost: cost, - } - - // Process the rate limit (in a real implementation, this would call the service) - response := handler.Response{ - Success: success, - Remaining: expectedRemaining, - Reset: state.LastReset[identifier].Add(duration).Unix(), - } - - if _, exists := state.Requests[identifier]; !exists { - state.Requests[identifier] = []Request{} - } - state.Requests[identifier] = append(state.Requests[identifier], request) - - return nil -} - -// AdvanceTimeEvent advances the simulation clock -type AdvanceTimeEvent struct{} - -func (e AdvanceTimeEvent) Name() string { - return "AdvanceTime" -} - -func (e AdvanceTimeEvent) Run(rng *sim.Rand, state *RateLimitState) error { - // Determine how much time to advance - // Sometimes advance a small amount, sometimes jump ahead significantly - var advanceAmount time.Duration - - timeRoll := rng.Float64() - switch { - case timeRoll < 0.70: // 70% small advancement (under 1s) - advanceAmount = time.Duration(rng.IntN(1000)) * time.Millisecond - case timeRoll < 0.90: // 20% medium advancement (1-10s) - advanceAmount = time.Duration(1+rng.IntN(9)) * time.Second - default: // 10% large advancement (10s-2min) - advanceAmount = time.Duration(10+rng.IntN(110)) * time.Second - } - - // Advance the clock - state.Clock.Tick(advanceAmount) - - return nil -} - -// setupRateLimitService creates a test rate limit service using the test clock -func setupRateLimitService(t *testing.T, testClock *clock.TestClock) ratelimit.Service { - t.Helper() - - // Create a logger - logger := logging.New(logging.Config{ - Development: true, - NoColor: true, - }) - - // Create a cluster (use noop for testing) - c := cluster.NewNoop("test_node", "localhost") - - // Create the service using the test clock - svc, err := ratelimit.New(ratelimit.Config{ - Logger: logger, - Cluster: c, - Clock: testClock, - }) - - if err != nil { - t.Fatalf("Failed to create rate limit service: %v", err) - } - - return svc -} - -// setupTestNamespace creates a test namespace in the database -func setupTestNamespace(t *testing.T, h *testutil.Harness, state *RateLimitState) error { - t.Helper() - - // Insert the namespace - return db.Query.InsertRatelimitNamespace(context.Background(), h.DB.RW(), db.InsertRatelimitNamespaceParams{ - ID: state.NamespaceID, - WorkspaceID: h.Resources.UserWorkspace.ID, - Name: state.NamespaceName, - CreatedAt: state.Clock.Now().UnixMilli(), - }) -} - -// setupTestOverride creates a test override in the database -func setupTestOverride(t *testing.T, h *testutil.Harness, state *RateLimitState, identifier string, override Override) error { - t.Helper() - - // Insert the override - return db.Query.InsertRatelimitOverride(context.Background(), h.DB.RW(), db.InsertRatelimitOverrideParams{ - ID: override.ID, - WorkspaceID: h.Resources.UserWorkspace.ID, - NamespaceID: state.NamespaceID, - Identifier: identifier, - Limit: int32(override.Limit), - Duration: int32(override.Duration.Milliseconds()), - CreatedAt: state.Clock.Now().UnixMilli(), - }) -} - -// deleteTestOverride soft-deletes an override from the database -func deleteTestOverride(t *testing.T, h *testutil.Harness, state *RateLimitState, overrideID string) error { - t.Helper() - - return db.Query.SoftDeleteRatelimitOverride(context.Background(), h.DB.RW(), db.SoftDeleteRatelimitOverrideParams{ - ID: overrideID, - Now: sql.NullInt64{Valid: true, Int64: state.Clock.Now().UnixMilli()}, - }) -} diff --git a/go/apps/api/routes/v2_ratelimit_list_overrides/200_test.go b/go/apps/api/routes/v2_ratelimit_list_overrides/200_test.go index e7b83f0d24..6546aa2d09 100644 --- a/go/apps/api/routes/v2_ratelimit_list_overrides/200_test.go +++ b/go/apps/api/routes/v2_ratelimit_list_overrides/200_test.go @@ -47,10 +47,9 @@ func TestListOverridesSuccessfully(t *testing.T) { require.NoError(t, err) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, } h.Register(route) diff --git a/go/apps/api/routes/v2_ratelimit_list_overrides/400_test.go b/go/apps/api/routes/v2_ratelimit_list_overrides/400_test.go index c2e436c8d3..667b8593fd 100644 --- a/go/apps/api/routes/v2_ratelimit_list_overrides/400_test.go +++ b/go/apps/api/routes/v2_ratelimit_list_overrides/400_test.go @@ -17,10 +17,9 @@ func TestBadRequests(t *testing.T) { rootKey := h.CreateRootKey(h.Resources().UserWorkspace.ID) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, } h.Register(route) diff --git a/go/apps/api/routes/v2_ratelimit_list_overrides/401_test.go b/go/apps/api/routes/v2_ratelimit_list_overrides/401_test.go index 6e5b65ea83..9e4720c312 100644 --- a/go/apps/api/routes/v2_ratelimit_list_overrides/401_test.go +++ b/go/apps/api/routes/v2_ratelimit_list_overrides/401_test.go @@ -13,10 +13,9 @@ func TestUnauthorizedAccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, } h.Register(route) diff --git a/go/apps/api/routes/v2_ratelimit_list_overrides/403_test.go b/go/apps/api/routes/v2_ratelimit_list_overrides/403_test.go index 6b8eb8acc5..17577434da 100644 --- a/go/apps/api/routes/v2_ratelimit_list_overrides/403_test.go +++ b/go/apps/api/routes/v2_ratelimit_list_overrides/403_test.go @@ -45,10 +45,9 @@ func TestWorkspacePermissions(t *testing.T) { require.NoError(t, err) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, } h.Register(route) diff --git a/go/apps/api/routes/v2_ratelimit_list_overrides/404_test.go b/go/apps/api/routes/v2_ratelimit_list_overrides/404_test.go index dcc080e0dc..70aed9829e 100644 --- a/go/apps/api/routes/v2_ratelimit_list_overrides/404_test.go +++ b/go/apps/api/routes/v2_ratelimit_list_overrides/404_test.go @@ -31,10 +31,9 @@ func TestOverrideNotFound(t *testing.T) { require.NoError(t, err) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, } h.Register(route) diff --git a/go/apps/api/routes/v2_ratelimit_list_overrides/handler.go b/go/apps/api/routes/v2_ratelimit_list_overrides/handler.go index e01a4a925f..a8337c94d6 100644 --- a/go/apps/api/routes/v2_ratelimit_list_overrides/handler.go +++ b/go/apps/api/routes/v2_ratelimit_list_overrides/handler.go @@ -6,7 +6,6 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/fault" @@ -21,10 +20,9 @@ type Response = openapi.V2RatelimitListOverridesResponseBody // Handler implements zen.Route interface for the v2 ratelimit list overrides endpoint type Handler struct { // Services as public fields - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService + Logger logging.Logger + DB db.Database + Keys keys.KeyService } // Method returns the HTTP method this route responds to @@ -39,7 +37,7 @@ func (h *Handler) Path() string { // Handle processes the HTTP request func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -72,22 +70,18 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { ) } - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Ratelimit, - ResourceID: namespace.ID, - Action: rbac.ReadOverride, - }), - rbac.T(rbac.Tuple{ - ResourceType: rbac.Ratelimit, - ResourceID: "*", - Action: rbac.ReadOverride, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Ratelimit, + ResourceID: namespace.ID, + Action: rbac.ReadOverride, + }), + rbac.T(rbac.Tuple{ + ResourceType: rbac.Ratelimit, + ResourceID: "*", + Action: rbac.ReadOverride, + }), + ))) if err != nil { return err } diff --git a/go/apps/api/routes/v2_ratelimit_set_override/200_test.go b/go/apps/api/routes/v2_ratelimit_set_override/200_test.go index 86b5cf92fc..2b3553a657 100644 --- a/go/apps/api/routes/v2_ratelimit_set_override/200_test.go +++ b/go/apps/api/routes/v2_ratelimit_set_override/200_test.go @@ -30,11 +30,11 @@ func TestSetOverrideSuccessfully(t *testing.T) { require.NoError(t, err) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + RatelimitNamespaceByNameCache: h.Caches.RatelimitNamespaceByName, } h.Register(route) diff --git a/go/apps/api/routes/v2_ratelimit_set_override/400_test.go b/go/apps/api/routes/v2_ratelimit_set_override/400_test.go index c3bba7bf5b..30b6f13c8a 100644 --- a/go/apps/api/routes/v2_ratelimit_set_override/400_test.go +++ b/go/apps/api/routes/v2_ratelimit_set_override/400_test.go @@ -18,10 +18,10 @@ func TestBadRequests(t *testing.T) { rootKey := h.CreateRootKey(h.Resources().UserWorkspace.ID, "ratelimit.*.set_override") route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + RatelimitNamespaceByNameCache: h.Caches.RatelimitNamespaceByName, } h.Register(route) diff --git a/go/apps/api/routes/v2_ratelimit_set_override/401_test.go b/go/apps/api/routes/v2_ratelimit_set_override/401_test.go index bab627bda1..32109244eb 100644 --- a/go/apps/api/routes/v2_ratelimit_set_override/401_test.go +++ b/go/apps/api/routes/v2_ratelimit_set_override/401_test.go @@ -14,11 +14,11 @@ func TestUnauthorizedAccess(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + RatelimitNamespaceByNameCache: h.Caches.RatelimitNamespaceByName, } h.Register(route) diff --git a/go/apps/api/routes/v2_ratelimit_set_override/403_test.go b/go/apps/api/routes/v2_ratelimit_set_override/403_test.go index dff00fcb63..8018a6126a 100644 --- a/go/apps/api/routes/v2_ratelimit_set_override/403_test.go +++ b/go/apps/api/routes/v2_ratelimit_set_override/403_test.go @@ -31,11 +31,11 @@ func TestWorkspacePermissions(t *testing.T) { require.NoError(t, err) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + RatelimitNamespaceByNameCache: h.Caches.RatelimitNamespaceByName, } h.Register(route) diff --git a/go/apps/api/routes/v2_ratelimit_set_override/404_test.go b/go/apps/api/routes/v2_ratelimit_set_override/404_test.go index 4126628e4c..3eac1a6715 100644 --- a/go/apps/api/routes/v2_ratelimit_set_override/404_test.go +++ b/go/apps/api/routes/v2_ratelimit_set_override/404_test.go @@ -15,11 +15,11 @@ func TestNamespaceNotFound(t *testing.T) { h := testutil.NewHarness(t) route := &handler.Handler{ - DB: h.DB, - Keys: h.Keys, - Logger: h.Logger, - Permissions: h.Permissions, - Auditlogs: h.Auditlogs, + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + Auditlogs: h.Auditlogs, + RatelimitNamespaceByNameCache: h.Caches.RatelimitNamespaceByName, } h.Register(route) diff --git a/go/apps/api/routes/v2_ratelimit_set_override/handler.go b/go/apps/api/routes/v2_ratelimit_set_override/handler.go index 99a9f79f50..09911f28bb 100644 --- a/go/apps/api/routes/v2_ratelimit_set_override/handler.go +++ b/go/apps/api/routes/v2_ratelimit_set_override/handler.go @@ -11,8 +11,8 @@ import ( "github.com/unkeyed/unkey/go/apps/api/openapi" "github.com/unkeyed/unkey/go/internal/services/auditlogs" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/pkg/auditlog" + "github.com/unkeyed/unkey/go/pkg/cache" "github.com/unkeyed/unkey/go/pkg/codes" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/fault" @@ -28,11 +28,11 @@ type Response = openapi.V2RatelimitSetOverrideResponseBody // Handler implements zen.Route interface for the v2 ratelimit set override endpoint type Handler struct { // Services as public fields - Logger logging.Logger - DB db.Database - Keys keys.KeyService - Permissions permissions.PermissionService - Auditlogs auditlogs.AuditLogService + Logger logging.Logger + DB db.Database + Keys keys.KeyService + Auditlogs auditlogs.AuditLogService + RatelimitNamespaceByNameCache cache.Cache[string, db.FindRatelimitNamespace] } // Method returns the HTTP method this route responds to @@ -47,7 +47,7 @@ func (h *Handler) Path() string { // Handle processes the HTTP request func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { - auth, err := h.Keys.VerifyRootKey(ctx, s) + auth, err := h.Keys.GetRootKey(ctx, s) if err != nil { return err } @@ -80,22 +80,18 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { ) } - err = h.Permissions.Check( - ctx, - auth.KeyID, - rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Ratelimit, - ResourceID: namespace.ID, - Action: rbac.SetOverride, - }), - rbac.T(rbac.Tuple{ - ResourceType: rbac.Ratelimit, - ResourceID: "*", - Action: rbac.SetOverride, - }), - ), - ) + err = auth.Verify(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Ratelimit, + ResourceID: namespace.ID, + Action: rbac.SetOverride, + }), + rbac.T(rbac.Tuple{ + ResourceType: rbac.Ratelimit, + ResourceID: "*", + Action: rbac.SetOverride, + }), + ))) if err != nil { return fault.Wrap(err, fault.Internal("unable to check permissions"), fault.Public("We're unable to check the permissions of your key."), @@ -124,7 +120,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { { WorkspaceID: auth.AuthorizedWorkspaceID, Event: auditlog.RatelimitSetOverrideEvent, - ActorID: auth.KeyID, + ActorID: auth.Key.ID, ActorType: auditlog.RootKeyActor, ActorName: "root key", ActorMeta: map[string]any{}, @@ -146,6 +142,8 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { return "", err } + h.RatelimitNamespaceByNameCache.Remove(ctx, namespace.Name) + return overrideID, nil }) if err != nil { diff --git a/go/apps/api/run.go b/go/apps/api/run.go index 5a2f4419d7..5169007fdf 100644 --- a/go/apps/api/run.go +++ b/go/apps/api/run.go @@ -11,7 +11,6 @@ import ( "github.com/unkeyed/unkey/go/internal/services/auditlogs" "github.com/unkeyed/unkey/go/internal/services/caches" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/internal/services/ratelimit" "github.com/unkeyed/unkey/go/pkg/clickhouse" "github.com/unkeyed/unkey/go/pkg/clock" @@ -20,6 +19,7 @@ import ( "github.com/unkeyed/unkey/go/pkg/otel" "github.com/unkeyed/unkey/go/pkg/otel/logging" "github.com/unkeyed/unkey/go/pkg/prometheus" + "github.com/unkeyed/unkey/go/pkg/rbac" "github.com/unkeyed/unkey/go/pkg/shutdown" "github.com/unkeyed/unkey/go/pkg/vault" "github.com/unkeyed/unkey/go/pkg/vault/storage" @@ -59,12 +59,15 @@ func Run(ctx context.Context, cfg Config) error { if cfg.InstanceID != "" { logger = logger.With(slog.String("instanceID", cfg.InstanceID)) } + if cfg.Platform != "" { logger = logger.With(slog.String("platform", cfg.Platform)) } + if cfg.Region != "" { logger = logger.With(slog.String("region", cfg.Region)) } + if version.Version != "" { logger = logger.With(slog.String("version", version.Version)) } @@ -142,8 +145,8 @@ func Run(ctx context.Context, cfg Config) error { }) if err != nil { return fmt.Errorf("unable to create server: %w", err) - } + shutdowns.RegisterCtx(srv.Shutdown) validator, err := validation.New() @@ -151,17 +154,6 @@ func Run(ctx context.Context, cfg Config) error { return fmt.Errorf("unable to create validator: %w", err) } - keySvc, err := keys.New(keys.Config{ - Logger: logger, - DB: db, - Clock: clk, - KeyCache: caches.KeyByHash, - WorkspaceCache: caches.WorkspaceByID, - }) - if err != nil { - return fmt.Errorf("unable to create key service: %w", err) - } - ctr, err := counter.NewRedis(counter.RedisConfig{ RedisURL: cfg.RedisUrl, Logger: logger, @@ -179,46 +171,56 @@ func Run(ctx context.Context, cfg Config) error { return fmt.Errorf("unable to create ratelimit service: %w", err) } - p, err := permissions.New(permissions.Config{ - DB: db, - Logger: logger, - Clock: clk, - Cache: caches.PermissionsByKeyId, + keySvc, err := keys.New(keys.Config{ + Logger: logger, + DB: db, + KeyCache: caches.VerificationKeyByHash, + RateLimiter: rlSvc, + RBAC: rbac.New(), + Clickhouse: ch, }) if err != nil { - return fmt.Errorf("unable to create permissions service: %w", err) + return fmt.Errorf("unable to create key service: %w", err) } - vaultStorage, err := storage.NewMemory(storage.MemoryConfig{ - Logger: logger, - }) - if err != nil { - return fmt.Errorf("unable to create vault storage: %w", err) + var vaultSvc *vault.Service + if len(cfg.VaultMasterKeys) > 0 && cfg.VaultS3 != nil { + vaultStorage, err := storage.NewS3(storage.S3Config{ + Logger: logger, + S3URL: cfg.VaultS3.URL, + S3Bucket: cfg.VaultS3.Bucket, + S3AccessKeyID: cfg.VaultS3.AccessKeyID, + S3AccessKeySecret: cfg.VaultS3.SecretAccessKey, + }) + if err != nil { + return fmt.Errorf("unable to create vault storage: %w", err) + } + + vaultSvc, err = vault.New(vault.Config{ + Logger: logger, + Storage: vaultStorage, + MasterKeys: cfg.VaultMasterKeys, + }) + if err != nil { + return fmt.Errorf("unable to create vault service: %w", err) + } } - vaultSvc, err := vault.New(vault.Config{ - Logger: logger, - Storage: vaultStorage, - MasterKeys: cfg.VaultMasterKeys, + auditlogSvc := auditlogs.New(auditlogs.Config{ + Logger: logger, + DB: db, }) - if err != nil { - return fmt.Errorf("unable to create vault service: %w", err) - } routes.Register(srv, &routes.Services{ - Logger: logger, - Database: db, - ClickHouse: ch, - Keys: keySvc, - Validator: validator, - Ratelimit: rlSvc, - Permissions: p, - Auditlogs: auditlogs.New(auditlogs.Config{ - Logger: logger, - DB: db, - }), - Caches: caches, - Vault: vaultSvc, + Logger: logger, + Database: db, + ClickHouse: ch, + Keys: keySvc, + Validator: validator, + Ratelimit: rlSvc, + Auditlogs: auditlogSvc, + Caches: caches, + Vault: vaultSvc, }) go func() { diff --git a/go/cmd/api/main.go b/go/cmd/api/main.go index bb9f22d575..160530d427 100644 --- a/go/cmd/api/main.go +++ b/go/cmd/api/main.go @@ -132,6 +132,8 @@ var Cmd = &cli.Command{ Value: 0, Required: false, }, + + // Vault Configuration &cli.StringSliceFlag{ Name: "vault-master-keys", Usage: "Vault master keys for encryption", @@ -139,6 +141,36 @@ var Cmd = &cli.Command{ Value: []string{}, Required: false, }, + + // S3 Configuration + &cli.StringFlag{ + Name: "vault-s3-url", + Usage: "S3 Compatible Endpoint URL ", + Sources: cli.EnvVars("UNKEY_VAULT_S3_URL"), + Value: "", + Required: false, + }, + &cli.StringFlag{ + Name: "vault-s3-bucket", + Usage: "S3 bucket name", + Sources: cli.EnvVars("UNKEY_VAULT_S3_BUCKET"), + Value: "", + Required: false, + }, + &cli.StringFlag{ + Name: "vault-s3-access-key-id", + Usage: "S3 access key ID", + Sources: cli.EnvVars("UNKEY_VAULT_S3_ACCESS_KEY_ID"), + Value: "", + Required: false, + }, + &cli.StringFlag{ + Name: "vault-s3-secret-access-key", + Usage: "S3 secret access key", + Sources: cli.EnvVars("UNKEY_VAULT_S3_SECRET_ACCESS_KEY"), + Value: "", + Required: false, + }, }, Action: action, @@ -162,6 +194,16 @@ func action(ctx context.Context, cmd *cli.Command) error { } } + var vaultS3Config *api.S3Config + if cmd.String("vault-s3-url") != "" { + vaultS3Config = &api.S3Config{ + URL: cmd.String("vault-s3-url"), + Bucket: cmd.String("vault-s3-bucket"), + AccessKeyID: cmd.String("vault-s3-access-key-id"), + SecretAccessKey: cmd.String("vault-s3-secret-access-key"), + } + } + config := api.Config{ // Basic configuration Platform: cmd.String("platform"), @@ -189,7 +231,9 @@ func action(ctx context.Context, cmd *cli.Command) error { Clock: clock.New(), TestMode: cmd.Bool("test-mode"), + // Vault configuration VaultMasterKeys: cmd.StringSlice("vault-master-keys"), + VaultS3: vaultS3Config, } err := config.Validate() diff --git a/go/internal/services/caches/caches.go b/go/internal/services/caches/caches.go index b6f9e8f371..eb2c96447d 100644 --- a/go/internal/services/caches/caches.go +++ b/go/internal/services/caches/caches.go @@ -15,27 +15,15 @@ import ( type Caches struct { // RatelimitNamespaceByName caches ratelimit namespace lookups by name. // Keys are db.FindRatelimitNamespaceByNameParams and values are db.RatelimitNamespace. - RatelimitNamespaceByName cache.Cache[db.FindRatelimitNamespaceByNameParams, db.RatelimitNamespace] + RatelimitNamespaceByName cache.Cache[string, db.FindRatelimitNamespace] - // RatelimitOverridesMatch caches ratelimit override matches for specific criteria. - // Keys are db.ListRatelimitOverrideMatchesParams and values are slices of db.RatelimitOverride. - RatelimitOverridesMatch cache.Cache[db.ListRatelimitOverrideMatchesParams, []db.RatelimitOverride] - - // KeyByHash caches API key lookups by their hash. - // Keys are string (hash) and values are db.Key. - KeyByHash cache.Cache[string, db.Key] - - // PermissionsByKeyId caches permission strings for a given key ID. - // Keys are string (key ID) and values are slices of string representing permissions. - PermissionsByKeyId cache.Cache[string, []string] - - // WorkspaceByID caches workspace lookups by their ID. - // Keys are string (workspace ID) and values are db.Workspace. - WorkspaceByID cache.Cache[string, db.Workspace] + // VerificationKeyByHash caches verification key lookups by their hash. + // Keys are string (hash) and values are db.VerificationKey. + VerificationKeyByHash cache.Cache[string, db.FindKeyForVerificationRow] + // ApiByID caches API lookups by their ID. + // Keys are string (ID) and values are db.Api. ApiByID cache.Cache[string, db.Api] - - IdentityByID cache.Cache[string, db.Identity] } // Config defines the configuration options for initializing caches. @@ -77,64 +65,24 @@ type Config struct { // // Use the caches // key, err := caches.KeyByHash.Get(ctx, "some-hash") func New(config Config) (Caches, error) { - - ratelimitNamespace, err := cache.New(cache.Config[db.FindRatelimitNamespaceByNameParams, db.RatelimitNamespace]{ + ratelimitNamespace, err := cache.New(cache.Config[string, db.FindRatelimitNamespace]{ Fresh: time.Minute, Stale: 24 * time.Hour, Logger: config.Logger, MaxSize: 1_000_000, - Resource: "ratelimit_namespace_by_name", + Resource: "ratelimit_namespace", Clock: config.Clock, }) if err != nil { return Caches{}, err } - ratelimitOverridesMatch, err := cache.New(cache.Config[db.ListRatelimitOverrideMatchesParams, []db.RatelimitOverride]{ - Fresh: time.Minute, + verificationKeyByHash, err := cache.New(cache.Config[string, db.FindKeyForVerificationRow]{ + Fresh: 30 * time.Second, Stale: 24 * time.Hour, Logger: config.Logger, MaxSize: 1_000_000, - Resource: "ratelimit_overrides", - Clock: config.Clock, - }) - if err != nil { - return Caches{}, err - } - - keyByHash, err := cache.New(cache.Config[string, db.Key]{ - Fresh: 10 * time.Second, - Stale: 24 * time.Hour, - Logger: config.Logger, - MaxSize: 1_000_000, - - Resource: "key_by_hash", - Clock: config.Clock, - }) - if err != nil { - return Caches{}, err - } - - permissionsByKeyId, err := cache.New(cache.Config[string, []string]{ - Fresh: 10 * time.Second, - Stale: 24 * time.Hour, - Logger: config.Logger, - MaxSize: 1_000_000, - - Resource: "permissions_by_key_id", - Clock: config.Clock, - }) - if err != nil { - return Caches{}, err - } - - workspaceByID, err := cache.New(cache.Config[string, db.Workspace]{ - Fresh: 10 * time.Second, - Stale: 24 * time.Hour, - Logger: config.Logger, - MaxSize: 1_000_000, - - Resource: "workspace_by_id", + Resource: "verification_key_by_hash", Clock: config.Clock, }) if err != nil { @@ -154,26 +102,9 @@ func New(config Config) (Caches, error) { return Caches{}, err } - identityByID, err := cache.New(cache.Config[string, db.Identity]{ - Fresh: 10 * time.Second, - Stale: 24 * time.Hour, - Logger: config.Logger, - MaxSize: 1_000_000, - - Resource: "identity_by_id", - Clock: config.Clock, - }) - if err != nil { - return Caches{}, err - } - return Caches{ RatelimitNamespaceByName: middleware.WithTracing(ratelimitNamespace), - RatelimitOverridesMatch: middleware.WithTracing(ratelimitOverridesMatch), - KeyByHash: middleware.WithTracing(keyByHash), - PermissionsByKeyId: middleware.WithTracing(permissionsByKeyId), - WorkspaceByID: middleware.WithTracing(workspaceByID), ApiByID: middleware.WithTracing(apiById), - IdentityByID: middleware.WithTracing(identityByID), + VerificationKeyByHash: middleware.WithTracing(verificationKeyByHash), }, nil } diff --git a/go/internal/services/caches/op.go b/go/internal/services/caches/op.go index f379580570..14427ca374 100644 --- a/go/internal/services/caches/op.go +++ b/go/internal/services/caches/op.go @@ -13,11 +13,12 @@ func DefaultFindFirstOp(err error) cache.Op { // everything went well and we have a row response return cache.WriteValue } + if errors.Is(err, sql.ErrNoRows) { // the response is empty, we need to store that the row does not exist return cache.WriteNull } + // this is a noop in the cache return cache.Noop - } diff --git a/go/internal/services/keys/doc.go b/go/internal/services/keys/doc.go new file mode 100644 index 0000000000..938d36c6f1 --- /dev/null +++ b/go/internal/services/keys/doc.go @@ -0,0 +1,108 @@ +/* +Package keys implements a comprehensive key management and verification system for API keys +with support for rate limiting, usage tracking, permissions, and workspace isolation. + +# Architecture + +The keys service provides a unified interface for managing API keys throughout their lifecycle: + + - Key Creation: Secure generation of API keys with customizable prefixes and byte lengths + - Key Verification: Multi-stage validation with configurable options for different use cases + - Key Retrieval: Cached access to key metadata and authorization information + - Root Key Management: Special handling for workspace-level administrative keys + +# Key Verification System + +The verification system uses a flexible, option-based approach that supports: + + 1. Basic validation: existence, enabled status, expiration + 2. Usage limiting: credit-based consumption tracking + 3. Rate limiting: configurable time-window based limits + 4. Permission checking: RBAC-based authorization + 5. IP whitelisting: network-level access control + 6. Workspace isolation: multi-tenant security boundaries + +# Usage + +To create a new keys service: + + svc, err := keys.New(keys.Config{ + Logger: logger, + DB: database, + RateLimiter: rateLimiter, + UsageLimiter: usageLimiter, + RBAC: rbac, + Clickhouse: clickhouse, + KeyCache: keyCache, + WorkspaceCache: workspaceCache, + }) + +To verify a key with rate limiting and permissions: + + key, err := svc.Get(ctx, session, rawKey) + if err != nil { + return err + } + + err = key.Verify(ctx, + keys.WithCredits(1), + keys.WithPermissions(rbac.PermissionQuery{ + Action: "read", + Resource: "api.key", + }), + keys.WithRateLimits([]openapi.KeysVerifyKeyRatelimit{ + {Name: "requests", Limit: ptr.Int32(100), Duration: ptr.Int64(60000)}, + }), + ) + + if !key.Valid { + // Handle validation failure based on key.Status + } + +# Key Statuses + +The system defines comprehensive status codes for different validation outcomes: + + - VALID: Key passed all validation checks + - NOT_FOUND: Key does not exist in the system + - DISABLED: Key exists but is disabled + - EXPIRED: Key has passed its expiration time + - FORBIDDEN: Access denied (IP whitelist, etc.) + - INSUFFICIENT_PERMISSIONS: RBAC validation failed + - RATE_LIMITED: Rate limit exceeded + - USAGE_EXCEEDED: Usage credit limit exceeded + - WORKSPACE_DISABLED: Associated workspace is disabled + - WORKSPACE_NOT_FOUND: Associated workspace does not exist + +# Root Key Handling + +Root keys receive special treatment with automatic fault error conversion: +- Validation failures are immediately converted to fault errors +- Used for workspace-level administrative operations +- Identified by the presence of ForWorkspaceID in the key metadata + +# Thread Safety + +The service is designed to be thread-safe and can handle concurrent requests across +multiple goroutines. All cache operations and state modifications are properly synchronized. + +# Error Handling + +The service provides structured error handling with: +- Fault errors for client-facing validation failures +- System errors for internal service problems +- Comprehensive error codes matching the OpenAPI specification +- Detailed logging for debugging and monitoring + +# Performance Considerations + +The service includes several performance optimizations: +- Multi-level caching (key cache, workspace cache) +- Stale-while-revalidate cache patterns +- Batched telemetry data collection +- Efficient database queries with proper indexing + +See the KeyService interface and KeyVerifier struct for detailed documentation +of the API contract and available methods. +*/ +package keys diff --git a/go/internal/services/keys/get.go b/go/internal/services/keys/get.go new file mode 100644 index 0000000000..06dcb0e5f9 --- /dev/null +++ b/go/internal/services/keys/get.go @@ -0,0 +1,177 @@ +package keys + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/unkeyed/unkey/go/internal/services/caches" + "github.com/unkeyed/unkey/go/pkg/assert" + "github.com/unkeyed/unkey/go/pkg/db" + "github.com/unkeyed/unkey/go/pkg/fault" + "github.com/unkeyed/unkey/go/pkg/hash" + "github.com/unkeyed/unkey/go/pkg/otel/tracing" + "github.com/unkeyed/unkey/go/pkg/zen" +) + +// GetRootKey retrieves and validates a root key from the session's Authorization header. +// Root keys are special administrative keys that can access workspace-level operations. +// Validation failures are immediately converted to fault errors for root keys. +func (s *service) GetRootKey(ctx context.Context, sess *zen.Session) (*KeyVerifier, error) { + ctx, span := tracing.Start(ctx, "keys.GetRootKey") + defer span.End() + + rootKey, err := zen.Bearer(sess) + if err != nil { + return nil, fault.Wrap(err, + fault.Internal("no bearer"), + fault.Public("You must provide a valid root key in the Authorization header in the format 'Bearer ROOT_KEY'."), + ) + } + + key, err := s.Get(ctx, sess, rootKey) + if err != nil { + return nil, err + } + + // For root keys, convert validation failures to proper fault errors immediately + if key.Status != StatusValid { + return nil, fault.Wrap( + key.ToFault(), + fault.Internal("invalid root key"), + fault.Public("The provided root key is invalid."), + ) + } + + return key, nil +} + +// Get retrieves a key from the database and performs basic validation checks. +// It returns a KeyVerifier that can be used for further validation with specific options. +// For normal keys, validation failures are indicated by KeyVerifier.Valid=false. +func (s *service) Get(ctx context.Context, sess *zen.Session, rawKey string) (*KeyVerifier, error) { + ctx, span := tracing.Start(ctx, "keys.Get") + defer span.End() + + err := assert.NotEmpty(rawKey) + if err != nil { + return nil, fault.Wrap(err, fault.Internal("rawKey is empty")) + } + + h := hash.Sha256(rawKey) + key, err := s.keyCache.SWR(ctx, h, func(ctx context.Context) (db.FindKeyForVerificationRow, error) { + return db.Query.FindKeyForVerification(ctx, s.db.RO(), h) + }, caches.DefaultFindFirstOp) + if err != nil { + if db.IsNotFound(err) { + // nolint:exhaustruct + return &KeyVerifier{ + Status: StatusNotFound, + message: "key does not exist", + }, nil + } + + return nil, fault.Wrap( + err, + fault.Internal("unable to load key"), + fault.Public("We could not load the requested key."), + ) + } + + // ForWorkspace set but that doesn't exist + if key.ForWorkspaceID.Valid && !key.ForWorkspaceEnabled.Valid { + // nolint:exhaustruct + return &KeyVerifier{ + Status: StatusWorkspaceNotFound, + message: "workspace not found", + }, nil + } + + if !key.WorkspaceEnabled || (key.ForWorkspaceEnabled.Valid && !key.ForWorkspaceEnabled.Bool) { + // nolint:exhaustruct + return &KeyVerifier{ + Status: StatusWorkspaceDisabled, + message: "workspace is disabled", + }, nil + } + + // The DB returns this in array format and an empty array if not found + var roles, permissions []string + var ratelimitArr []db.KeyFindForVerificationRatelimit + err = json.Unmarshal(key.Roles.([]byte), &roles) + if err != nil { + return nil, err + } + err = json.Unmarshal(key.Permissions.([]byte), &permissions) + if err != nil { + return nil, err + } + err = json.Unmarshal(key.Ratelimits.([]byte), &ratelimitArr) + if err != nil { + return nil, err + } + + // Convert rate limits array to map (key name -> config) + // Key rate limits take precedence over identity rate limits + ratelimitConfigs := make(map[string]db.KeyFindForVerificationRatelimit) + for _, rl := range ratelimitArr { + existing, exists := ratelimitConfigs[rl.Name] + if !exists { + ratelimitConfigs[rl.Name] = rl + continue + } + + if rl.KeyID != "" && existing.IdentityID != "" { + ratelimitConfigs[rl.Name] = rl + } + } + + authorizedWorkspaceID := key.WorkspaceID + if key.ForWorkspaceID.Valid { + authorizedWorkspaceID = key.ForWorkspaceID.String + } + + sess.WorkspaceID = authorizedWorkspaceID + kv := &KeyVerifier{ + Key: key, + clickhouse: s.clickhouse, + rateLimiter: s.raterLimiter, + usageLimiter: s.usageLimiter, + AuthorizedWorkspaceID: authorizedWorkspaceID, + rBAC: s.rbac, + session: sess, + logger: s.logger, + message: "", + isRootKey: key.ForWorkspaceID.Valid, + + // By default we assume the key is valid unless proven otherwise + Status: StatusValid, + ratelimitConfigs: ratelimitConfigs, + Roles: roles, + Permissions: permissions, + RatelimitResults: nil, + } + + if key.DeletedAtM.Valid { + kv.setInvalid(StatusNotFound, "key is deleted") + return kv, nil + } + + if key.ApiDeletedAtM.Valid { + kv.setInvalid(StatusNotFound, "key is deleted") + return kv, nil + } + + if !key.Enabled { + kv.setInvalid(StatusDisabled, "key is disabled") + return kv, nil + } + + if key.Expires.Valid && time.Now().After(key.Expires.Time) { + kv.setInvalid(StatusExpired, fmt.Sprintf("the key has expired on %s", key.Expires.Time.Format(time.RFC3339))) + return kv, nil + } + + return kv, nil +} diff --git a/go/internal/services/keys/interface.go b/go/internal/services/keys/interface.go index cf780e9faa..a1006af80c 100644 --- a/go/internal/services/keys/interface.go +++ b/go/internal/services/keys/interface.go @@ -6,25 +6,32 @@ import ( "github.com/unkeyed/unkey/go/pkg/zen" ) +// KeyService defines the interface for key management operations. +// It provides methods for key creation, retrieval, and validation. type KeyService interface { - Verify(ctx context.Context, hash string) (VerifyResponse, error) - VerifyRootKey(ctx context.Context, sess *zen.Session) (VerifyResponse, error) + // Get retrieves a key and returns a KeyVerifier for validation + Get(ctx context.Context, sess *zen.Session, hash string) (*KeyVerifier, error) + // GetRootKey retrieves and validates a root key from the session + GetRootKey(ctx context.Context, sess *zen.Session) (*KeyVerifier, error) + // CreateKey generates a new secure API key CreateKey(ctx context.Context, req CreateKeyRequest) (CreateKeyResponse, error) } +// VerifyResponse contains the result of a successful key verification. type VerifyResponse struct { - AuthorizedWorkspaceID string - KeyID string + AuthorizedWorkspaceID string // The workspace ID that the key is authorized for + KeyID string // The unique identifier of the key } +// CreateKeyRequest specifies the parameters for creating a new API key. type CreateKeyRequest struct { - // Key generation parameters - Prefix string - ByteLength int + Prefix string // Optional prefix to prepend to the key (e.g., "test_", "prod_") + ByteLength int // Length of the random bytes to generate (16-255) } +// CreateKeyResponse contains the generated key and its metadata. type CreateKeyResponse struct { - Key string // The plaintext key - Hash string // SHA-256 hash for storage - Start string // Key prefix for indexing + Key string // The complete plaintext key (prefix + encoded random bytes) + Hash string // SHA-256 hash of the key for secure storage + Start string // The start of the key for indexing and display purposes } diff --git a/go/internal/services/keys/options.go b/go/internal/services/keys/options.go new file mode 100644 index 0000000000..832caec0c0 --- /dev/null +++ b/go/internal/services/keys/options.go @@ -0,0 +1,76 @@ +package keys + +import ( + "errors" + + "github.com/unkeyed/unkey/go/apps/api/openapi" + "github.com/unkeyed/unkey/go/pkg/rbac" +) + +// VerifyOption represents a functional option for configuring key verification. +// Options can be combined to create complex validation scenarios. +type VerifyOption func(*verifyConfig) error + +// verifyConfig holds the internal configuration for verification options. +type verifyConfig struct { + ipWhitelist bool + credits *int32 + apiID *string + tags []string + permissions *rbac.PermissionQuery + ratelimits []openapi.KeysVerifyKeyRatelimit +} + +// WithCredits validates that the key has sufficient usage credits and deducts the specified cost. +// The cost must be non-negative. If the key doesn't have enough credits, verification fails. +func WithCredits(cost int32) VerifyOption { + return func(config *verifyConfig) error { + if cost < 0 { + return errors.New("cost cannot be negative") + } + config.credits = &cost + return nil + } +} + +// WithIPWhitelist validates that the client IP address is in the key's IP whitelist. +// The client IP is extracted from the session. If no whitelist is configured, this check is skipped. +func WithIPWhitelist() VerifyOption { + return func(config *verifyConfig) error { + config.ipWhitelist = true + return nil + } +} + +// WithPermissions validates that the key has the required RBAC permissions. +// The query specifies the action and resource that the key needs access to. +func WithPermissions(query rbac.PermissionQuery) VerifyOption { + return func(config *verifyConfig) error { + config.permissions = &query + return nil + } +} + +// WithRateLimits validates the key against the specified rate limits. +// These limits are applied in addition to any auto-applied limits on the key or identity. +func WithRateLimits(limits []openapi.KeysVerifyKeyRatelimit) VerifyOption { + return func(config *verifyConfig) error { + config.ratelimits = limits + return nil + } +} + +// WithTags adds given tags to the key verification. +func WithTags(tags []string) VerifyOption { + return func(config *verifyConfig) error { + config.tags = tags + return nil + } +} + +func WithApiID(apiID string) VerifyOption { + return func(config *verifyConfig) error { + config.apiID = &apiID + return nil + } +} diff --git a/go/internal/services/keys/service.go b/go/internal/services/keys/service.go index b2b27ead83..f9084f390b 100644 --- a/go/internal/services/keys/service.go +++ b/go/internal/services/keys/service.go @@ -1,33 +1,57 @@ package keys import ( + "fmt" + + "github.com/unkeyed/unkey/go/internal/services/ratelimit" + "github.com/unkeyed/unkey/go/internal/services/usagelimiter" "github.com/unkeyed/unkey/go/pkg/cache" - "github.com/unkeyed/unkey/go/pkg/clock" + "github.com/unkeyed/unkey/go/pkg/clickhouse" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/otel/logging" + "github.com/unkeyed/unkey/go/pkg/rbac" ) +// Config holds the configuration for creating a new keys service instance. type Config struct { - Logger logging.Logger - DB db.Database - Clock clock.Clock - KeyCache cache.Cache[string, db.Key] - WorkspaceCache cache.Cache[string, db.Workspace] + Logger logging.Logger // Logger for service operations + DB db.Database // Database connection + RateLimiter ratelimit.Service // Rate limiting service + RBAC *rbac.RBAC // Role-based access control + Clickhouse clickhouse.ClickHouse // Clickhouse for telemetry + + KeyCache cache.Cache[string, db.FindKeyForVerificationRow] // Cache for key lookups } type service struct { - logger logging.Logger - db db.Database + logger logging.Logger + db db.Database + raterLimiter ratelimit.Service + usageLimiter usagelimiter.Service + rbac *rbac.RBAC + clickhouse clickhouse.ClickHouse + // hash -> key - keyCache cache.Cache[string, db.Key] - workspaceCache cache.Cache[string, db.Workspace] + keyCache cache.Cache[string, db.FindKeyForVerificationRow] } +// New creates a new keys service instance with the provided configuration. func New(config Config) (*service, error) { + ulSvc, err := usagelimiter.New(usagelimiter.Config{ + Logger: config.Logger, + DB: config.DB, + }) + if err != nil { + return nil, fmt.Errorf("unable to create usage limiter service: %w", err) + } + return &service{ - logger: config.Logger, - db: config.DB, - keyCache: config.KeyCache, - workspaceCache: config.WorkspaceCache, + logger: config.Logger, + db: config.DB, + rbac: config.RBAC, + raterLimiter: config.RateLimiter, + usageLimiter: ulSvc, + clickhouse: config.Clickhouse, + keyCache: config.KeyCache, }, nil } diff --git a/go/internal/services/keys/status.go b/go/internal/services/keys/status.go new file mode 100644 index 0000000000..e4daae67a8 --- /dev/null +++ b/go/internal/services/keys/status.go @@ -0,0 +1,155 @@ +package keys + +import ( + "github.com/unkeyed/unkey/go/apps/api/openapi" + "github.com/unkeyed/unkey/go/pkg/codes" + "github.com/unkeyed/unkey/go/pkg/fault" +) + +// KeyStatus represents the validation status of a key after verification. +// Each status indicates a specific validation outcome that can be used +// to determine the appropriate response and error handling. +type KeyStatus string + +const ( + StatusValid KeyStatus = "VALID" + StatusNotFound KeyStatus = "NOT_FOUND" + StatusDisabled KeyStatus = "DISABLED" + StatusExpired KeyStatus = "EXPIRED" + StatusForbidden KeyStatus = "FORBIDDEN" + StatusInsufficientPermissions KeyStatus = "INSUFFICIENT_PERMISSIONS" + StatusRateLimited KeyStatus = "RATE_LIMITED" + StatusUsageExceeded KeyStatus = "USAGE_EXCEEDED" + StatusWorkspaceDisabled KeyStatus = "WORKSPACE_DISABLED" + StatusWorkspaceNotFound KeyStatus = "WORKSPACE_NOT_FOUND" +) + +// ToFault converts the verification result to an appropriate fault error. +// This method should only be called when k.Valid is false. +// It provides structured error information that matches the API specification. +func (k *KeyVerifier) ToFault() error { + switch k.Status { + case StatusValid: + return nil + case StatusNotFound: + return fault.New("key does not exist", + fault.Code(codes.Auth.Authentication.KeyNotFound.URN()), + fault.Internal("key does not exist"), + fault.Public("We could not find the requested key."), + ) + case StatusDisabled: + message := k.message + if message == "" { + message = "the key is disabled" + } + return fault.New("key is disabled", + fault.Code(codes.Auth.Authorization.KeyDisabled.URN()), + fault.Internal(message), + fault.Public("The key is disabled."), + ) + case StatusExpired: + message := k.message + if message == "" { + message = "the key has expired" + } + return fault.New("key has expired", + fault.Code(codes.Auth.Authorization.Forbidden.URN()), + fault.Internal(message), + fault.Public(message), + ) + case StatusWorkspaceDisabled: + return fault.New("workspace is disabled", + fault.Code(codes.Auth.Authorization.WorkspaceDisabled.URN()), + fault.Internal("workspace disabled"), + fault.Public("The workspace is disabled."), + ) + case StatusWorkspaceNotFound: + return fault.New("workspace not found", + fault.Code(codes.Data.Workspace.NotFound.URN()), + fault.Internal("workspace disabled"), + fault.Public("The requested workspace does not exist."), + ) + case StatusForbidden: + message := k.message + if message == "" { + message = "Forbidden" + } + return fault.New("forbidden", + fault.Code(codes.Auth.Authorization.Forbidden.URN()), + fault.Internal(message), + fault.Public(message), + ) + case StatusInsufficientPermissions: + message := k.message + if message == "" { + message = "Insufficient permissions to access this resource." + } + return fault.New("insufficient permissions", + fault.Code(codes.Auth.Authorization.InsufficientPermissions.URN()), + fault.Internal(message), + fault.Public(message), + ) + case StatusUsageExceeded: + message := k.message + if message == "" { + message = "Key usage limit exceeded." + } + return fault.New("key usage limit exceeded", + fault.Code(codes.Auth.Authorization.Forbidden.URN()), + fault.Internal(message), + fault.Public(message), + ) + case StatusRateLimited: + message := k.message + if message == "" { + message = "Rate limit exceeded" + } + return fault.New("rate limit exceeded", + fault.Code(codes.Auth.Authorization.Forbidden.URN()), + fault.Internal(message), + fault.Public(message), + ) + default: + return fault.New("key verification failed", + fault.Code(codes.Auth.Authorization.Forbidden.URN()), + fault.Internal("key verification failed with unknown status"), + fault.Public("Key verification failed."), + ) + } +} + +// ToOpenAPIStatus converts our internal KeyStatus to the OpenAPI response status type. +// This mapping ensures consistency between internal validation and external API responses. +func (k *KeyVerifier) ToOpenAPIStatus() openapi.KeysVerifyKeyResponseDataCode { + switch k.Status { + case StatusValid: + return openapi.VALID + case StatusNotFound: + return openapi.NOTFOUND + case StatusDisabled: + return openapi.DISABLED + case StatusExpired: + return openapi.EXPIRED + case StatusForbidden: + return openapi.FORBIDDEN + case StatusInsufficientPermissions: + return openapi.INSUFFICIENTPERMISSIONS + case StatusUsageExceeded: + return openapi.USAGEEXCEEDED + case StatusRateLimited: + return openapi.RATELIMITED + case StatusWorkspaceNotFound: + return openapi.NOTFOUND + case StatusWorkspaceDisabled: + return openapi.FORBIDDEN + default: + return openapi.FORBIDDEN + } +} + +// setInvalid marks the key as invalid with the specified status and message. +// This is used internally by validation methods to indicate validation failures. +func (k *KeyVerifier) setInvalid(status KeyStatus, message string) { + k.Status = status + k.message = message +} diff --git a/go/internal/services/keys/validation.go b/go/internal/services/keys/validation.go new file mode 100644 index 0000000000..6142b51336 --- /dev/null +++ b/go/internal/services/keys/validation.go @@ -0,0 +1,242 @@ +package keys + +import ( + "context" + "database/sql" + "fmt" + "strconv" + "strings" + "time" + + "github.com/unkeyed/unkey/go/apps/api/openapi" + "github.com/unkeyed/unkey/go/internal/services/ratelimit" + "github.com/unkeyed/unkey/go/internal/services/usagelimiter" + "github.com/unkeyed/unkey/go/pkg/codes" + "github.com/unkeyed/unkey/go/pkg/fault" + "github.com/unkeyed/unkey/go/pkg/otel/tracing" + "github.com/unkeyed/unkey/go/pkg/prometheus/metrics" + "github.com/unkeyed/unkey/go/pkg/ptr" + "github.com/unkeyed/unkey/go/pkg/rbac" + "golang.org/x/exp/slices" +) + +// withCredits validates that the key has sufficient usage credits and deducts the specified cost. +// It updates the key's remaining request count and marks the key as invalid if the limit is exceeded. +func (k *KeyVerifier) withCredits(ctx context.Context, cost int32) error { + ctx, span := tracing.Start(ctx, "verify.withCredits") + defer span.End() + + if k.Status != StatusValid { + return nil + } + + usage, err := k.usageLimiter.Limit(ctx, usagelimiter.UsageRequest{ + KeyId: k.Key.ID, + Cost: cost, + }) + if err != nil { + return err + } + + k.Key.RemainingRequests = sql.NullInt32{Int32: usage.Remaining, Valid: usage.Remaining >= 0} + if !usage.Valid { + k.setInvalid(StatusUsageExceeded, "Key usage limit exceeded.") + } + + // Emit Prometheus metrics for credits spent + identityID := "" + if k.Key.IdentityID.Valid { + identityID = k.Key.IdentityID.String + } + + // Credits are deducted when usage is valid AND cost > 0 + deducted := usage.Valid && cost > 0 + actualCostDeducted := int32(0) + if deducted { + actualCostDeducted = cost + } + + metrics.KeyCreditsSpentTotal.WithLabelValues( + k.AuthorizedWorkspaceID, // workspace_id + k.Key.ID, // key_id + identityID, // identity_id + strconv.FormatBool(deducted), // deducted - whether credits were actually deducted + ).Add(float64(actualCostDeducted)) // Add the actual amount deducted, not the requested cost + + return nil +} + +// withIPWhitelist validates that the client IP address is in the key's IP whitelist. +// If no whitelist is configured, this validation is skipped. +func (k *KeyVerifier) withIPWhitelist() error { + if k.Status != StatusValid { + return nil + } + + if !k.Key.IpWhitelist.Valid { + return nil + } + + clientIP := k.session.Location() + if clientIP == "" { + k.Status = StatusForbidden + k.message = "client IP is required for IP whitelist validation" + return nil + } + + allowedIPs := strings.Split(k.Key.IpWhitelist.String, ",") + for i, ip := range allowedIPs { + allowedIPs[i] = strings.TrimSpace(ip) + } + + if !slices.Contains(allowedIPs, clientIP) { + k.setInvalid(StatusForbidden, fmt.Sprintf("client IP %s is not in the whitelist", clientIP)) + } + + return nil +} + +func (k *KeyVerifier) WithApiID(apiID string) { + if k.Status != StatusValid { + return + } + + if k.Key.ApiID != apiID { + k.setInvalid(StatusForbidden, fmt.Sprintf("The key does not belong to %s", apiID)) + } +} + +// withPermissions validates that the key has the required RBAC permissions. +// It uses the configured RBAC system to evaluate the permission query against the key's permissions. +func (k *KeyVerifier) withPermissions(ctx context.Context, query rbac.PermissionQuery) error { + ctx, span := tracing.Start(ctx, "verify.withPermissions") + defer span.End() + + if k.Status != StatusValid { + return nil + } + + allowed, err := k.rBAC.EvaluatePermissions(query, k.Permissions) + if err != nil { + return err + } + + if !allowed.Valid { + k.setInvalid(StatusInsufficientPermissions, allowed.Message) + } + + return nil +} + +// withRateLimits validates the key against both auto-applied and specified rate limits. +// Auto-applied limits come from the key or identity configuration, while specified limits +// are provided at verification time. All limits must pass for the key to be valid. +func (k *KeyVerifier) withRateLimits(ctx context.Context, specifiedLimits []openapi.KeysVerifyKeyRatelimit) error { + // nolint:ineffassign + ctx, span := tracing.Start(ctx, "verify.withRateLimits") + defer span.End() + + if k.Status != StatusValid { + return nil + } + + ratelimitsToCheck := make(map[string]RatelimitConfigAndResult) + for name, rl := range k.ratelimitConfigs { + if rl.AutoApply == 0 { + continue + } + + identifier := k.Key.ID + if rl.IdentityID != "" { + identifier = rl.IdentityID + } + + ratelimitsToCheck[name] = RatelimitConfigAndResult{ + Cost: 1, + Name: rl.Name, + Duration: time.Duration(rl.Duration) * time.Millisecond, + Limit: int64(rl.Limit), + AutoApply: rl.AutoApply == 1, + Identifier: identifier, + Response: nil, + } + } + + for _, rl := range specifiedLimits { + if rl.Limit != nil && rl.Duration != nil { + ratelimitsToCheck[rl.Name] = RatelimitConfigAndResult{ + Cost: int64(ptr.SafeDeref(rl.Cost, 1)), + Name: rl.Name, + Duration: time.Duration(*rl.Duration) * time.Millisecond, + Limit: int64(*rl.Limit), + AutoApply: false, + Identifier: k.Key.ID, // Specified limits use key ID + Response: nil, + } + + continue + } + + dbRl, exists := k.ratelimitConfigs[rl.Name] + if !exists { + errorMsg := "ratelimit %q was requested but does not exist for key %q" + if k.Key.IdentityID.Valid { + errorMsg += " nor identity: %q external ID: %q" + } else { + errorMsg += " and there is no identity connected." + } + + errorMsg = fmt.Sprintf(errorMsg, rl.Name, k.Key.ID, k.Key.IdentityID.String, k.Key.ExternalID.String) + return fault.New("invalid ratelimit requested", + fault.Code(codes.App.Precondition.PreconditionFailed.URN()), + fault.Public(errorMsg), + ) + } + + identifier := k.Key.ID + if dbRl.IdentityID != "" { + identifier = dbRl.IdentityID + } + + ratelimitsToCheck[rl.Name] = RatelimitConfigAndResult{ + Name: dbRl.Name, + Duration: time.Duration(dbRl.Duration) * time.Millisecond, + Cost: int64(ptr.SafeDeref(rl.Cost, 1)), + Limit: int64(dbRl.Limit), + AutoApply: dbRl.AutoApply == 1, + Identifier: identifier, + Response: nil, + } + } + + if len(ratelimitsToCheck) == 0 { + return nil + } + + for name, config := range ratelimitsToCheck { + response, err := k.rateLimiter.Ratelimit(ctx, ratelimit.RatelimitRequest{ + Identifier: config.Identifier, // Use the pre-determined identifier + Limit: config.Limit, + Duration: config.Duration, + Cost: config.Cost, + Time: time.Now(), + }) + if err != nil { + return err + } + + config.Response = &response + ratelimitsToCheck[name] = config + + // If rate limit exceeded, stop processing + if !response.Success { + k.setInvalid(StatusRateLimited, fmt.Sprintf("key exceeded rate limit %s", name)) + break + } + } + + // Store the final results + k.RatelimitResults = ratelimitsToCheck + + return nil +} diff --git a/go/internal/services/keys/verifier.go b/go/internal/services/keys/verifier.go new file mode 100644 index 0000000000..528d663277 --- /dev/null +++ b/go/internal/services/keys/verifier.go @@ -0,0 +1,134 @@ +package keys + +import ( + "context" + "strconv" + "time" + + "github.com/unkeyed/unkey/go/internal/services/ratelimit" + "github.com/unkeyed/unkey/go/internal/services/usagelimiter" + "github.com/unkeyed/unkey/go/pkg/clickhouse" + "github.com/unkeyed/unkey/go/pkg/clickhouse/schema" + "github.com/unkeyed/unkey/go/pkg/db" + "github.com/unkeyed/unkey/go/pkg/otel/logging" + "github.com/unkeyed/unkey/go/pkg/prometheus/metrics" + "github.com/unkeyed/unkey/go/pkg/rbac" + "github.com/unkeyed/unkey/go/pkg/zen" +) + +// RatelimitConfigAndResult holds both the configuration and result for a rate limit +type RatelimitConfigAndResult struct { + Cost int64 + Name string + Duration time.Duration + Limit int64 + AutoApply bool + Identifier string // The identifier to use for this rate limit + Response *ratelimit.RatelimitResponse // nil until rate limit is checked +} + +// KeyVerifier represents a key that has been loaded from the database and is ready for verification. +// It contains all the necessary information and services to perform various validation checks. +type KeyVerifier struct { + Key db.FindKeyForVerificationRow // The key data from the database + ratelimitConfigs map[string]db.KeyFindForVerificationRatelimit // Rate limits configured for this key (name -> config) + Roles []string // RBAC roles assigned to this key + Permissions []string // RBAC permissions assigned to this key + Status KeyStatus // The current validation status + AuthorizedWorkspaceID string // The workspace ID this key is authorized for + RatelimitResults map[string]RatelimitConfigAndResult // Combined config and results for rate limits (name -> config+result) + isRootKey bool // Whether this is a root key (special handling) + session *zen.Session // The current request session + rateLimiter ratelimit.Service // Rate limiting service + usageLimiter usagelimiter.Service // Usage limiting service + rBAC *rbac.RBAC // Role-based access control service + clickhouse clickhouse.ClickHouse // Clickhouse for telemetry + logger logging.Logger // Logger for verification operations + message string // Internal message for validation failures +} + +// GetRatelimitConfigs returns the rate limit configurations +func (k *KeyVerifier) GetRatelimitConfigs() map[string]db.KeyFindForVerificationRatelimit { + return k.ratelimitConfigs +} + +// Verify performs key verification with the given options. +// For root keys: returns fault errors for validation failures. +// For normal keys: returns error only for system problems, check k.Valid and k.Status for validation results. +func (k *KeyVerifier) Verify(ctx context.Context, opts ...VerifyOption) error { + // Skip verification if key is already invalid + if k.Status != StatusValid { + // For root keys, auto-return validation failures as fault errors + if k.isRootKey { + return k.ToFault() + } + return nil + } + + // nolint:exhaustruct + config := &verifyConfig{} + for _, opt := range opts { + if err := opt(config); err != nil { + return err + } + } + + var err error + if config.credits != nil { + err = k.withCredits(ctx, *config.credits) + if err != nil { + return err + } + } + + if config.ipWhitelist { + err = k.withIPWhitelist() + if err != nil { + return err + } + } + + if config.permissions != nil { + err = k.withPermissions(ctx, *config.permissions) + if err != nil { + return err + } + } + + if config.apiID != nil { + k.WithApiID(*config.apiID) + } + + err = k.withRateLimits(ctx, config.ratelimits) + if err != nil { + return err + } + + k.clickhouse.BufferKeyVerification(schema.KeyVerificationRequestV1{ + RequestID: k.session.RequestID(), + WorkspaceID: k.session.AuthorizedWorkspaceID(), + Time: time.Now().UnixMilli(), + Region: "", + Outcome: string(k.Status), + KeySpaceID: k.Key.KeyAuthID, + KeyID: k.Key.ID, + IdentityID: k.Key.IdentityID.String, + Tags: config.tags, + }) + + // Emit Prometheus metrics for key verification + metrics.KeyVerificationsTotal.WithLabelValues( + k.AuthorizedWorkspaceID, // workspaceId + k.Key.ApiID, // apiId + k.Key.ID, // keyId + strconv.FormatBool(k.Status == StatusValid), // valid + string(k.Status), // code + ).Inc() + + // For root keys, auto-return validation failures as fault errors + if k.isRootKey && k.Status != StatusValid { + return k.ToFault() + } + + return nil +} diff --git a/go/internal/services/keys/verify.go b/go/internal/services/keys/verify.go deleted file mode 100644 index 831fb33118..0000000000 --- a/go/internal/services/keys/verify.go +++ /dev/null @@ -1,123 +0,0 @@ -package keys - -import ( - "context" - - "github.com/unkeyed/unkey/go/internal/services/caches" - "github.com/unkeyed/unkey/go/pkg/assert" - "github.com/unkeyed/unkey/go/pkg/codes" - "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/fault" - "github.com/unkeyed/unkey/go/pkg/hash" - "github.com/unkeyed/unkey/go/pkg/otel/tracing" -) - -func (s *service) Verify(ctx context.Context, rawKey string) (VerifyResponse, error) { - ctx, span := tracing.Start(ctx, "keys.VerifyRootKey") - defer span.End() - - err := assert.NotEmpty(rawKey) - if err != nil { - return VerifyResponse{}, fault.Wrap(err, fault.Internal("rawKey is empty")) - } - h := hash.Sha256(rawKey) - - key, err := s.keyCache.SWR(ctx, h, func(ctx context.Context) (db.Key, error) { - return db.Query.FindKeyByHash(ctx, s.db.RO(), h) - }, caches.DefaultFindFirstOp) - - if db.IsNotFound(err) { - return VerifyResponse{}, fault.Wrap( - err, - fault.Code(codes.Auth.Authentication.KeyNotFound.URN()), - fault.Internal("key does not exist"), - fault.Public("We could not find the requested key."), - ) - } - - if err != nil { - return VerifyResponse{}, fault.Wrap( - err, - fault.Internal("unable to load key"), - fault.Public("We could not load the requested key."), - ) - } - - // Following are various checks to ensure the validity of the key - // - Is it enabled? - // - Is it deleted? - // - Is it expired? - // - Is it ratelimited? - // - Is the related workspace deleted? - // - Is the related workspace disabled? - // - Is the related forWorkspace deleted? - // - Is the related forWorkspace disabled? - - if key.DeletedAtM.Valid { - return VerifyResponse{}, fault.New( - "key is deleted", - fault.Code(codes.Data.Key.NotFound.URN()), - fault.Internal("deleted_at is non-zero"), - fault.Public("The key has been deleted."), - ) - } - - if !key.Enabled { - return VerifyResponse{}, fault.New( - "key is disabled", - fault.Code(codes.Auth.Authorization.KeyDisabled.URN()), - fault.Internal("disabled"), - fault.Public("The key is disabled."), - ) - } - - authorizedWorkspaceID := key.WorkspaceID - if key.ForWorkspaceID.Valid { - authorizedWorkspaceID = key.ForWorkspaceID.String - } - - ws, err := s.workspaceCache.SWR(ctx, authorizedWorkspaceID, func(ctx context.Context) (db.Workspace, error) { - return db.Query.FindWorkspaceByID(ctx, s.db.RW(), authorizedWorkspaceID) - }, caches.DefaultFindFirstOp) - - if db.IsNotFound(err) { - return VerifyResponse{}, fault.New( - "workspace not found", - fault.Code(codes.Data.Workspace.NotFound.URN()), - fault.Internal("workspace not found"), - fault.Public("The requested workspace does not exist."), - ) - } - if err != nil { - s.logger.Error("unable to load workspace", - "error", err.Error()) - return VerifyResponse{}, fault.Wrap( - err, - fault.Code(codes.App.Internal.ServiceUnavailable.URN()), - fault.Internal("unable to load workspace"), - fault.Public("We could not load the requested workspace."), - ) - } - - if !ws.Enabled { - return VerifyResponse{}, fault.New( - "workspace is disabled", - fault.Code(codes.Auth.Authorization.WorkspaceDisabled.URN()), - fault.Internal("workspace disabled"), - fault.Public("The workspace is disabled."), - ) - } - - res := VerifyResponse{ - AuthorizedWorkspaceID: authorizedWorkspaceID, - KeyID: key.ID, - } - - // Root keys store the user's workspace id in `ForWorkspaceID` and we're - // interested in the user, not our rootkey workspace. - if key.ForWorkspaceID.Valid { - res.AuthorizedWorkspaceID = key.ForWorkspaceID.String - } - - return res, nil -} diff --git a/go/internal/services/keys/verify_root_key.go b/go/internal/services/keys/verify_root_key.go deleted file mode 100644 index 729abf1261..0000000000 --- a/go/internal/services/keys/verify_root_key.go +++ /dev/null @@ -1,30 +0,0 @@ -package keys - -import ( - "context" - - "github.com/unkeyed/unkey/go/pkg/fault" - "github.com/unkeyed/unkey/go/pkg/zen" -) - -func (s *service) VerifyRootKey(ctx context.Context, sess *zen.Session) (VerifyResponse, error) { - - rootKey, err := zen.Bearer(sess) - if err != nil { - return VerifyResponse{}, fault.Wrap(err, - fault.Internal("no bearer"), - fault.Public("You must provide a valid root key in the Authorization header in the format 'Bearer ROOT_KEY'."), - ) - } - - res, err := s.Verify(ctx, rootKey) - if err != nil { - return VerifyResponse{}, fault.Wrap(err, - fault.Internal("invalid root key"), - fault.Public("The provided root key is invalid.")) - } - sess.WorkspaceID = res.AuthorizedWorkspaceID - - return res, nil - -} diff --git a/go/internal/services/permissions/check.go b/go/internal/services/permissions/check.go deleted file mode 100644 index fee99adff3..0000000000 --- a/go/internal/services/permissions/check.go +++ /dev/null @@ -1,53 +0,0 @@ -package permissions - -import ( - "context" - - "github.com/unkeyed/unkey/go/pkg/cache" - "github.com/unkeyed/unkey/go/pkg/codes" - "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/fault" - "github.com/unkeyed/unkey/go/pkg/otel/tracing" - "github.com/unkeyed/unkey/go/pkg/rbac" -) - -func (s *service) Check(ctx context.Context, keyID string, query rbac.PermissionQuery) error { - - ctx, span := tracing.Start(ctx, "permissions.Check") - defer span.End() - - permissions, err := s.cache.SWR(ctx, keyID, func(ctx context.Context) ([]string, error) { - return db.Query.ListPermissionsByKeyID(ctx, s.db.RO(), db.ListPermissionsByKeyIDParams{ - KeyID: keyID, - }) - - }, func(err error) cache.Op { - if err == nil { - return cache.WriteValue - } - return cache.Noop - - }) - - if err != nil { - return fault.Wrap(err, fault.Internal("unable to load permissions from db")) - } - - res, err := s.rbac.EvaluatePermissions(query, permissions) - if err != nil { - return fault.New("unable to evaluate permissions", - fault.Code(codes.App.Internal.UnexpectedError.URN()), - fault.Internal(err.Error()), - fault.Public("Unhandled exception during permission evaluation."), - ) - } - if !res.Valid { - return fault.New("insufficient permissions", - fault.Code(codes.Auth.Authorization.InsufficientPermissions.URN()), - fault.Internal(res.Message), - fault.Public(res.Message), - ) - } - - return nil -} diff --git a/go/internal/services/permissions/interface.go b/go/internal/services/permissions/interface.go deleted file mode 100644 index e64f200502..0000000000 --- a/go/internal/services/permissions/interface.go +++ /dev/null @@ -1,14 +0,0 @@ -package permissions - -import ( - "context" - - "github.com/unkeyed/unkey/go/pkg/rbac" -) - -type PermissionService interface { - // If the user does not have the required permissions, an error is returned. - // The returned error will have a code of codes.Auth.Authorization.InsufficientPermissions.URN() - // and can be returned in a zen route as is. - Check(ctx context.Context, keyId string, query rbac.PermissionQuery) error -} diff --git a/go/internal/services/permissions/service.go b/go/internal/services/permissions/service.go deleted file mode 100644 index 6401132113..0000000000 --- a/go/internal/services/permissions/service.go +++ /dev/null @@ -1,37 +0,0 @@ -package permissions - -import ( - "github.com/unkeyed/unkey/go/pkg/cache" - "github.com/unkeyed/unkey/go/pkg/clock" - "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/otel/logging" - "github.com/unkeyed/unkey/go/pkg/rbac" -) - -type service struct { - db db.Database - logger logging.Logger - rbac *rbac.RBAC - - // keyId -> permissions - cache cache.Cache[string, []string] -} - -var _ PermissionService = (*service)(nil) - -type Config struct { - DB db.Database - Logger logging.Logger - Clock clock.Clock - Cache cache.Cache[string, []string] -} - -func New(config Config) (*service, error) { - - return &service{ - db: config.DB, - logger: config.Logger, - rbac: rbac.New(), - cache: config.Cache, - }, nil -} diff --git a/go/internal/services/ratelimit/replay.go b/go/internal/services/ratelimit/replay.go index 5f4a2059fc..a9b613fca0 100644 --- a/go/internal/services/ratelimit/replay.go +++ b/go/internal/services/ratelimit/replay.go @@ -58,6 +58,7 @@ func (s *service) syncWithOrigin(ctx context.Context, req RatelimitRequest) erro if err != nil { return err } + key := bucketKey{ identifier: req.Identifier, limit: req.Limit, @@ -79,8 +80,8 @@ func (s *service) syncWithOrigin(ctx context.Context, req RatelimitRequest) erro req.Cost, currentWindow.duration*3, ) - }) + if err != nil { tracing.RecordError(span, err) diff --git a/go/internal/services/ratelimit/service.go b/go/internal/services/ratelimit/service.go index 2278e3c82e..265819e1e2 100644 --- a/go/internal/services/ratelimit/service.go +++ b/go/internal/services/ratelimit/service.go @@ -90,6 +90,7 @@ type Config struct { Logger logging.Logger Clock clock.Clock + // If provided, use this counter implementation instead of creating a Redis counter Counter counter.Counter } @@ -225,7 +226,6 @@ func (s *service) calculateRateLimit(req RatelimitRequest, currentWindow, previo // time.UnixMilli(resp.Reset)) // } func (s *service) Ratelimit(ctx context.Context, req RatelimitRequest) (RatelimitResponse, error) { - _, span := tracing.Start(ctx, "Ratelimit") defer span.End() @@ -264,7 +264,6 @@ func (s *service) Ratelimit(ctx context.Context, req RatelimitRequest) (Ratelimi // Check if we can reject based on local data alone exceeded, effectiveCount, remaining := s.calculateRateLimit(req, currentWindow, previousWindow) if exceeded { - b.strictUntil = req.Time.Add(req.Duration) // Record the denied request diff --git a/go/internal/services/usagelimiter/interface.go b/go/internal/services/usagelimiter/interface.go new file mode 100644 index 0000000000..e45e0a59bd --- /dev/null +++ b/go/internal/services/usagelimiter/interface.go @@ -0,0 +1,20 @@ +package usagelimiter + +import ( + "context" +) + +type Service interface { + // If the given keyId has exceeded its usage limit, an error is returned. + Limit(ctx context.Context, req UsageRequest) (UsageResponse, error) +} + +type UsageRequest struct { + KeyId string + Cost int32 +} + +type UsageResponse struct { + Valid bool + Remaining int32 // Remaining usage for the keyId -1 indicates no limit +} diff --git a/go/internal/services/usagelimiter/limit.go b/go/internal/services/usagelimiter/limit.go new file mode 100644 index 0000000000..a27e484a5e --- /dev/null +++ b/go/internal/services/usagelimiter/limit.go @@ -0,0 +1,45 @@ +package usagelimiter + +import ( + "context" + "database/sql" + "math" + + "github.com/unkeyed/unkey/go/pkg/db" + "github.com/unkeyed/unkey/go/pkg/otel/tracing" +) + +func (s *service) Limit(ctx context.Context, req UsageRequest) (UsageResponse, error) { + ctx, span := tracing.Start(ctx, "usagelimiter.Limit") + defer span.End() + + limit, err := db.Query.FindKeyCredits(ctx, s.db.RW(), req.KeyId) + if err != nil { + if db.IsNotFound(err) { + return UsageResponse{Valid: false, Remaining: 0}, nil + } + + return UsageResponse{Valid: false, Remaining: 0}, err + } + + if !limit.Valid { + return UsageResponse{Valid: true, Remaining: -1}, nil + } + remaining := limit.Int32 + + // Key doesn't have enough credits to cover the request cost + if remaining <= 0 && req.Cost != 0 || remaining-req.Cost < 0 { + return UsageResponse{Valid: false, Remaining: 0}, nil + } + + err = db.Query.UpdateKeyCredits(ctx, s.db.RW(), db.UpdateKeyCreditsParams{ + ID: req.KeyId, + Operation: "decrement", + Credits: sql.NullInt32{Int32: req.Cost, Valid: true}, + }) + if err != nil { + return UsageResponse{}, err + } + + return UsageResponse{Valid: true, Remaining: int32(math.Max(float64(0), float64(remaining-req.Cost)))}, nil +} diff --git a/go/internal/services/usagelimiter/service.go b/go/internal/services/usagelimiter/service.go new file mode 100644 index 0000000000..5abf94c142 --- /dev/null +++ b/go/internal/services/usagelimiter/service.go @@ -0,0 +1,25 @@ +package usagelimiter + +import ( + "github.com/unkeyed/unkey/go/pkg/db" + "github.com/unkeyed/unkey/go/pkg/otel/logging" +) + +type service struct { + db db.Database + logger logging.Logger +} + +var _ Service = (*service)(nil) + +type Config struct { + DB db.Database + Logger logging.Logger +} + +func New(config Config) (*service, error) { + return &service{ + db: config.DB, + logger: config.Logger, + }, nil +} diff --git a/go/pkg/cache/cache.go b/go/pkg/cache/cache.go index 4e1cdddd7e..013fa30ea2 100644 --- a/go/pkg/cache/cache.go +++ b/go/pkg/cache/cache.go @@ -99,38 +99,27 @@ func New[K comparable, V any](config Config[K, V]) (*cache[K, V], error) { } func (c *cache[K, V]) registerMetrics() { - repeat.Every(60*time.Second, func() { - metrics.CacheSize.WithLabelValues(c.resource).Set(float64(c.otter.Size())) metrics.CacheCapacity.WithLabelValues(c.resource).Set(float64(c.otter.Capacity())) - }) - } func (c *cache[K, V]) Get(ctx context.Context, key K) (value V, hit CacheHit) { - e, ok := c.get(ctx, key) if !ok { - // This hack is necessary because you can not return nil as V - var v V - - return v, Miss + return value, Miss } now := c.clock.Now() if now.Before(e.Stale) { - return e.Value, e.Hit } c.otter.Delete(key) - var v V - return v, Miss - + return value, Miss } func (c *cache[K, V]) SetNull(_ context.Context, key K) { @@ -179,21 +168,20 @@ func (c *cache[K, V]) Dump(ctx context.Context) ([]byte, error) { }) b, err := json.Marshal(data) - if err != nil { return nil, fault.Wrap(err, fault.Internal("failed to marshal cache data")) } - return b, nil + return b, nil } func (c *cache[K, V]) Restore(ctx context.Context, b []byte) error { - data := make(map[K]swrEntry[V]) err := json.Unmarshal(b, &data) if err != nil { return fmt.Errorf("failed to unmarshal cache data: %w", err) } + now := c.clock.Now() for key, entry := range data { if now.Before(entry.Fresh) || now.Before(entry.Stale) { @@ -201,6 +189,7 @@ func (c *cache[K, V]) Restore(ctx context.Context, b []byte) error { } // If the entry is older than, we don't restore it } + return nil } @@ -213,13 +202,13 @@ func (c *cache[K, V]) revalidate( key K, refreshFromOrigin func(context.Context) (V, error), op func(error) Op, ) { - c.inflightMu.Lock() _, ok := c.inflightRefreshes[key] if ok { c.inflightMu.Unlock() return } + c.inflightRefreshes[key] = true c.inflightMu.Unlock() @@ -234,6 +223,7 @@ func (c *cache[K, V]) revalidate( if err != nil && !db.IsNotFound(err) { c.logger.Warn("failed to revalidate", "error", err.Error(), "key", key) } + switch op(err) { case WriteValue: c.Set(ctx, key, v) @@ -242,7 +232,6 @@ func (c *cache[K, V]) revalidate( case Noop: break } - } func (c *cache[K, V]) SWR( @@ -251,7 +240,6 @@ func (c *cache[K, V]) SWR( refreshFromOrigin func(context.Context) (V, error), op func(error) Op, ) (V, error) { - now := c.clock.Now() e, ok := c.get(ctx, key) if ok { @@ -259,7 +247,6 @@ func (c *cache[K, V]) SWR( if now.Before(e.Fresh) { // We have data and it's fresh, so we return it - return e.Value, nil } @@ -278,12 +265,11 @@ func (c *cache[K, V]) SWR( // We have old data, that we should not serve anymore c.otter.Delete(key) - } + // Cache Miss // We have no data and need to go to the origin - v, err := refreshFromOrigin(ctx) switch op(err) { @@ -296,5 +282,4 @@ func (c *cache[K, V]) SWR( } return v, err - } diff --git a/go/pkg/cache/entry.go b/go/pkg/cache/entry.go index 2577bebe9e..c032acf583 100644 --- a/go/pkg/cache/entry.go +++ b/go/pkg/cache/entry.go @@ -8,8 +8,10 @@ type swrEntry[T any] struct { Value T `json:"value"` Hit CacheHit `json:"hit"` + // Before this time the entry is considered fresh and vaid Fresh time.Time `json:"fresh"` + // Before this time, the entry should be revalidated // After this time, the entry must be discarded Stale time.Time `json:"stale"` diff --git a/go/pkg/circuitbreaker/lib.go b/go/pkg/circuitbreaker/lib.go index f9ef28f80a..f51761045d 100644 --- a/go/pkg/circuitbreaker/lib.go +++ b/go/pkg/circuitbreaker/lib.go @@ -109,7 +109,6 @@ func WithLogger(logger logging.Logger) applyConfig { type applyConfig func(*config) func New[Res any](name string, applyConfigs ...applyConfig) *CB[Res] { - cfg := &config{ name: name, maxRequests: 10, diff --git a/go/pkg/clickhouse/schema/requests.go b/go/pkg/clickhouse/schema/requests.go index b24cf3b98a..5186004672 100644 --- a/go/pkg/clickhouse/schema/requests.go +++ b/go/pkg/clickhouse/schema/requests.go @@ -64,7 +64,6 @@ type ApiRequestV1 struct { // // Fields are mapped to ClickHouse columns using the `ch` struct tags. type KeyVerificationRequestV1 struct { - // RequestID is a unique identifier for this verification request RequestID string `ch:"request_id"` diff --git a/go/pkg/codes/codes.go b/go/pkg/codes/codes.go index 87643cb7b5..6d7ef8b24c 100644 --- a/go/pkg/codes/codes.go +++ b/go/pkg/codes/codes.go @@ -60,6 +60,9 @@ const ( CategoryUnkeyLimits Category = "limits" CategoryUnkeyApplication Category = "application" + + // CategoryUnkeyVault represents vault-related errors. + CategoryUnkeyVault Category = "vault" ) // Code represents a specific error with its metadata. It contains all components diff --git a/go/pkg/codes/generate.go b/go/pkg/codes/generate.go index cf26698335..4406d7522a 100644 --- a/go/pkg/codes/generate.go +++ b/go/pkg/codes/generate.go @@ -12,7 +12,6 @@ import ( "path/filepath" "reflect" "strings" - "unicode" "github.com/unkeyed/unkey/go/pkg/codes" ) @@ -36,7 +35,7 @@ func main() { defer f.Close() // Write file header - f.WriteString(fmt.Sprintf("// Code generated by generate.go; DO NOT EDIT.\n")) + f.WriteString("// Code generated by generate.go; DO NOT EDIT.\n") f.WriteString("package codes\n\n") // Generate constants @@ -123,9 +122,9 @@ func processErrorDomain(f *os.File, systemName string, domainValue reflect.Value // Section header domainType := domainValue.Type() domainName := domainType.Name() - f.WriteString(fmt.Sprintf("// ----------------\n")) - f.WriteString(fmt.Sprintf("// %s\n", domainName)) - f.WriteString(fmt.Sprintf("// ----------------\n")) + f.WriteString("// ----------------\n") + fmt.Fprintf(f, "// %s\n", domainName) + f.WriteString("// ----------------\n") f.WriteString("\n") // Iterate through categories (fields of the domain struct) @@ -133,7 +132,7 @@ func processErrorDomain(f *os.File, systemName string, domainValue reflect.Value categoryField := domainValue.Field(i) categoryName := domainType.Field(i).Name - f.WriteString(fmt.Sprintf("// %s\n\n", categoryName)) + fmt.Fprintf(f, "// %s\n\n", categoryName) // Iterate through error codes in this category processCategory(f, systemName, domainName, categoryName, categoryField) @@ -178,15 +177,3 @@ func processCategory(f *os.File, systemName, domainName, categoryName string, ca f.WriteString(fmt.Sprintf("\t%s URN = \"%s\"\n", constName, codeStr)) } } - -// toSnakeCase converts a string from PascalCase to snake_case -func toSnakeCase(s string) string { - var result strings.Builder - for i, r := range s { - if i > 0 && unicode.IsUpper(r) { - result.WriteRune('_') - } - result.WriteRune(unicode.ToLower(r)) - } - return result.String() -} diff --git a/go/pkg/counter/redis.go b/go/pkg/counter/redis.go index 55796c3cdd..aac607a7ed 100644 --- a/go/pkg/counter/redis.go +++ b/go/pkg/counter/redis.go @@ -124,9 +124,11 @@ func (r *redisCounter) Get(ctx context.Context, key string) (int64, error) { // Key doesn't exist, return 0 without error return 0, nil } + if err != nil { return 0, err } + return strconv.ParseInt(res, 10, 64) } diff --git a/go/pkg/db/api_insert.sql_generated.go b/go/pkg/db/api_insert.sql_generated.go index a71d448966..ee19a5c7bc 100644 --- a/go/pkg/db/api_insert.sql_generated.go +++ b/go/pkg/db/api_insert.sql_generated.go @@ -16,6 +16,7 @@ INSERT INTO apis ( name, workspace_id, auth_type, + ip_whitelist, key_auth_id, created_at_m, deleted_at_m @@ -26,6 +27,7 @@ INSERT INTO apis ( ?, ?, ?, + ?, NULL ) ` @@ -35,6 +37,7 @@ type InsertApiParams struct { Name string `db:"name"` WorkspaceID string `db:"workspace_id"` AuthType NullApisAuthType `db:"auth_type"` + IpWhitelist sql.NullString `db:"ip_whitelist"` KeyAuthID sql.NullString `db:"key_auth_id"` CreatedAtM int64 `db:"created_at_m"` } @@ -46,6 +49,7 @@ type InsertApiParams struct { // name, // workspace_id, // auth_type, +// ip_whitelist, // key_auth_id, // created_at_m, // deleted_at_m @@ -56,6 +60,7 @@ type InsertApiParams struct { // ?, // ?, // ?, +// ?, // NULL // ) func (q *Queries) InsertApi(ctx context.Context, db DBTX, arg InsertApiParams) error { @@ -64,6 +69,7 @@ func (q *Queries) InsertApi(ctx context.Context, db DBTX, arg InsertApiParams) e arg.Name, arg.WorkspaceID, arg.AuthType, + arg.IpWhitelist, arg.KeyAuthID, arg.CreatedAtM, ) diff --git a/go/pkg/db/key_find_credits.sql_generated.go b/go/pkg/db/key_find_credits.sql_generated.go new file mode 100644 index 0000000000..0f7a87e3b1 --- /dev/null +++ b/go/pkg/db/key_find_credits.sql_generated.go @@ -0,0 +1,25 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 +// source: key_find_credits.sql + +package db + +import ( + "context" + "database/sql" +) + +const findKeyCredits = `-- name: FindKeyCredits :one +SELECT remaining_requests FROM ` + "`" + `keys` + "`" + ` k WHERE k.id = ? +` + +// FindKeyCredits +// +// SELECT remaining_requests FROM `keys` k WHERE k.id = ? +func (q *Queries) FindKeyCredits(ctx context.Context, db DBTX, id string) (sql.NullInt32, error) { + row := db.QueryRowContext(ctx, findKeyCredits, id) + var remaining_requests sql.NullInt32 + err := row.Scan(&remaining_requests) + return remaining_requests, err +} diff --git a/go/pkg/db/key_find_for_verification.sql_generated.go b/go/pkg/db/key_find_for_verification.sql_generated.go index f2a0d90fd1..5e463a29c2 100644 --- a/go/pkg/db/key_find_for_verification.sql_generated.go +++ b/go/pkg/db/key_find_for_verification.sql_generated.go @@ -8,170 +8,224 @@ package db import ( "context" "database/sql" - "encoding/json" ) const findKeyForVerification = `-- name: FindKeyForVerification :one -WITH direct_permissions AS ( - SELECT kp.key_id, p.name as permission_name - FROM keys_permissions kp - JOIN permissions p ON kp.permission_id = p.id -), -role_permissions AS ( - SELECT kr.key_id, p.name as permission_name - FROM keys_roles kr - JOIN roles_permissions rp ON kr.role_id = rp.role_id - JOIN permissions p ON rp.permission_id = p.id -), -all_permissions AS ( - SELECT key_id, permission_name FROM direct_permissions - UNION - SELECT key_id, permission_name FROM role_permissions -), -all_ratelimits AS ( - SELECT - key_id as target_id, - 'key' as target_type, - name, - ` + "`" + `limit` + "`" + `, - duration - FROM ratelimits - WHERE key_id IS NOT NULL - UNION - SELECT - identity_id as target_id, - 'identity' as target_type, - name, - ` + "`" + `limit` + "`" + `, - duration - FROM ratelimits - WHERE identity_id IS NOT NULL -) -SELECT - k.id, k.key_auth_id, k.hash, k.start, k.workspace_id, k.for_workspace_id, k.name, k.owner_id, k.identity_id, k.meta, k.expires, k.created_at_m, k.updated_at_m, k.deleted_at_m, k.refill_day, k.refill_amount, k.last_refill_at, k.enabled, k.remaining_requests, k.ratelimit_async, k.ratelimit_limit, k.ratelimit_duration, k.environment, - i.id, i.external_id, i.workspace_id, i.environment, i.meta, i.deleted, i.created_at, i.updated_at, - JSON_ARRAYAGG( - JSON_OBJECT( - 'target_type', rl.target_type, - 'name', rl.name, - 'limit', rl.limit, - 'duration', rl.duration - ) - ) as ratelimits, - GROUP_CONCAT(DISTINCT perms.permission_name) as permissions -FROM ` + "`" + `keys` + "`" + ` k -LEFT JOIN identities i ON k.identity_id = i.id -LEFT JOIN all_permissions perms ON k.id = perms.key_id -LEFT JOIN all_ratelimits rl ON ( - (rl.target_type = 'key' AND rl.target_id = k.id) OR - (rl.target_type = 'identity' AND rl.target_id = k.identity_id) -) -WHERE k.hash = ? -GROUP BY k.id +select k.id, + k.key_auth_id, + k.workspace_id, + k.for_workspace_id, + k.name, + k.meta, + k.expires, + k.deleted_at_m, + k.refill_day, + k.refill_amount, + k.last_refill_at, + k.enabled, + k.remaining_requests, + a.ip_whitelist, + a.workspace_id as api_workspace_id, + a.id as api_id, + a.deleted_at_m as api_deleted_at_m, + + COALESCE( + (SELECT JSON_ARRAYAGG(name) + FROM (SELECT name + FROM keys_roles kr + JOIN roles r ON r.id = kr.role_id + WHERE kr.key_id = k.id) as roles), + JSON_ARRAY() + ) as roles, + + COALESCE( + (SELECT JSON_ARRAYAGG(slug) + FROM (SELECT slug + FROM keys_permissions kp + JOIN permissions p ON kp.permission_id = p.id + WHERE kp.key_id = k.id + + UNION ALL + + SELECT slug + FROM keys_roles kr + JOIN roles_permissions rp ON kr.role_id = rp.role_id + JOIN permissions p ON rp.permission_id = p.id + WHERE kr.key_id = k.id) as combined_perms), + JSON_ARRAY() + ) as permissions, + + coalesce( + (select json_arrayagg( + json_object( + 'id', rl.id, + 'name', rl.name, + 'key_id', rl.key_id, + 'identity_id', rl.identity_id, + 'limit', rl.limit, + 'duration', rl.duration, + 'auto_apply', rl.auto_apply + ) + ) + from ` + "`" + `ratelimits` + "`" + ` rl + where rl.key_id = k.id + OR rl.identity_id = i.id), + json_array() + ) as ratelimits, + + i.id as identity_id, + i.external_id, + i.meta as identity_meta, + ka.deleted_at_m as key_auth_deleted_at_m, + ws.enabled as workspace_enabled, + fws.enabled as for_workspace_enabled +from ` + "`" + `keys` + "`" + ` k + JOIN apis a USING (key_auth_id) + JOIN key_auth ka ON ka.id = k.key_auth_id + JOIN workspaces ws ON ws.id = k.workspace_id + LEFT JOIN workspaces fws ON fws.id = k.for_workspace_id + LEFT JOIN identities i ON k.identity_id = i.id AND i.deleted = 0 +where k.hash = ? + and k.deleted_at_m is null ` type FindKeyForVerificationRow struct { - Key Key `db:"key"` - Identity Identity `db:"identity"` - Ratelimits json.RawMessage `db:"ratelimits"` - Permissions sql.NullString `db:"permissions"` + ID string `db:"id"` + KeyAuthID string `db:"key_auth_id"` + WorkspaceID string `db:"workspace_id"` + ForWorkspaceID sql.NullString `db:"for_workspace_id"` + Name sql.NullString `db:"name"` + Meta sql.NullString `db:"meta"` + Expires sql.NullTime `db:"expires"` + DeletedAtM sql.NullInt64 `db:"deleted_at_m"` + RefillDay sql.NullInt16 `db:"refill_day"` + RefillAmount sql.NullInt32 `db:"refill_amount"` + LastRefillAt sql.NullTime `db:"last_refill_at"` + Enabled bool `db:"enabled"` + RemainingRequests sql.NullInt32 `db:"remaining_requests"` + IpWhitelist sql.NullString `db:"ip_whitelist"` + ApiWorkspaceID string `db:"api_workspace_id"` + ApiID string `db:"api_id"` + ApiDeletedAtM sql.NullInt64 `db:"api_deleted_at_m"` + Roles interface{} `db:"roles"` + Permissions interface{} `db:"permissions"` + Ratelimits interface{} `db:"ratelimits"` + IdentityID sql.NullString `db:"identity_id"` + ExternalID sql.NullString `db:"external_id"` + IdentityMeta []byte `db:"identity_meta"` + KeyAuthDeletedAtM sql.NullInt64 `db:"key_auth_deleted_at_m"` + WorkspaceEnabled bool `db:"workspace_enabled"` + ForWorkspaceEnabled sql.NullBool `db:"for_workspace_enabled"` } // FindKeyForVerification // -// WITH direct_permissions AS ( -// SELECT kp.key_id, p.name as permission_name -// FROM keys_permissions kp -// JOIN permissions p ON kp.permission_id = p.id -// ), -// role_permissions AS ( -// SELECT kr.key_id, p.name as permission_name -// FROM keys_roles kr -// JOIN roles_permissions rp ON kr.role_id = rp.role_id -// JOIN permissions p ON rp.permission_id = p.id -// ), -// all_permissions AS ( -// SELECT key_id, permission_name FROM direct_permissions -// UNION -// SELECT key_id, permission_name FROM role_permissions -// ), -// all_ratelimits AS ( -// SELECT -// key_id as target_id, -// 'key' as target_type, -// name, -// `limit`, -// duration -// FROM ratelimits -// WHERE key_id IS NOT NULL -// UNION -// SELECT -// identity_id as target_id, -// 'identity' as target_type, -// name, -// `limit`, -// duration -// FROM ratelimits -// WHERE identity_id IS NOT NULL -// ) -// SELECT -// k.id, k.key_auth_id, k.hash, k.start, k.workspace_id, k.for_workspace_id, k.name, k.owner_id, k.identity_id, k.meta, k.expires, k.created_at_m, k.updated_at_m, k.deleted_at_m, k.refill_day, k.refill_amount, k.last_refill_at, k.enabled, k.remaining_requests, k.ratelimit_async, k.ratelimit_limit, k.ratelimit_duration, k.environment, -// i.id, i.external_id, i.workspace_id, i.environment, i.meta, i.deleted, i.created_at, i.updated_at, -// JSON_ARRAYAGG( -// JSON_OBJECT( -// 'target_type', rl.target_type, -// 'name', rl.name, -// 'limit', rl.limit, -// 'duration', rl.duration -// ) -// ) as ratelimits, -// GROUP_CONCAT(DISTINCT perms.permission_name) as permissions -// FROM `keys` k -// LEFT JOIN identities i ON k.identity_id = i.id -// LEFT JOIN all_permissions perms ON k.id = perms.key_id -// LEFT JOIN all_ratelimits rl ON ( -// (rl.target_type = 'key' AND rl.target_id = k.id) OR -// (rl.target_type = 'identity' AND rl.target_id = k.identity_id) -// ) -// WHERE k.hash = ? -// GROUP BY k.id +// select k.id, +// k.key_auth_id, +// k.workspace_id, +// k.for_workspace_id, +// k.name, +// k.meta, +// k.expires, +// k.deleted_at_m, +// k.refill_day, +// k.refill_amount, +// k.last_refill_at, +// k.enabled, +// k.remaining_requests, +// a.ip_whitelist, +// a.workspace_id as api_workspace_id, +// a.id as api_id, +// a.deleted_at_m as api_deleted_at_m, +// +// COALESCE( +// (SELECT JSON_ARRAYAGG(name) +// FROM (SELECT name +// FROM keys_roles kr +// JOIN roles r ON r.id = kr.role_id +// WHERE kr.key_id = k.id) as roles), +// JSON_ARRAY() +// ) as roles, +// +// COALESCE( +// (SELECT JSON_ARRAYAGG(slug) +// FROM (SELECT slug +// FROM keys_permissions kp +// JOIN permissions p ON kp.permission_id = p.id +// WHERE kp.key_id = k.id +// +// UNION ALL +// +// SELECT slug +// FROM keys_roles kr +// JOIN roles_permissions rp ON kr.role_id = rp.role_id +// JOIN permissions p ON rp.permission_id = p.id +// WHERE kr.key_id = k.id) as combined_perms), +// JSON_ARRAY() +// ) as permissions, +// +// coalesce( +// (select json_arrayagg( +// json_object( +// 'id', rl.id, +// 'name', rl.name, +// 'key_id', rl.key_id, +// 'identity_id', rl.identity_id, +// 'limit', rl.limit, +// 'duration', rl.duration, +// 'auto_apply', rl.auto_apply +// ) +// ) +// from `ratelimits` rl +// where rl.key_id = k.id +// OR rl.identity_id = i.id), +// json_array() +// ) as ratelimits, +// +// i.id as identity_id, +// i.external_id, +// i.meta as identity_meta, +// ka.deleted_at_m as key_auth_deleted_at_m, +// ws.enabled as workspace_enabled, +// fws.enabled as for_workspace_enabled +// from `keys` k +// JOIN apis a USING (key_auth_id) +// JOIN key_auth ka ON ka.id = k.key_auth_id +// JOIN workspaces ws ON ws.id = k.workspace_id +// LEFT JOIN workspaces fws ON fws.id = k.for_workspace_id +// LEFT JOIN identities i ON k.identity_id = i.id AND i.deleted = 0 +// where k.hash = ? +// and k.deleted_at_m is null func (q *Queries) FindKeyForVerification(ctx context.Context, db DBTX, hash string) (FindKeyForVerificationRow, error) { row := db.QueryRowContext(ctx, findKeyForVerification, hash) var i FindKeyForVerificationRow err := row.Scan( - &i.Key.ID, - &i.Key.KeyAuthID, - &i.Key.Hash, - &i.Key.Start, - &i.Key.WorkspaceID, - &i.Key.ForWorkspaceID, - &i.Key.Name, - &i.Key.OwnerID, - &i.Key.IdentityID, - &i.Key.Meta, - &i.Key.Expires, - &i.Key.CreatedAtM, - &i.Key.UpdatedAtM, - &i.Key.DeletedAtM, - &i.Key.RefillDay, - &i.Key.RefillAmount, - &i.Key.LastRefillAt, - &i.Key.Enabled, - &i.Key.RemainingRequests, - &i.Key.RatelimitAsync, - &i.Key.RatelimitLimit, - &i.Key.RatelimitDuration, - &i.Key.Environment, - &i.Identity.ID, - &i.Identity.ExternalID, - &i.Identity.WorkspaceID, - &i.Identity.Environment, - &i.Identity.Meta, - &i.Identity.Deleted, - &i.Identity.CreatedAt, - &i.Identity.UpdatedAt, - &i.Ratelimits, + &i.ID, + &i.KeyAuthID, + &i.WorkspaceID, + &i.ForWorkspaceID, + &i.Name, + &i.Meta, + &i.Expires, + &i.DeletedAtM, + &i.RefillDay, + &i.RefillAmount, + &i.LastRefillAt, + &i.Enabled, + &i.RemainingRequests, + &i.IpWhitelist, + &i.ApiWorkspaceID, + &i.ApiID, + &i.ApiDeletedAtM, + &i.Roles, &i.Permissions, + &i.Ratelimits, + &i.IdentityID, + &i.ExternalID, + &i.IdentityMeta, + &i.KeyAuthDeletedAtM, + &i.WorkspaceEnabled, + &i.ForWorkspaceEnabled, ) return i, err } diff --git a/go/pkg/db/key_find_for_verification_ratelimits.go b/go/pkg/db/key_find_for_verification_ratelimits.go new file mode 100644 index 0000000000..3e21f0c266 --- /dev/null +++ b/go/pkg/db/key_find_for_verification_ratelimits.go @@ -0,0 +1,11 @@ +package db + +type KeyFindForVerificationRatelimit struct { + ID string `json:"id"` + Name string `json:"name"` + Limit int `json:"limit"` + Duration int `json:"duration"` + AutoApply int `json:"auto_apply"` + KeyID string `json:"key_id"` + IdentityID string `json:"identity_id"` +} diff --git a/go/pkg/db/key_insert.sql_generated.go b/go/pkg/db/key_insert.sql_generated.go index 5b9ccdd520..f33b0d8259 100644 --- a/go/pkg/db/key_insert.sql_generated.go +++ b/go/pkg/db/key_insert.sql_generated.go @@ -27,11 +27,7 @@ INSERT INTO ` + "`" + `keys` + "`" + ` ( enabled, remaining_requests, refill_day, - refill_amount, - ratelimit_async, - ratelimit_limit, - ratelimit_duration, - environment + refill_amount ) VALUES ( ?, ?, @@ -48,10 +44,6 @@ INSERT INTO ` + "`" + `keys` + "`" + ` ( ?, ?, ?, - ?, - ?, - ?, - ?, ? ) ` @@ -72,10 +64,6 @@ type InsertKeyParams struct { RemainingRequests sql.NullInt32 `db:"remaining_requests"` RefillDay sql.NullInt16 `db:"refill_day"` RefillAmount sql.NullInt32 `db:"refill_amount"` - RatelimitAsync sql.NullBool `db:"ratelimit_async"` - RatelimitLimit sql.NullInt32 `db:"ratelimit_limit"` - RatelimitDuration sql.NullInt64 `db:"ratelimit_duration"` - Environment sql.NullString `db:"environment"` } // InsertKey @@ -96,11 +84,7 @@ type InsertKeyParams struct { // enabled, // remaining_requests, // refill_day, -// refill_amount, -// ratelimit_async, -// ratelimit_limit, -// ratelimit_duration, -// environment +// refill_amount // ) VALUES ( // ?, // ?, @@ -117,10 +101,6 @@ type InsertKeyParams struct { // ?, // ?, // ?, -// ?, -// ?, -// ?, -// ?, // ? // ) func (q *Queries) InsertKey(ctx context.Context, db DBTX, arg InsertKeyParams) error { @@ -140,10 +120,6 @@ func (q *Queries) InsertKey(ctx context.Context, db DBTX, arg InsertKeyParams) e arg.RemainingRequests, arg.RefillDay, arg.RefillAmount, - arg.RatelimitAsync, - arg.RatelimitLimit, - arg.RatelimitDuration, - arg.Environment, ) return err } diff --git a/go/pkg/db/key_update_credits.sql_generated.go b/go/pkg/db/key_update_credits.sql_generated.go index f001ec7cac..9465a0f7a4 100644 --- a/go/pkg/db/key_update_credits.sql_generated.go +++ b/go/pkg/db/key_update_credits.sql_generated.go @@ -12,7 +12,8 @@ import ( const updateKeyCredits = `-- name: UpdateKeyCredits :exec UPDATE ` + "`" + `keys` + "`" + ` -SET remaining_requests = CASE +SET remaining_requests = +CASE WHEN ? = 'set' THEN ? WHEN ? = 'increment' THEN remaining_requests + ? WHEN ? = 'decrement' AND remaining_requests - ? > 0 THEN remaining_requests - ? @@ -30,7 +31,8 @@ type UpdateKeyCreditsParams struct { // UpdateKeyCredits // // UPDATE `keys` -// SET remaining_requests = CASE +// SET remaining_requests = +// CASE // WHEN ? = 'set' THEN ? // WHEN ? = 'increment' THEN remaining_requests + ? // WHEN ? = 'decrement' AND remaining_requests - ? > 0 THEN remaining_requests - ? diff --git a/go/pkg/db/querier_generated.go b/go/pkg/db/querier_generated.go index 63f138bd0a..2e4d195e76 100644 --- a/go/pkg/db/querier_generated.go +++ b/go/pkg/db/querier_generated.go @@ -174,68 +174,92 @@ type Querier interface { // ELSE FALSE // END) AND k.deleted_at_m IS NULL AND a.deleted_at_m IS NULL FindKeyByIdOrHash(ctx context.Context, db DBTX, arg FindKeyByIdOrHashParams) (FindKeyByIdOrHashRow, error) + //FindKeyCredits + // + // SELECT remaining_requests FROM `keys` k WHERE k.id = ? + FindKeyCredits(ctx context.Context, db DBTX, id string) (sql.NullInt32, error) //FindKeyEncryptionByKeyID // // SELECT workspace_id, key_id, created_at, updated_at, encrypted, encryption_key_id FROM encrypted_keys WHERE key_id = ? FindKeyEncryptionByKeyID(ctx context.Context, db DBTX, keyID string) (EncryptedKey, error) //FindKeyForVerification // - // WITH direct_permissions AS ( - // SELECT kp.key_id, p.name as permission_name - // FROM keys_permissions kp - // JOIN permissions p ON kp.permission_id = p.id - // ), - // role_permissions AS ( - // SELECT kr.key_id, p.name as permission_name - // FROM keys_roles kr - // JOIN roles_permissions rp ON kr.role_id = rp.role_id - // JOIN permissions p ON rp.permission_id = p.id - // ), - // all_permissions AS ( - // SELECT key_id, permission_name FROM direct_permissions - // UNION - // SELECT key_id, permission_name FROM role_permissions - // ), - // all_ratelimits AS ( - // SELECT - // key_id as target_id, - // 'key' as target_type, - // name, - // `limit`, - // duration - // FROM ratelimits - // WHERE key_id IS NOT NULL - // UNION - // SELECT - // identity_id as target_id, - // 'identity' as target_type, - // name, - // `limit`, - // duration - // FROM ratelimits - // WHERE identity_id IS NOT NULL - // ) - // SELECT - // k.id, k.key_auth_id, k.hash, k.start, k.workspace_id, k.for_workspace_id, k.name, k.owner_id, k.identity_id, k.meta, k.expires, k.created_at_m, k.updated_at_m, k.deleted_at_m, k.refill_day, k.refill_amount, k.last_refill_at, k.enabled, k.remaining_requests, k.ratelimit_async, k.ratelimit_limit, k.ratelimit_duration, k.environment, - // i.id, i.external_id, i.workspace_id, i.environment, i.meta, i.deleted, i.created_at, i.updated_at, - // JSON_ARRAYAGG( - // JSON_OBJECT( - // 'target_type', rl.target_type, - // 'name', rl.name, - // 'limit', rl.limit, - // 'duration', rl.duration - // ) - // ) as ratelimits, - // GROUP_CONCAT(DISTINCT perms.permission_name) as permissions - // FROM `keys` k - // LEFT JOIN identities i ON k.identity_id = i.id - // LEFT JOIN all_permissions perms ON k.id = perms.key_id - // LEFT JOIN all_ratelimits rl ON ( - // (rl.target_type = 'key' AND rl.target_id = k.id) OR - // (rl.target_type = 'identity' AND rl.target_id = k.identity_id) - // ) - // WHERE k.hash = ? - // GROUP BY k.id + // select k.id, + // k.key_auth_id, + // k.workspace_id, + // k.for_workspace_id, + // k.name, + // k.meta, + // k.expires, + // k.deleted_at_m, + // k.refill_day, + // k.refill_amount, + // k.last_refill_at, + // k.enabled, + // k.remaining_requests, + // a.ip_whitelist, + // a.workspace_id as api_workspace_id, + // a.id as api_id, + // a.deleted_at_m as api_deleted_at_m, + // + // COALESCE( + // (SELECT JSON_ARRAYAGG(name) + // FROM (SELECT name + // FROM keys_roles kr + // JOIN roles r ON r.id = kr.role_id + // WHERE kr.key_id = k.id) as roles), + // JSON_ARRAY() + // ) as roles, + // + // COALESCE( + // (SELECT JSON_ARRAYAGG(slug) + // FROM (SELECT slug + // FROM keys_permissions kp + // JOIN permissions p ON kp.permission_id = p.id + // WHERE kp.key_id = k.id + // + // UNION ALL + // + // SELECT slug + // FROM keys_roles kr + // JOIN roles_permissions rp ON kr.role_id = rp.role_id + // JOIN permissions p ON rp.permission_id = p.id + // WHERE kr.key_id = k.id) as combined_perms), + // JSON_ARRAY() + // ) as permissions, + // + // coalesce( + // (select json_arrayagg( + // json_object( + // 'id', rl.id, + // 'name', rl.name, + // 'key_id', rl.key_id, + // 'identity_id', rl.identity_id, + // 'limit', rl.limit, + // 'duration', rl.duration, + // 'auto_apply', rl.auto_apply + // ) + // ) + // from `ratelimits` rl + // where rl.key_id = k.id + // OR rl.identity_id = i.id), + // json_array() + // ) as ratelimits, + // + // i.id as identity_id, + // i.external_id, + // i.meta as identity_meta, + // ka.deleted_at_m as key_auth_deleted_at_m, + // ws.enabled as workspace_enabled, + // fws.enabled as for_workspace_enabled + // from `keys` k + // JOIN apis a USING (key_auth_id) + // JOIN key_auth ka ON ka.id = k.key_auth_id + // JOIN workspaces ws ON ws.id = k.workspace_id + // LEFT JOIN workspaces fws ON fws.id = k.for_workspace_id + // LEFT JOIN identities i ON k.identity_id = i.id AND i.deleted = 0 + // where k.hash = ? + // and k.deleted_at_m is null FindKeyForVerification(ctx context.Context, db DBTX, hash string) (FindKeyForVerificationRow, error) //FindKeyRoleByKeyAndRoleID // @@ -326,6 +350,27 @@ type Querier interface { // FROM projects // WHERE workspace_id = ? AND slug = ? FindProjectByWorkspaceSlug(ctx context.Context, db DBTX, arg FindProjectByWorkspaceSlugParams) (Project, error) + //FindRatelimitNamespace + // + // SELECT id, workspace_id, name, created_at_m, updated_at_m, deleted_at_m, + // coalesce( + // (select json_arrayagg( + // json_object( + // 'id', ro.id, + // 'identifier', ro.identifier, + // 'limit', ro.limit, + // 'duration', ro.duration + // ) + // ) + // from ratelimit_overrides ro where ro.namespace_id = ns.id AND ro.deleted_at_m IS NULL), + // json_array() + // ) as overrides + // FROM `ratelimit_namespaces` ns + // WHERE ns.workspace_id = ? + // AND CASE WHEN ? IS NOT NULL THEN ns.name = ? + // WHEN ? IS NOT NULL THEN ns.id = ? + // ELSE false END + FindRatelimitNamespace(ctx context.Context, db DBTX, arg FindRatelimitNamespaceParams) (FindRatelimitNamespaceRow, error) //FindRatelimitNamespaceByID // // SELECT id, workspace_id, name, created_at_m, updated_at_m, deleted_at_m FROM `ratelimit_namespaces` @@ -424,6 +469,7 @@ type Querier interface { // name, // workspace_id, // auth_type, + // ip_whitelist, // key_auth_id, // created_at_m, // deleted_at_m @@ -434,6 +480,7 @@ type Querier interface { // ?, // ?, // ?, + // ?, // NULL // ) InsertApi(ctx context.Context, db DBTX, arg InsertApiParams) error @@ -617,11 +664,7 @@ type Querier interface { // enabled, // remaining_requests, // refill_day, - // refill_amount, - // ratelimit_async, - // ratelimit_limit, - // ratelimit_duration, - // environment + // refill_amount // ) VALUES ( // ?, // ?, @@ -638,10 +681,6 @@ type Querier interface { // ?, // ?, // ?, - // ?, - // ?, - // ?, - // ?, // ? // ) InsertKey(ctx context.Context, db DBTX, arg InsertKeyParams) error @@ -1002,18 +1041,6 @@ type Querier interface { // WHERE rp.role_id = ? // ORDER BY p.slug ListPermissionsByRoleID(ctx context.Context, db DBTX, roleID string) ([]Permission, error) - //ListRatelimitOverrideMatches - // - // SELECT id, workspace_id, namespace_id, identifier, `limit`, duration, async, sharding, created_at_m, updated_at_m, deleted_at_m FROM ratelimit_overrides - // WHERE - // workspace_id = ? - // AND namespace_id = ? - // AND ? LIKE - // REPLACE( - // REPLACE(identifier, '*', '%'), -- Replace * with % wildcard - // '_', '\\_' -- Escape underscore literals - // ) - ListRatelimitOverrideMatches(ctx context.Context, db DBTX, arg ListRatelimitOverrideMatchesParams) ([]RatelimitOverride, error) //ListRatelimitOverridesByNamespaceID // // SELECT id, workspace_id, namespace_id, identifier, `limit`, duration, async, sharding, created_at_m, updated_at_m, deleted_at_m FROM ratelimit_overrides @@ -1157,7 +1184,8 @@ type Querier interface { //UpdateKeyCredits // // UPDATE `keys` - // SET remaining_requests = CASE + // SET remaining_requests = + // CASE // WHEN ? = 'set' THEN ? // WHEN ? = 'increment' THEN remaining_requests + ? // WHEN ? = 'decrement' AND remaining_requests - ? > 0 THEN remaining_requests - ? diff --git a/go/pkg/db/queries/api_insert.sql b/go/pkg/db/queries/api_insert.sql index a4c992d480..793b7e686d 100644 --- a/go/pkg/db/queries/api_insert.sql +++ b/go/pkg/db/queries/api_insert.sql @@ -4,6 +4,7 @@ INSERT INTO apis ( name, workspace_id, auth_type, + ip_whitelist, key_auth_id, created_at_m, deleted_at_m @@ -14,5 +15,6 @@ INSERT INTO apis ( ?, ?, ?, + ?, NULL ); diff --git a/go/pkg/db/queries/key_find_credits.sql b/go/pkg/db/queries/key_find_credits.sql new file mode 100644 index 0000000000..7f2f71aee2 --- /dev/null +++ b/go/pkg/db/queries/key_find_credits.sql @@ -0,0 +1,2 @@ +-- name: FindKeyCredits :one +SELECT remaining_requests FROM `keys` k WHERE k.id = ?; diff --git a/go/pkg/db/queries/key_find_for_verification.sql b/go/pkg/db/queries/key_find_for_verification.sql index 0f765cfb14..d649de3567 100644 --- a/go/pkg/db/queries/key_find_for_verification.sql +++ b/go/pkg/db/queries/key_find_for_verification.sql @@ -1,57 +1,77 @@ -- name: FindKeyForVerification :one -WITH direct_permissions AS ( - SELECT kp.key_id, p.name as permission_name - FROM keys_permissions kp - JOIN permissions p ON kp.permission_id = p.id -), -role_permissions AS ( - SELECT kr.key_id, p.name as permission_name - FROM keys_roles kr - JOIN roles_permissions rp ON kr.role_id = rp.role_id - JOIN permissions p ON rp.permission_id = p.id -), -all_permissions AS ( - SELECT * FROM direct_permissions - UNION - SELECT * FROM role_permissions -), -all_ratelimits AS ( - SELECT - key_id as target_id, - 'key' as target_type, - name, - `limit`, - duration - FROM ratelimits - WHERE key_id IS NOT NULL - UNION - SELECT - identity_id as target_id, - 'identity' as target_type, - name, - `limit`, - duration - FROM ratelimits - WHERE identity_id IS NOT NULL -) -SELECT - sqlc.embed(k), - sqlc.embed(i), - JSON_ARRAYAGG( - JSON_OBJECT( - 'target_type', rl.target_type, - 'name', rl.name, - 'limit', rl.limit, - 'duration', rl.duration - ) - ) as ratelimits, - GROUP_CONCAT(DISTINCT perms.permission_name) as permissions -FROM `keys` k -LEFT JOIN identities i ON k.identity_id = i.id -LEFT JOIN all_permissions perms ON k.id = perms.key_id -LEFT JOIN all_ratelimits rl ON ( - (rl.target_type = 'key' AND rl.target_id = k.id) OR - (rl.target_type = 'identity' AND rl.target_id = k.identity_id) -) -WHERE k.hash = ? -GROUP BY k.id; +select k.id, + k.key_auth_id, + k.workspace_id, + k.for_workspace_id, + k.name, + k.meta, + k.expires, + k.deleted_at_m, + k.refill_day, + k.refill_amount, + k.last_refill_at, + k.enabled, + k.remaining_requests, + a.ip_whitelist, + a.workspace_id as api_workspace_id, + a.id as api_id, + a.deleted_at_m as api_deleted_at_m, + + COALESCE( + (SELECT JSON_ARRAYAGG(name) + FROM (SELECT name + FROM keys_roles kr + JOIN roles r ON r.id = kr.role_id + WHERE kr.key_id = k.id) as roles), + JSON_ARRAY() + ) as roles, + + COALESCE( + (SELECT JSON_ARRAYAGG(slug) + FROM (SELECT slug + FROM keys_permissions kp + JOIN permissions p ON kp.permission_id = p.id + WHERE kp.key_id = k.id + + UNION ALL + + SELECT slug + FROM keys_roles kr + JOIN roles_permissions rp ON kr.role_id = rp.role_id + JOIN permissions p ON rp.permission_id = p.id + WHERE kr.key_id = k.id) as combined_perms), + JSON_ARRAY() + ) as permissions, + + coalesce( + (select json_arrayagg( + json_object( + 'id', rl.id, + 'name', rl.name, + 'key_id', rl.key_id, + 'identity_id', rl.identity_id, + 'limit', rl.limit, + 'duration', rl.duration, + 'auto_apply', rl.auto_apply + ) + ) + from `ratelimits` rl + where rl.key_id = k.id + OR rl.identity_id = i.id), + json_array() + ) as ratelimits, + + i.id as identity_id, + i.external_id, + i.meta as identity_meta, + ka.deleted_at_m as key_auth_deleted_at_m, + ws.enabled as workspace_enabled, + fws.enabled as for_workspace_enabled +from `keys` k + JOIN apis a USING (key_auth_id) + JOIN key_auth ka ON ka.id = k.key_auth_id + JOIN workspaces ws ON ws.id = k.workspace_id + LEFT JOIN workspaces fws ON fws.id = k.for_workspace_id + LEFT JOIN identities i ON k.identity_id = i.id AND i.deleted = 0 +where k.hash = ? + and k.deleted_at_m is null; diff --git a/go/pkg/db/queries/key_insert.sql b/go/pkg/db/queries/key_insert.sql index 8461b8c85c..072e9fea12 100644 --- a/go/pkg/db/queries/key_insert.sql +++ b/go/pkg/db/queries/key_insert.sql @@ -15,11 +15,7 @@ INSERT INTO `keys` ( enabled, remaining_requests, refill_day, - refill_amount, - ratelimit_async, - ratelimit_limit, - ratelimit_duration, - environment + refill_amount ) VALUES ( sqlc.arg(id), sqlc.arg(keyring_id), @@ -36,9 +32,5 @@ INSERT INTO `keys` ( sqlc.arg(enabled), sqlc.arg(remaining_requests), sqlc.arg(refill_day), - sqlc.arg(refill_amount), - sqlc.arg(ratelimit_async), - sqlc.arg(ratelimit_limit), - sqlc.arg(ratelimit_duration), - sqlc.arg(environment) + sqlc.arg(refill_amount) ); diff --git a/go/pkg/db/queries/key_update_credits.sql b/go/pkg/db/queries/key_update_credits.sql index 96f9920e30..f7d5929f0d 100644 --- a/go/pkg/db/queries/key_update_credits.sql +++ b/go/pkg/db/queries/key_update_credits.sql @@ -1,6 +1,7 @@ -- name: UpdateKeyCredits :exec UPDATE `keys` -SET remaining_requests = CASE +SET remaining_requests = +CASE WHEN sqlc.narg('operation') = 'set' THEN sqlc.narg('credits') WHEN sqlc.narg('operation') = 'increment' THEN remaining_requests + sqlc.narg('credits') WHEN sqlc.narg('operation') = 'decrement' AND remaining_requests - sqlc.narg('credits') > 0 THEN remaining_requests - sqlc.narg('credits') diff --git a/go/pkg/db/queries/ratelimit_namespace_find.sql b/go/pkg/db/queries/ratelimit_namespace_find.sql new file mode 100644 index 0000000000..dbd9e8394e --- /dev/null +++ b/go/pkg/db/queries/ratelimit_namespace_find.sql @@ -0,0 +1,19 @@ +-- name: FindRatelimitNamespace :one +SELECT *, + coalesce( + (select json_arrayagg( + json_object( + 'id', ro.id, + 'identifier', ro.identifier, + 'limit', ro.limit, + 'duration', ro.duration + ) + ) + from ratelimit_overrides ro where ro.namespace_id = ns.id AND ro.deleted_at_m IS NULL), + json_array() + ) as overrides +FROM `ratelimit_namespaces` ns +WHERE ns.workspace_id = ? +AND CASE WHEN sqlc.narg('name') IS NOT NULL THEN ns.name = sqlc.narg('name') +WHEN sqlc.narg('id') IS NOT NULL THEN ns.id = sqlc.narg('id') +ELSE false END; diff --git a/go/pkg/db/queries/ratelimit_override_list_matches.sql b/go/pkg/db/queries/ratelimit_override_list_matches.sql deleted file mode 100644 index 2b75b7b8c6..0000000000 --- a/go/pkg/db/queries/ratelimit_override_list_matches.sql +++ /dev/null @@ -1,10 +0,0 @@ --- name: ListRatelimitOverrideMatches :many -SELECT * FROM ratelimit_overrides -WHERE - workspace_id = sqlc.arg(workspace_id) - AND namespace_id = sqlc.arg(namespace_id) - AND sqlc.arg(identifier) LIKE - REPLACE( - REPLACE(identifier, '*', '%'), -- Replace * with % wildcard - '_', '\\_' -- Escape underscore literals - ); diff --git a/go/pkg/db/ratelimit_namespace_find.sql.go b/go/pkg/db/ratelimit_namespace_find.sql.go new file mode 100644 index 0000000000..2a4bfb60ae --- /dev/null +++ b/go/pkg/db/ratelimit_namespace_find.sql.go @@ -0,0 +1,23 @@ +package db + +import ( + "database/sql" +) + +type FindRatelimitNamespaceLimitOverride struct { + ID string `json:"id"` + Identifier string `json:"identifier"` + Limit int64 `json:"limit"` + Duration int64 `json:"duration"` +} + +type FindRatelimitNamespace struct { + ID string `db:"id"` + WorkspaceID string `db:"workspace_id"` + Name string `db:"name"` + CreatedAtM int64 `db:"created_at_m"` + UpdatedAtM sql.NullInt64 `db:"updated_at_m"` + DeletedAtM sql.NullInt64 `db:"deleted_at_m"` + DirectOverrides map[string]FindRatelimitNamespaceLimitOverride `db:"direct_overrides"` + WildcardOverrides []FindRatelimitNamespaceLimitOverride `db:"wildcard_overrides"` +} diff --git a/go/pkg/db/ratelimit_namespace_find.sql_generated.go b/go/pkg/db/ratelimit_namespace_find.sql_generated.go new file mode 100644 index 0000000000..e054ec25af --- /dev/null +++ b/go/pkg/db/ratelimit_namespace_find.sql_generated.go @@ -0,0 +1,89 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 +// source: ratelimit_namespace_find.sql + +package db + +import ( + "context" + "database/sql" +) + +const findRatelimitNamespace = `-- name: FindRatelimitNamespace :one +SELECT id, workspace_id, name, created_at_m, updated_at_m, deleted_at_m, + coalesce( + (select json_arrayagg( + json_object( + 'id', ro.id, + 'identifier', ro.identifier, + 'limit', ro.limit, + 'duration', ro.duration + ) + ) + from ratelimit_overrides ro where ro.namespace_id = ns.id AND ro.deleted_at_m IS NULL), + json_array() + ) as overrides +FROM ` + "`" + `ratelimit_namespaces` + "`" + ` ns +WHERE ns.workspace_id = ? +AND CASE WHEN ? IS NOT NULL THEN ns.name = ? +WHEN ? IS NOT NULL THEN ns.id = ? +ELSE false END +` + +type FindRatelimitNamespaceParams struct { + WorkspaceID string `db:"workspace_id"` + Name sql.NullString `db:"name"` + ID sql.NullString `db:"id"` +} + +type FindRatelimitNamespaceRow struct { + ID string `db:"id"` + WorkspaceID string `db:"workspace_id"` + Name string `db:"name"` + CreatedAtM int64 `db:"created_at_m"` + UpdatedAtM sql.NullInt64 `db:"updated_at_m"` + DeletedAtM sql.NullInt64 `db:"deleted_at_m"` + Overrides interface{} `db:"overrides"` +} + +// FindRatelimitNamespace +// +// SELECT id, workspace_id, name, created_at_m, updated_at_m, deleted_at_m, +// coalesce( +// (select json_arrayagg( +// json_object( +// 'id', ro.id, +// 'identifier', ro.identifier, +// 'limit', ro.limit, +// 'duration', ro.duration +// ) +// ) +// from ratelimit_overrides ro where ro.namespace_id = ns.id AND ro.deleted_at_m IS NULL), +// json_array() +// ) as overrides +// FROM `ratelimit_namespaces` ns +// WHERE ns.workspace_id = ? +// AND CASE WHEN ? IS NOT NULL THEN ns.name = ? +// WHEN ? IS NOT NULL THEN ns.id = ? +// ELSE false END +func (q *Queries) FindRatelimitNamespace(ctx context.Context, db DBTX, arg FindRatelimitNamespaceParams) (FindRatelimitNamespaceRow, error) { + row := db.QueryRowContext(ctx, findRatelimitNamespace, + arg.WorkspaceID, + arg.Name, + arg.Name, + arg.ID, + arg.ID, + ) + var i FindRatelimitNamespaceRow + err := row.Scan( + &i.ID, + &i.WorkspaceID, + &i.Name, + &i.CreatedAtM, + &i.UpdatedAtM, + &i.DeletedAtM, + &i.Overrides, + ) + return i, err +} diff --git a/go/pkg/match/doc.go b/go/pkg/match/doc.go new file mode 100644 index 0000000000..310c3cbdd2 --- /dev/null +++ b/go/pkg/match/doc.go @@ -0,0 +1,2 @@ +// Package match provides pattern matching utilities. +package match diff --git a/go/pkg/match/wildcard.go b/go/pkg/match/wildcard.go new file mode 100644 index 0000000000..c536d61cd1 --- /dev/null +++ b/go/pkg/match/wildcard.go @@ -0,0 +1,44 @@ +package match + +import ( + "regexp" + "strings" +) + +// Wildcard checks if a string matches a wildcard pattern. +// The pattern can contain '*' as a wildcard that matches any sequence of characters. +// +// Examples: +// - Wildcard("test@gmail.com", "*@gmail.com") returns true +// - Wildcard("test@yahoo.com", "*@gmail.com") returns false +// - Wildcard("hello world", "hello*") returns true +// - Wildcard("hello world", "*world") returns true +// - Wildcard("hello world", "h*d") returns true +func Wildcard(s, pattern string) (bool, error) { + // Fast path for patterns without wildcards + if !strings.Contains(pattern, "*") { + return s == pattern, nil + } + + // Convert wildcard pattern to regex pattern + // Escape special regex characters except * + regexPattern := "" + for i := 0; i < len(pattern); i++ { + ch := pattern[i] + switch ch { + case '*': + regexPattern += ".*" + case '.', '^', '$', '+', '?', '{', '}', '[', ']', '(', ')', '|', '\\': + regexPattern += "\\" + string(ch) + default: + regexPattern += string(ch) + } + } + + // Anchor the pattern to match the entire string + regexPattern = "^" + regexPattern + "$" + + // Check if the pattern matches + matched, err := regexp.MatchString(regexPattern, s) + return matched, err +} diff --git a/go/pkg/match/wildcard_test.go b/go/pkg/match/wildcard_test.go new file mode 100644 index 0000000000..d2b780f584 --- /dev/null +++ b/go/pkg/match/wildcard_test.go @@ -0,0 +1,123 @@ +package match + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWildcard(t *testing.T) { + tests := []struct { + name string + s string + pattern string + expected bool + }{ + // Email patterns + { + name: "email wildcard match gmail", + s: "test@gmail.com", + pattern: "*@gmail.com", + expected: true, + }, + { + name: "email wildcard no match different domain", + s: "test@yahoo.com", + pattern: "*@gmail.com", + expected: false, + }, + { + name: "email exact match", + s: "test@gmail.com", + pattern: "test@gmail.com", + expected: true, + }, + // Prefix patterns + { + name: "prefix wildcard match", + s: "hello world", + pattern: "hello*", + expected: true, + }, + { + name: "prefix wildcard no match", + s: "goodbye world", + pattern: "hello*", + expected: false, + }, + // Suffix patterns + { + name: "suffix wildcard match", + s: "hello world", + pattern: "*world", + expected: true, + }, + { + name: "suffix wildcard no match", + s: "hello earth", + pattern: "*world", + expected: false, + }, + // Middle patterns + { + name: "middle wildcard match", + s: "hello world", + pattern: "h*d", + expected: true, + }, + { + name: "multiple wildcards", + s: "hello beautiful world", + pattern: "h*beau*world", + expected: true, + }, + // Special regex characters + { + name: "dots are literal", + s: "test@gmail.com", + pattern: "*@gmail.com", + expected: true, + }, + { + name: "dots must match exactly", + s: "test@gmailxcom", + pattern: "*@gmail.com", + expected: false, + }, + // Edge cases + { + name: "empty string with wildcard", + s: "", + pattern: "*", + expected: true, + }, + { + name: "wildcard only", + s: "anything", + pattern: "*", + expected: true, + }, + { + name: "no wildcard exact match", + s: "exact", + pattern: "exact", + expected: true, + }, + { + name: "no wildcard no match", + s: "different", + pattern: "exact", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := Wildcard(tt.s, tt.pattern) + require.NoError(t, err) + if result != tt.expected { + t.Errorf("Wildcard(%q, %q) = %v, want %v", tt.s, tt.pattern, result, tt.expected) + } + }) + } +} diff --git a/go/pkg/prometheus/metrics/keys.go b/go/pkg/prometheus/metrics/keys.go new file mode 100644 index 0000000000..bb34709415 --- /dev/null +++ b/go/pkg/prometheus/metrics/keys.go @@ -0,0 +1,42 @@ +/* +Package metrics provides Prometheus metric collectors for monitoring application performance. + +This file contains Key-Verification-related metrics for tracking what keys do. +*/ +package metrics + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var ( + // KeyVerificationsTotal tracks the number of key verifications handled, labeled by some data. + // Use this counter to monitor API traffic patterns and error rates. + // + // Example usage: + // metrics.KeyVerificationsTotal.WithLabelValues("ws_1234", "api_5678", "key_abcd", "true", "VALID").Inc() + KeyVerificationsTotal = promauto.NewCounterVec( + prometheus.CounterOpts{ + Subsystem: "key", + Name: "verifications_total", + Help: "Total number of Key verifications processed.", + ConstLabels: constLabels, + }, + []string{"workspaceId", "apiId", "keyId", "valid", "code"}, + ) + + // KeyCreditsSpentTotal tracks the total credits spent by keys, labeled by workspace ID, key ID, and identity ID. + // Use this counter to monitor credit usage patterns and error rates. + // + // Example usage: + // metrics.KeyCreditsSpentTotal.WithLabelValues("ws_1234", "key_abcd", "identity_xyz", "true").Add(5) + KeyCreditsSpentTotal = promauto.NewCounterVec( + prometheus.CounterOpts{ + Subsystem: "key", + Name: "credits_spent_total", + Help: "Total credits spent by keys", + }, + []string{"workspace_id", "key_id", "identity_id", "deducted"}, + ) +) diff --git a/go/pkg/rbac/permissions.go b/go/pkg/rbac/permissions.go index 407bffb0e5..fd70d33ec5 100644 --- a/go/pkg/rbac/permissions.go +++ b/go/pkg/rbac/permissions.go @@ -62,6 +62,9 @@ const ( // ReadKey permits viewing API key details ReadKey ActionType = "read_key" + + // VerifyKey permits verifying API keys + VerifyKey ActionType = "verify_key" ) // Predefined rate limiting actions. These constants define operations diff --git a/go/pkg/rbac/rbac.go b/go/pkg/rbac/rbac.go index 19e878d566..27e5fdbd5e 100644 --- a/go/pkg/rbac/rbac.go +++ b/go/pkg/rbac/rbac.go @@ -2,6 +2,7 @@ package rbac import ( "fmt" + "slices" "strings" ) @@ -69,11 +70,10 @@ func (r *RBAC) EvaluatePermissions(query PermissionQuery, permissions []string) func (r *RBAC) evaluateQueryV1(query PermissionQuery, permissions []string) (EvaluationResult, error) { // Handle simple permission check if query.Value != "" { - for _, p := range permissions { - if p == query.Value { - return EvaluationResult{Valid: true, Message: ""}, nil - } + if slices.Contains(permissions, query.Value) { + return EvaluationResult{Valid: true, Message: ""}, nil } + return EvaluationResult{ Valid: false, Message: fmt.Sprintf("Missing permission: '%s'", query.Value), @@ -107,6 +107,7 @@ func (r *RBAC) evaluateQueryV1(query PermissionQuery, permissions []string) (Eva } missingPerms = append(missingPerms, fmt.Sprintf("'%v'", child)) } + return EvaluationResult{ Valid: false, Message: fmt.Sprintf("Missing one of these permissions: [%s], have: [%s]", @@ -120,8 +121,10 @@ func (r *RBAC) evaluateQueryV1(query PermissionQuery, permissions []string) (Eva func formatPermissions(permissions []string) []string { formatted := make([]string, len(permissions)) + for i, p := range permissions { formatted[i] = fmt.Sprintf("'%s'", p) } + return formatted } diff --git a/go/pkg/testutil/http.go b/go/pkg/testutil/http.go index 50a55d0a82..eba64e44dc 100644 --- a/go/pkg/testutil/http.go +++ b/go/pkg/testutil/http.go @@ -12,13 +12,13 @@ import ( "github.com/unkeyed/unkey/go/internal/services/auditlogs" "github.com/unkeyed/unkey/go/internal/services/caches" "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/internal/services/ratelimit" "github.com/unkeyed/unkey/go/pkg/clickhouse" "github.com/unkeyed/unkey/go/pkg/clock" "github.com/unkeyed/unkey/go/pkg/counter" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/otel/logging" + "github.com/unkeyed/unkey/go/pkg/rbac" "github.com/unkeyed/unkey/go/pkg/testutil/containers" "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/vault" @@ -38,16 +38,15 @@ type Harness struct { middleware []zen.Middleware - DB db.Database - Caches caches.Caches - Logger logging.Logger - Keys keys.KeyService - Permissions permissions.PermissionService - Auditlogs auditlogs.AuditLogService - ClickHouse clickhouse.ClickHouse - Ratelimit ratelimit.Service - Vault *vault.Service - seeder *seed.Seeder + DB db.Database + Caches caches.Caches + Logger logging.Logger + Keys keys.KeyService + Auditlogs auditlogs.AuditLogService + ClickHouse clickhouse.ClickHouse + Ratelimit ratelimit.Service + Vault *vault.Service + seeder *seed.Seeder } func NewHarness(t *testing.T) *Harness { @@ -86,15 +85,6 @@ func NewHarness(t *testing.T) *Harness { }) require.NoError(t, err) - keyService, err := keys.New(keys.Config{ - Logger: logger, - DB: db, - Clock: clk, - KeyCache: caches.KeyByHash, - WorkspaceCache: caches.WorkspaceByID, - }) - require.NoError(t, err) - // Get ClickHouse connection string chDSN := containers.ClickHouse(t) @@ -108,14 +98,6 @@ func NewHarness(t *testing.T) *Harness { validator, err := validation.New() require.NoError(t, err) - permissionService, err := permissions.New(permissions.Config{ - DB: db, - Logger: logger, - Clock: clk, - Cache: caches.PermissionsByKeyId, - }) - require.NoError(t, err) - ctr, err := counter.NewRedis(counter.RedisConfig{ RedisURL: redisUrl, Logger: logger, @@ -129,6 +111,16 @@ func NewHarness(t *testing.T) *Harness { }) require.NoError(t, err) + keyService, err := keys.New(keys.Config{ + Logger: logger, + DB: db, + KeyCache: caches.VerificationKeyByHash, + RateLimiter: ratelimitService, + RBAC: rbac.New(), + Clickhouse: ch, + }) + require.NoError(t, err) + s3 := containers.S3(t) vaultStorage, err := storage.NewS3(storage.S3Config{ @@ -155,18 +147,17 @@ func NewHarness(t *testing.T) *Harness { seeder.Seed(context.Background()) h := Harness{ - t: t, - Logger: logger, - srv: srv, - validator: validator, - Keys: keyService, - Permissions: permissionService, - Ratelimit: ratelimitService, - Vault: v, - ClickHouse: ch, - DB: db, - seeder: seeder, - Clock: clk, + t: t, + Logger: logger, + srv: srv, + validator: validator, + Keys: keyService, + Ratelimit: ratelimitService, + Vault: v, + ClickHouse: ch, + DB: db, + seeder: seeder, + Clock: clk, Auditlogs: auditlogs.New(auditlogs.Config{ DB: db, Logger: logger, @@ -205,6 +196,30 @@ func (h *Harness) CreateWorkspace() db.Workspace { return h.seeder.CreateWorkspace(context.Background()) } +func (h *Harness) CreateApi(req seed.CreateApiRequest) db.Api { + return h.seeder.CreateAPI(context.Background(), req) +} + +func (h *Harness) CreateKey(req seed.CreateKeyRequest) seed.CreateKeyResponse { + return h.seeder.CreateKey(context.Background(), req) +} + +func (h *Harness) CreateIdentity(req seed.CreateIdentityRequest) string { + return h.seeder.CreateIdentity(context.Background(), req) +} + +func (h *Harness) CreateRatelimit(req seed.CreateRatelimitRequest) string { + return h.seeder.CreateRatelimit(context.Background(), req) +} + +func (h *Harness) CreateRole(req seed.CreateRoleRequest) string { + return h.seeder.CreateRole(context.Background(), req) +} + +func (h *Harness) CreatePermission(req seed.CreatePermissionRequest) string { + return h.seeder.CreatePermission(context.Background(), req) +} + func (h *Harness) Resources() seed.Resources { return h.seeder.Resources } diff --git a/go/pkg/testutil/seed/seed.go b/go/pkg/testutil/seed/seed.go index dd7df7df72..c414fbdcc2 100644 --- a/go/pkg/testutil/seed/seed.go +++ b/go/pkg/testutil/seed/seed.go @@ -9,8 +9,10 @@ import ( "github.com/go-sql-driver/mysql" "github.com/stretchr/testify/require" + "github.com/unkeyed/unkey/go/pkg/assert" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/hash" + "github.com/unkeyed/unkey/go/pkg/ptr" "github.com/unkeyed/unkey/go/pkg/uid" ) @@ -18,6 +20,7 @@ import ( type Resources struct { RootWorkspace db.Workspace RootKeyring db.KeyAuth + RootApi db.Api UserWorkspace db.Workspace } @@ -56,29 +59,60 @@ func (s *Seeder) CreateWorkspace(ctx context.Context) db.Workspace { // Seed initializes the database with test data func (s *Seeder) Seed(ctx context.Context) { - // Insert root workspace - + s.Resources.UserWorkspace = s.CreateWorkspace(ctx) s.Resources.RootWorkspace = s.CreateWorkspace(ctx) + s.Resources.RootApi = s.CreateAPI(ctx, CreateApiRequest{ + WorkspaceID: s.Resources.RootWorkspace.ID, + IpWhitelist: "", + EncryptedKeys: false, + Name: nil, + CreatedAt: nil, + DefaultPrefix: nil, + DefaultBytes: nil, + }) + keyring, err := db.Query.FindKeyringByID(ctx, s.DB.RW(), s.Resources.RootApi.KeyAuthID.String) + require.NoError(s.t, err) + s.Resources.RootKeyring = keyring +} - // Insert root keyring - insertRootKeyringParams := db.InsertKeyringParams{ - ID: uid.New("test_kr"), - WorkspaceID: s.Resources.RootWorkspace.ID, - StoreEncryptedKeys: false, - DefaultPrefix: sql.NullString{String: "test", Valid: true}, - DefaultBytes: sql.NullInt32{Int32: 8, Valid: true}, - CreatedAtM: time.Now().UnixMilli(), - } +type CreateApiRequest struct { + WorkspaceID string + IpWhitelist string + EncryptedKeys bool + Name *string + CreatedAt *int64 + DefaultPrefix *string + DefaultBytes *int32 +} - err := db.Query.InsertKeyring(ctx, s.DB.RW(), insertRootKeyringParams) +func (s *Seeder) CreateAPI(ctx context.Context, req CreateApiRequest) db.Api { + keyAuthID := uid.New(uid.KeyAuthPrefix) + err := db.Query.InsertKeyring(ctx, s.DB.RW(), db.InsertKeyringParams{ + ID: keyAuthID, + WorkspaceID: req.WorkspaceID, + CreatedAtM: time.Now().UnixMilli(), + DefaultPrefix: sql.NullString{String: ptr.SafeDeref(req.DefaultPrefix), Valid: req.DefaultPrefix != nil}, + DefaultBytes: sql.NullInt32{Int32: ptr.SafeDeref(req.DefaultBytes), Valid: req.DefaultBytes != nil}, + StoreEncryptedKeys: req.EncryptedKeys, + }) require.NoError(s.t, err) - s.Resources.RootKeyring, err = db.Query.FindKeyringByID(ctx, s.DB.RW(), insertRootKeyringParams.ID) + apiID := uid.New("api") + err = db.Query.InsertApi(ctx, s.DB.RW(), db.InsertApiParams{ + ID: apiID, + Name: ptr.SafeDeref(req.Name, "test-api"), + WorkspaceID: req.WorkspaceID, + IpWhitelist: sql.NullString{String: req.IpWhitelist, Valid: req.IpWhitelist != ""}, + AuthType: db.NullApisAuthType{Valid: true, ApisAuthType: db.ApisAuthTypeKey}, + KeyAuthID: sql.NullString{Valid: true, String: keyAuthID}, + CreatedAtM: ptr.SafeDeref(req.CreatedAt, time.Now().UnixMilli()), + }) require.NoError(s.t, err) - s.Resources.UserWorkspace = s.CreateWorkspace(ctx) - + api, err := db.Query.FindApiByID(ctx, s.DB.RW(), apiID) require.NoError(s.t, err) + + return api } // CreateRootKey creates a root key with optional permissions @@ -101,10 +135,6 @@ func (s *Seeder) CreateRootKey(ctx context.Context, workspaceID string, permissi RemainingRequests: sql.NullInt32{Int32: 0, Valid: false}, RefillDay: sql.NullInt16{Int16: 0, Valid: false}, RefillAmount: sql.NullInt32{Int32: 0, Valid: false}, - RatelimitAsync: sql.NullBool{Bool: false, Valid: false}, - RatelimitLimit: sql.NullInt32{Int32: 0, Valid: false}, - RatelimitDuration: sql.NullInt64{Int64: 0, Valid: false}, - Environment: sql.NullString{String: "", Valid: false}, } err := db.Query.InsertKey(ctx, s.DB.RW(), insertKeyParams) @@ -148,3 +178,243 @@ func (s *Seeder) CreateRootKey(ctx context.Context, workspaceID string, permissi return key } + +type CreateKeyRequest struct { + Disabled bool + WorkspaceID string + KeyAuthID string + Remaining *int32 + IdentityID *string + Meta *string + Expires *time.Time + Name *string + Deleted bool + + RefillAmount *int32 + RefillDay *int16 + + Permissions []CreatePermissionRequest + Roles []CreateRoleRequest + Ratelimits []CreateRatelimitRequest +} + +type CreateKeyResponse struct { + KeyID string + Key string + + RolesIds []string + PermissionIds []string +} + +func (s *Seeder) CreateKey(ctx context.Context, req CreateKeyRequest) CreateKeyResponse { + keyID := uid.New(uid.KeyPrefix) + key := uid.New("") + start := key[:4] + + err := db.Query.InsertKey(ctx, s.DB.RW(), db.InsertKeyParams{ + ID: keyID, + KeyringID: req.KeyAuthID, + WorkspaceID: req.WorkspaceID, + CreatedAtM: time.Now().UnixMilli(), + Hash: hash.Sha256(key), + Enabled: !req.Disabled, + Start: start, + Name: sql.NullString{String: ptr.SafeDeref(req.Name, "test-key"), Valid: true}, + ForWorkspaceID: sql.NullString{String: "", Valid: false}, + Meta: sql.NullString{String: ptr.SafeDeref(req.Meta, ""), Valid: req.Meta != nil}, + IdentityID: sql.NullString{String: ptr.SafeDeref(req.IdentityID, ""), Valid: req.IdentityID != nil}, + Expires: sql.NullTime{Time: ptr.SafeDeref(req.Expires, time.Now()), Valid: req.Expires != nil}, + RemainingRequests: sql.NullInt32{Int32: ptr.SafeDeref(req.Remaining, 0), Valid: req.Remaining != nil}, + RefillAmount: sql.NullInt32{Int32: ptr.SafeDeref(req.RefillAmount, 0), Valid: req.RefillAmount != nil}, + RefillDay: sql.NullInt16{Int16: ptr.SafeDeref(req.RefillDay, 0), Valid: req.RefillDay != nil}, + }) + require.NoError(s.t, err) + + res := CreateKeyResponse{ + KeyID: keyID, + Key: key, + } + + if req.Deleted { + err = db.Query.SoftDeleteKeyByID(ctx, s.DB.RW(), db.SoftDeleteKeyByIDParams{ + Now: sql.NullInt64{Int64: time.Now().UnixMilli(), Valid: true}, + ID: keyID, + }) + + require.NoError(s.t, err) + } + + for _, role := range req.Roles { + roleID := s.CreateRole(ctx, role) + err = db.Query.InsertKeyRole(ctx, s.DB.RW(), db.InsertKeyRoleParams{ + KeyID: keyID, + RoleID: roleID, + WorkspaceID: req.WorkspaceID, + CreatedAtM: time.Now().UnixMilli(), + }) + require.NoError(s.t, err) + res.RolesIds = append(res.RolesIds, roleID) + } + + for _, permission := range req.Permissions { + permissionID := s.CreatePermission(ctx, permission) + err = db.Query.InsertKeyPermission(ctx, s.DB.RW(), db.InsertKeyPermissionParams{ + KeyID: keyID, + PermissionID: permissionID, + WorkspaceID: req.WorkspaceID, + CreatedAt: time.Now().UnixMilli(), + }) + + require.NoError(s.t, err) + res.PermissionIds = append(res.PermissionIds, permissionID) + } + + for _, ratelimit := range req.Ratelimits { + ratelimit.KeyID = ptr.P(keyID) + s.CreateRatelimit(ctx, ratelimit) + } + + return res +} + +type CreateRatelimitRequest struct { + Name string + WorkspaceID string + AutoApply bool + Duration int64 + Limit int32 + IdentityID *string + KeyID *string +} + +func (s *Seeder) CreateRatelimit(ctx context.Context, req CreateRatelimitRequest) string { + ratelimitID := uid.New(uid.RatelimitPrefix) + var err error + if req.IdentityID != nil { + err = db.Query.InsertIdentityRatelimit(ctx, s.DB.RW(), db.InsertIdentityRatelimitParams{ + ID: ratelimitID, + WorkspaceID: req.WorkspaceID, + IdentityID: sql.NullString{String: *req.IdentityID, Valid: true}, + Name: req.Name, + Limit: req.Limit, + Duration: req.Duration, + AutoApply: req.AutoApply, + CreatedAt: time.Now().UnixMilli(), + }) + } + + if req.KeyID != nil { + err = db.Query.InsertKeyRatelimit(ctx, s.DB.RW(), db.InsertKeyRatelimitParams{ + ID: ratelimitID, + WorkspaceID: req.WorkspaceID, + KeyID: sql.NullString{String: *req.KeyID, Valid: true}, + Name: req.Name, + Limit: req.Limit, + Duration: req.Duration, + AutoApply: req.AutoApply, + CreatedAt: time.Now().UnixMilli(), + }) + } + + require.NoError(s.t, err) + + return ratelimitID +} + +type CreateIdentityRequest struct { + WorkspaceID string + ExternalID string + Meta []byte + Ratelimits []CreateRatelimitRequest +} + +func (s *Seeder) CreateIdentity(ctx context.Context, req CreateIdentityRequest) string { + metaBytes := []byte("{}") + if len(req.Meta) > 0 { + metaBytes = req.Meta + } + + require.NoError(s.t, assert.NotEmpty(req.ExternalID, "Identity ExternalID must be set")) + require.NoError(s.t, assert.NotEmpty(req.WorkspaceID, "Identity WorkspaceID must be set")) + + identityId := uid.New(uid.IdentityPrefix) + err := db.Query.InsertIdentity(ctx, s.DB.RW(), db.InsertIdentityParams{ + ID: identityId, + ExternalID: req.ExternalID, + WorkspaceID: req.WorkspaceID, + Environment: "", + CreatedAt: time.Now().UnixMilli(), + Meta: metaBytes, + }) + require.NoError(s.t, err) + + for _, ratelimit := range req.Ratelimits { + ratelimit.IdentityID = ptr.P(identityId) + s.CreateRatelimit(ctx, ratelimit) + } + + return identityId +} + +type CreateRoleRequest struct { + Name string + Description *string + WorkspaceID string + + Permissions []CreatePermissionRequest +} + +func (s *Seeder) CreateRole(ctx context.Context, req CreateRoleRequest) string { + require.NoError(s.t, assert.NotEmpty(req.WorkspaceID, "Role WorkspaceID must be set")) + require.NoError(s.t, assert.NotEmpty(req.Name, "Role Name must be set")) + + roleID := uid.New(uid.PermissionPrefix) + + err := db.Query.InsertRole(ctx, s.DB.RW(), db.InsertRoleParams{ + RoleID: roleID, + WorkspaceID: req.WorkspaceID, + Name: req.Name, + CreatedAt: time.Now().UnixMilli(), + Description: sql.NullString{Valid: req.Description != nil, String: ptr.SafeDeref(req.Description, "")}, + }) + require.NoError(s.t, err) + + for _, permission := range req.Permissions { + permissionID := s.CreatePermission(ctx, permission) + err = db.Query.InsertRolePermission(ctx, s.DB.RW(), db.InsertRolePermissionParams{ + RoleID: roleID, + PermissionID: permissionID, + WorkspaceID: req.WorkspaceID, + CreatedAtM: time.Now().UnixMilli(), + }) + require.NoError(s.t, err) + } + + return roleID +} + +type CreatePermissionRequest struct { + Name string + Slug string + Description *string + WorkspaceID string +} + +func (s *Seeder) CreatePermission(ctx context.Context, req CreatePermissionRequest) string { + require.NoError(s.t, assert.NotEmpty(req.WorkspaceID, "Permission WorkspaceID must be set")) + require.NoError(s.t, assert.NotEmpty(req.WorkspaceID, "Permission Name must be set")) + require.NoError(s.t, assert.NotEmpty(req.WorkspaceID, "Permission Slug must be set")) + + permissionID := uid.New(uid.PermissionPrefix) + err := db.Query.InsertPermission(ctx, s.DB.RW(), db.InsertPermissionParams{ + PermissionID: permissionID, + WorkspaceID: req.WorkspaceID, + Name: req.Name, + Slug: req.Slug, + Description: sql.NullString{Valid: req.Description != nil, String: ptr.SafeDeref(req.Description, "")}, + CreatedAtM: time.Now().UnixMilli(), + }) + require.NoError(s.t, err) + + return permissionID +} diff --git a/go/pkg/vault/storage/memory.go b/go/pkg/vault/storage/memory.go index 514ffa8a50..a122037245 100644 --- a/go/pkg/vault/storage/memory.go +++ b/go/pkg/vault/storage/memory.go @@ -22,7 +22,6 @@ type MemoryConfig struct { } func NewMemory(config MemoryConfig) (Storage, error) { - logger := config.Logger.With("service", "storage") return &memory{ diff --git a/go/pkg/vault/storage/s3.go b/go/pkg/vault/storage/s3.go index 190a8223e0..a8f8eb3003 100644 --- a/go/pkg/vault/storage/s3.go +++ b/go/pkg/vault/storage/s3.go @@ -31,7 +31,6 @@ type S3Config struct { } func NewS3(config S3Config) (Storage, error) { - logger := config.Logger.With("service", "storage") logger.Info("using s3 storage") @@ -43,7 +42,6 @@ func NewS3(config S3Config) (Storage, error) { URL: config.S3URL, HostnameImmutable: true, }, nil - }) cfg, err := awsConfig.LoadDefaultConfig(context.Background(), diff --git a/go/pkg/zen/session.go b/go/pkg/zen/session.go index 64db225134..457bfa1ed2 100644 --- a/go/pkg/zen/session.go +++ b/go/pkg/zen/session.go @@ -347,7 +347,6 @@ func (s *Session) JSON(status int, body any) error { // // Unlike [JSON], this method does not set any Content-Type header automatically. func (s *Session) Send(status int, body []byte) error { - return s.send(status, body) } diff --git a/internal/db/src/schema/identity.ts b/internal/db/src/schema/identity.ts index addbb6da1a..78adaf1972 100644 --- a/internal/db/src/schema/identity.ts +++ b/internal/db/src/schema/identity.ts @@ -78,13 +78,11 @@ export const ratelimits = mysqlTable( }, (table) => ({ nameIdx: index("name_idx").on(table.name), - uniqueNamePerKey: uniqueIndex("unique_name_per_key_idx").on(table.name, table.keyId), + uniqueNamePerKey: uniqueIndex("unique_name_per_key_idx").on(table.keyId, table.name), uniqueNamePerIdentity: uniqueIndex("unique_name_per_identity_idx").on( - table.name, table.identityId, + table.name, ), - identityId: index("identity_id_idx").on(table.identityId), - keyId: index("key_id_idx").on(table.keyId), }), ); diff --git a/internal/db/src/schema/ratelimit.ts b/internal/db/src/schema/ratelimit.ts index bab29713e5..a3ecd09e72 100644 --- a/internal/db/src/schema/ratelimit.ts +++ b/internal/db/src/schema/ratelimit.ts @@ -15,8 +15,8 @@ export const ratelimitNamespaces = mysqlTable( (table) => { return { uniqueNamePerWorkspaceIdx: unique("unique_name_per_workspace_idx").on( - table.name, table.workspaceId, + table.name, ), }; }, diff --git a/internal/db/src/schema/rbac.ts b/internal/db/src/schema/rbac.ts index 9a902fe812..1bc4855825 100644 --- a/internal/db/src/schema/rbac.ts +++ b/internal/db/src/schema/rbac.ts @@ -27,14 +27,13 @@ export const permissions = mysqlTable( }, (table) => { return { - workspaceIdIdx: index("workspace_id_idx").on(table.workspaceId), uniqueNamePerWorkspaceIdx: unique("unique_name_per_workspace_idx").on( - table.name, table.workspaceId, + table.name, ), uniqueSlugPerWorkspaceIdx: unique("unique_slug_per_workspace_idx").on( - table.slug, table.workspaceId, + table.slug, ), }; },