diff --git a/go/apps/api/routes/v2_keys_get_key/400_test.go b/go/apps/api/routes/v2_keys_get_key/400_test.go index 1a9c1518db..773aa67727 100644 --- a/go/apps/api/routes/v2_keys_get_key/400_test.go +++ b/go/apps/api/routes/v2_keys_get_key/400_test.go @@ -33,16 +33,26 @@ func TestGetKeyBadRequest(t *testing.T) { "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, } - t.Run("empty keyId string", func(t *testing.T) { - req := handler.Request{ - KeyId: "", - Decrypt: ptr.P(false), - } + req := handler.Request{ + KeyId: "", + Decrypt: ptr.P(false), + } + t.Run("empty keyId string", func(t *testing.T) { res := testutil.CallRoute[handler.Request, openapi.BadRequestErrorResponse](h, route, headers, req) require.Equal(t, 400, res.Status) require.NotNil(t, res.Body) require.NotNil(t, res.Body.Error) }) + t.Run("invalid key format", func(t *testing.T) { + headers := http.Header{ + "Content-Type": {"application/json"}, + "Authorization": {"Bearer invalid_key_format_not_uid"}, + } + res := testutil.CallRoute[handler.Request, openapi.UnauthorizedErrorResponse](h, route, headers, req) + require.Equal(t, 400, res.Status) + require.NotNil(t, res.Body) + require.NotNil(t, res.Body.Error) + }) } diff --git a/go/apps/api/routes/v2_keys_get_key/403_test.go b/go/apps/api/routes/v2_keys_get_key/403_test.go index 1a0ef725ef..6dcab70a8f 100644 --- a/go/apps/api/routes/v2_keys_get_key/403_test.go +++ b/go/apps/api/routes/v2_keys_get_key/403_test.go @@ -19,7 +19,6 @@ import ( ) func TestGetKeyForbidden(t *testing.T) { - h := testutil.NewHarness(t) ctx := context.Background() @@ -199,4 +198,52 @@ func TestGetKeyForbidden(t *testing.T) { require.Equal(t, 403, res.Status) require.NotNil(t, res.Body) }) + + t.Run("decrypt permission without read permission", func(t *testing.T) { + // Create root key with only decrypt permission, no read permission + rootKey := h.CreateRootKey(h.Resources().UserWorkspace.ID, "api.*.decrypt_key") + + headers := http.Header{ + "Content-Type": {"application/json"}, + "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, + } + + // Try to get key + readReq := handler.Request{ + KeyId: keyID, + Decrypt: ptr.P(false), // Even without decrypt, should fail on read permission + } + + res := testutil.CallRoute[handler.Request, openapi.ForbiddenErrorResponse](h, route, headers, readReq) + require.Equal(t, 403, res.Status) + require.NotNil(t, res.Body) + }) + + t.Run("wrong resource type permissions", func(t *testing.T) { + // Create root key with permissions for different resource type + rootKey := h.CreateRootKey(h.Resources().UserWorkspace.ID, "workspace.*.read", "identity.*.read") + + headers := http.Header{ + "Content-Type": {"application/json"}, + "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, + } + + res := testutil.CallRoute[handler.Request, openapi.ForbiddenErrorResponse](h, route, headers, req) + require.Equal(t, 403, res.Status) + require.NotNil(t, res.Body) + }) + + t.Run("specific API permission but wrong action", func(t *testing.T) { + // Create root key with permission for correct API but wrong action + rootKey := h.CreateRootKey(h.Resources().UserWorkspace.ID, fmt.Sprintf("api.%s.delete_key", apiID)) + + headers := http.Header{ + "Content-Type": {"application/json"}, + "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, + } + + res := testutil.CallRoute[handler.Request, openapi.ForbiddenErrorResponse](h, route, headers, req) + require.Equal(t, 403, res.Status) + require.NotNil(t, res.Body) + }) } diff --git a/go/apps/api/routes/v2_keys_get_key/404_test.go b/go/apps/api/routes/v2_keys_get_key/404_test.go index 9a94d54bb5..9aa3c53c94 100644 --- a/go/apps/api/routes/v2_keys_get_key/404_test.go +++ b/go/apps/api/routes/v2_keys_get_key/404_test.go @@ -10,6 +10,7 @@ import ( handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_get_key" "github.com/unkeyed/unkey/go/pkg/ptr" "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/testutil/seed" "github.com/unkeyed/unkey/go/pkg/uid" ) @@ -46,4 +47,31 @@ func TestGetKeyNotFound(t *testing.T) { require.Contains(t, res.Body.Error.Detail, "We could not find the requested key") }) + t.Run("key from different workspace", func(t *testing.T) { + // Create a different workspace + otherWorkspace := h.CreateWorkspace() + + // Create API and key in the other workspace + apiName := "other-workspace-api" + otherAPI := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: otherWorkspace.ID, + Name: &apiName, + }) + + otherKey := h.CreateKey(seed.CreateKeyRequest{ + KeyAuthID: otherAPI.KeyAuthID.String, + WorkspaceID: otherWorkspace.ID, + }) + + // Try to access the key from different workspace using our root key + req := handler.Request{ + KeyId: otherKey.KeyID, + Decrypt: ptr.P(false), + } + + res := testutil.CallRoute[handler.Request, openapi.NotFoundErrorResponse](h, route, headers, req) + require.Equal(t, 404, res.Status) + require.NotNil(t, res.Body) + require.Contains(t, res.Body.Error.Detail, "specified key was not found") + }) } diff --git a/go/apps/api/routes/v2_keys_get_key/412_test.go b/go/apps/api/routes/v2_keys_get_key/412_test.go index 7341af7ce5..8ead515e59 100644 --- a/go/apps/api/routes/v2_keys_get_key/412_test.go +++ b/go/apps/api/routes/v2_keys_get_key/412_test.go @@ -64,4 +64,96 @@ func TestPreconditionError(t *testing.T) { require.NotNil(t, res.Body) require.NotNil(t, res.Body.Error) }) + + t.Run("api not set up for key encryption", func(t *testing.T) { + h := testutil.NewHarness(t) + + apiName := "test-api" + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: h.Resources().UserWorkspace.ID, + Name: &apiName, + }) + + rootKey := h.CreateRootKey(h.Resources().UserWorkspace.ID, "api.*.read_key", "api.*.decrypt_key") + + headers := http.Header{ + "Content-Type": {"application/json"}, + "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, + } + + key := h.CreateKey(seed.CreateKeyRequest{ + KeyAuthID: api.KeyAuthID.String, + WorkspaceID: h.Resources().UserWorkspace.ID, + }) + + route := &handler.Handler{ + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, + Vault: h.Vault, + } + h.Register(route) + + req := handler.Request{ + Decrypt: ptr.P(true), + KeyId: key.KeyID, + } + + res := testutil.CallRoute[handler.Request, openapi.PreconditionFailedErrorResponse]( + h, + route, + headers, + req, + ) + require.Equal(t, 412, res.Status) + require.NotNil(t, res.Body) + require.NotNil(t, res.Body.Error) + require.Contains(t, res.Body.Error.Detail, "does not support key encryption") + }) + + t.Run("vault missing when decrypt requested", func(t *testing.T) { + h := testutil.NewHarness(t) + + // Create API using testutil helper + apiName := "test-api" + api := h.CreateApi(seed.CreateApiRequest{ + WorkspaceID: h.Resources().UserWorkspace.ID, + Name: &apiName, + }) + + // Create a root key with appropriate permissions + rootKey := h.CreateRootKey(h.Resources().UserWorkspace.ID, "api.*.read_key", "api.*.decrypt_key") + + // Set up request headers + headers := http.Header{ + "Content-Type": {"application/json"}, + "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, + } + + key := h.CreateKey(seed.CreateKeyRequest{ + KeyAuthID: api.KeyAuthID.String, + WorkspaceID: h.Resources().UserWorkspace.ID, + }) + + // Create route with nil vault + routeNoVault := &handler.Handler{ + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, + Vault: nil, // No vault + } + h.Register(routeNoVault) + + req := handler.Request{ + KeyId: key.KeyID, + Decrypt: ptr.P(true), + } + + res := testutil.CallRoute[handler.Request, openapi.PreconditionFailedErrorResponse](h, routeNoVault, headers, req) + require.Equal(t, 412, res.Status) + require.NotNil(t, res.Body) + require.Contains(t, res.Body.Error.Detail, "Vault hasn't been set up") + }) } diff --git a/go/apps/api/routes/v2_keys_get_key/500_test.go b/go/apps/api/routes/v2_keys_get_key/500_test.go new file mode 100644 index 0000000000..f65b3d1fd3 --- /dev/null +++ b/go/apps/api/routes/v2_keys_get_key/500_test.go @@ -0,0 +1,46 @@ +package handler_test + +import ( + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + "github.com/unkeyed/unkey/go/apps/api/openapi" + handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_get_key" + "github.com/unkeyed/unkey/go/pkg/ptr" + "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/uid" +) + +func TestInternalError(t *testing.T) { + h := testutil.NewHarness(t) + route := &handler.Handler{ + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, + Vault: h.Vault, + } + h.Register(route) + rootKey := h.CreateRootKey(h.Resources().UserWorkspace.ID, "api.*.read_key") + headers := http.Header{ + "Content-Type": {"application/json"}, + "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, + } + + t.Run("database connection closed during request", func(t *testing.T) { + // Close the database connections to simulate a database failure + err := h.DB.Close() + require.NoError(t, err) + req := handler.Request{ + KeyId: uid.New(uid.KeyPrefix), + Decrypt: ptr.P(false), + } + + res := testutil.CallRoute[handler.Request, openapi.InternalServerErrorResponse](h, route, headers, req) + require.Equal(t, 500, res.Status) + require.NotNil(t, res.Body) + require.Contains(t, res.Body.Error.Detail, "We could not load the requested key") + }) +} diff --git a/go/apps/api/routes/v2_keys_get_key/handler.go b/go/apps/api/routes/v2_keys_get_key/handler.go index e03af1c257..8284a11c8d 100644 --- a/go/apps/api/routes/v2_keys_get_key/handler.go +++ b/go/apps/api/routes/v2_keys_get_key/handler.go @@ -2,9 +2,9 @@ package handler import ( "context" - "database/sql" "encoding/json" "net/http" + "sort" "github.com/oapi-codegen/nullable" "github.com/unkeyed/unkey/go/apps/api/openapi" @@ -28,7 +28,6 @@ type ( // Handler implements zen.Route interface for the v2 keys.getKey endpoint type Handler struct { - // Services as public fields Logger logging.Logger DB db.Database Keys keys.KeyService @@ -36,12 +35,10 @@ type Handler struct { Vault *vault.Service } -// Method returns the HTTP method this route responds to func (h *Handler) Method() string { return "POST" } -// Path returns the URL path pattern this route matches func (h *Handler) Path() string { return "/v2/keys.getKey" } @@ -80,8 +77,10 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { ) } + keyData := db.ToKeyData(key) + // Validate key belongs to authorized workspace - if key.WorkspaceID != auth.AuthorizedWorkspaceID { + if keyData.Key.WorkspaceID != auth.AuthorizedWorkspaceID { return fault.New("key not found", fault.Code(codes.Data.Key.NotFound.URN()), fault.Internal("key belongs to different workspace"), @@ -98,7 +97,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { }), rbac.T(rbac.Tuple{ ResourceType: rbac.Api, - ResourceID: key.Api.ID, + ResourceID: keyData.Api.ID, Action: rbac.ReadKey, }), ))) @@ -106,253 +105,203 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { return err } - keyAuth, err := db.Query.FindKeyringByID(ctx, h.DB.RO(), key.KeyAuthID) - if err != nil { - if db.IsNotFound(err) { - return fault.New("api not set up for keys", - fault.Code(codes.App.Precondition.PreconditionFailed.URN()), - fault.Internal("api not set up for keys, keyauth not found"), fault.Public("The requested API is not set up to handle keys."), - ) - } - - return fault.Wrap(err, - fault.Code(codes.App.Internal.ServiceUnavailable.URN()), - fault.Internal("database error"), fault.Public("Failed to retrieve API information."), - ) - } - - decrypt := ptr.SafeDeref(req.Decrypt, false) + // Handle decryption if requested var plaintext *string + decrypt := ptr.SafeDeref(req.Decrypt, false) if decrypt { - if h.Vault == nil { - return fault.New("vault missing", - fault.Code(codes.App.Precondition.PreconditionFailed.URN()), - fault.Public("Vault hasn't been set up."), - ) - } - - // Permission check - err = auth.VerifyRootKey(ctx, keys.WithPermissions(rbac.Or( - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: "*", - Action: rbac.DecryptKey, - }), - rbac.T(rbac.Tuple{ - ResourceType: rbac.Api, - ResourceID: key.Api.ID, - Action: rbac.DecryptKey, - }), - ))) + plaintext, err = h.decryptKey(ctx, auth, keyData) if err != nil { return err } - - if !keyAuth.StoreEncryptedKeys { - return fault.New("api not set up for key encryption", - fault.Code(codes.App.Precondition.PreconditionFailed.URN()), - fault.Internal("api not set up for key encryption"), fault.Public("The API for this key does not support key encryption."), - ) - } - - // If the key is encrypted and the encryption key ID is valid, decrypt the key. - // Otherwise the key was never encrypted to begin with. - if key.EncryptedKey.Valid && key.EncryptionKeyID.Valid { - decrypted, decryptErr := h.Vault.Decrypt(ctx, &vaultv1.DecryptRequest{ - Keyring: key.WorkspaceID, - Encrypted: key.EncryptedKey.String, - }) - - if decryptErr != nil { - h.Logger.Error("failed to decrypt key", - "keyId", key.ID, - "error", decryptErr, - ) - } else { - plaintext = ptr.P(decrypted.GetPlaintext()) - } - } - } - k := openapi.KeyResponseData{ - CreatedAt: key.CreatedAtM, - Enabled: key.Enabled, - KeyId: key.ID, - Start: key.Start, - Plaintext: plaintext, - Name: nil, - Meta: nil, - Identity: nil, - Credits: nil, - Expires: nil, - Permissions: nil, - Ratelimits: nil, - Roles: nil, - UpdatedAt: nil, + response := openapi.KeyResponseData{ + CreatedAt: keyData.Key.CreatedAtM, + Enabled: keyData.Key.Enabled, + KeyId: keyData.Key.ID, + Start: keyData.Key.Start, + Plaintext: plaintext, } - if key.Name.Valid { - k.Name = ptr.P(key.Name.String) + // Set optional fields + if keyData.Key.Name.Valid { + response.Name = ptr.P(keyData.Key.Name.String) } - if key.UpdatedAtM.Valid { - k.UpdatedAt = ptr.P(key.UpdatedAtM.Int64) + if keyData.Key.UpdatedAtM.Valid { + response.UpdatedAt = ptr.P(keyData.Key.UpdatedAtM.Int64) } - if key.Expires.Valid { - k.Expires = ptr.P(key.Expires.Time.UnixMilli()) + if keyData.Key.Expires.Valid { + response.Expires = ptr.P(keyData.Key.Expires.Time.UnixMilli()) } - if key.RemainingRequests.Valid { - k.Credits = &openapi.KeyCreditsData{ - Remaining: nullable.NewNullableWithValue(int64(key.RemainingRequests.Int32)), - Refill: nil, + // Set credits + if keyData.Key.RemainingRequests.Valid { + response.Credits = &openapi.KeyCreditsData{ + Remaining: nullable.NewNullableWithValue(int64(keyData.Key.RemainingRequests.Int32)), } - if key.RefillAmount.Valid { + if keyData.Key.RefillAmount.Valid { var refillDay *int interval := openapi.Daily - if key.RefillDay.Valid { + if keyData.Key.RefillDay.Valid { interval = openapi.Monthly - refillDay = ptr.P(int(key.RefillDay.Int16)) + refillDay = ptr.P(int(keyData.Key.RefillDay.Int16)) } - k.Credits.Refill = &openapi.KeyCreditsRefill{ - Amount: int64(key.RefillAmount.Int32), + response.Credits.Refill = &openapi.KeyCreditsRefill{ + Amount: int64(keyData.Key.RefillAmount.Int32), Interval: interval, RefillDay: refillDay, } } } - if key.IdentityID.Valid { - identity, idErr := db.Query.FindIdentity(ctx, h.DB.RO(), db.FindIdentityParams{ - Identity: key.IdentityID.String, - Deleted: false, - WorkspaceID: auth.AuthorizedWorkspaceID, - }) - if idErr != nil { - if db.IsNotFound(idErr) { - return fault.New("identity not found for key", - fault.Code(codes.Data.Identity.NotFound.URN()), - fault.Internal("identity not found"), - fault.Public("The requested identity does not exist or has been deleted."), - ) + // Set identity + if keyData.Identity != nil { + response.Identity = &openapi.Identity{ + Id: keyData.Identity.ID, + ExternalId: keyData.Identity.ExternalID, + } + + if len(keyData.Identity.Meta) > 0 { + var identityMeta map[string]any + if err := json.Unmarshal(keyData.Identity.Meta, &identityMeta); err != nil { + h.Logger.Error("failed to unmarshal identity meta", "error", err) + } else { + response.Identity.Meta = &identityMeta } + } + } - return fault.Wrap(idErr, - fault.Code(codes.App.Internal.ServiceUnavailable.URN()), - fault.Internal("database error"), - fault.Public("Failed to retrieve Identity information."), - ) + // Set permissions, combine direct + role permissions + permissionSlugs := make(map[string]struct{}) + for _, p := range keyData.Permissions { + permissionSlugs[p.Slug] = struct{}{} + } + for _, p := range keyData.RolePermissions { + permissionSlugs[p.Slug] = struct{}{} + } + if len(permissionSlugs) > 0 { + slugs := make([]string, 0, len(permissionSlugs)) + for slug := range permissionSlugs { + slugs = append(slugs, slug) } + sort.Strings(slugs) + response.Permissions = &slugs + } - k.Identity = &openapi.Identity{ - Id: identity.ID, - ExternalId: identity.ExternalID, - Meta: nil, - Ratelimits: nil, + // Set roles + if len(keyData.Roles) > 0 { + roleNames := make([]string, len(keyData.Roles)) + for i, role := range keyData.Roles { + roleNames[i] = role.Name } + response.Roles = &roleNames + } - if len(identity.Meta) > 0 { - err = json.Unmarshal(identity.Meta, &k.Identity.Meta) - if err != nil { - return fault.Wrap(err, fault.Code(codes.App.Internal.UnexpectedError.URN()), - fault.Internal("unable to unmarshal identity meta"), - fault.Public("We encountered an error while trying to unmarshal the identity meta data."), - ) + // Set ratelimits + if len(keyData.Ratelimits) > 0 { + var keyRatelimits []openapi.RatelimitResponse + var identityRatelimits []openapi.RatelimitResponse + + for _, rl := range keyData.Ratelimits { + ratelimitResp := openapi.RatelimitResponse{ + Id: rl.ID, + Duration: rl.Duration, + Limit: int64(rl.Limit), + Name: rl.Name, + AutoApply: rl.AutoApply, } - } - ratelimits, rlErr := db.Query.ListIdentityRatelimitsByID(ctx, h.DB.RO(), sql.NullString{Valid: true, String: identity.ID}) - if rlErr != nil && !db.IsNotFound(rlErr) { - return fault.Wrap(rlErr, fault.Code(codes.App.Internal.UnexpectedError.URN()), - fault.Internal("unable to retrieve identity ratelimits"), - fault.Public("We encountered an error while trying to retrieve the identity ratelimits."), - ) + // Add to key ratelimits if it belongs to this key + if rl.KeyID.Valid { + keyRatelimits = append(keyRatelimits, ratelimitResp) + } + + // Add to identity ratelimits if it has an identity_id that matches + if rl.IdentityID.Valid { + identityRatelimits = append(identityRatelimits, ratelimitResp) + } } - identityRatelimits := make([]openapi.RatelimitResponse, 0) - for _, ratelimit := range ratelimits { - identityRatelimits = append(identityRatelimits, openapi.RatelimitResponse{ - Id: ratelimit.ID, - Duration: ratelimit.Duration, - Limit: int64(ratelimit.Limit), - Name: ratelimit.Name, - AutoApply: ratelimit.AutoApply, - }) + if len(keyRatelimits) > 0 { + response.Ratelimits = &keyRatelimits } if len(identityRatelimits) > 0 { - k.Identity.Ratelimits = ptr.P(identityRatelimits) + response.Identity.Ratelimits = &identityRatelimits } } - ratelimits, err := db.Query.ListRatelimitsByKeyID(ctx, h.DB.RO(), sql.NullString{String: key.ID, Valid: true}) - if err != nil && !db.IsNotFound(err) { - return fault.Wrap(err, fault.Code(codes.App.Internal.UnexpectedError.URN()), - fault.Internal("unable to retrieve key ratelimits"), - fault.Public("We encountered an error while trying to retrieve the key ratelimits."), - ) - } - - ratelimitsResponse := make([]openapi.RatelimitResponse, len(ratelimits)) - for idx, ratelimit := range ratelimits { - ratelimitsResponse[idx] = openapi.RatelimitResponse{ - Id: ratelimit.ID, - Duration: ratelimit.Duration, - Limit: int64(ratelimit.Limit), - Name: ratelimit.Name, - AutoApply: ratelimit.AutoApply, + // Set meta + if keyData.Key.Meta.Valid { + var meta map[string]any + if err := json.Unmarshal([]byte(keyData.Key.Meta.String), &meta); err != nil { + h.Logger.Error("failed to unmarshal key meta", "error", err) + } else { + response.Meta = &meta } } - if len(ratelimitsResponse) > 0 { - k.Ratelimits = ptr.P(ratelimitsResponse) - } + return s.JSON(http.StatusOK, Response{ + Meta: openapi.Meta{ + RequestId: s.RequestID(), + }, + Data: response, + }) +} - if key.Meta.Valid { - err = json.Unmarshal([]byte(key.Meta.String), &k.Meta) - if err != nil { - return fault.Wrap(err, fault.Code(codes.App.Internal.UnexpectedError.URN()), - fault.Internal("unable to unmarshal key meta"), - fault.Public("We encountered an error while trying to unmarshal the key meta data."), - ) - } +func (h *Handler) decryptKey(ctx context.Context, auth *keys.KeyVerifier, keyData *db.KeyData) (*string, error) { + if h.Vault == nil { + return nil, fault.New("vault missing", + fault.Code(codes.App.Precondition.PreconditionFailed.URN()), + fault.Public("Vault hasn't been set up."), + ) } - permissionSlugs, err := db.Query.ListPermissionsByKeyID(ctx, h.DB.RO(), db.ListPermissionsByKeyIDParams{ - KeyID: k.KeyId, - }) + // Permission check for decryption + err := auth.VerifyRootKey(ctx, keys.WithPermissions(rbac.Or( + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: "*", + Action: rbac.DecryptKey, + }), + rbac.T(rbac.Tuple{ + ResourceType: rbac.Api, + ResourceID: keyData.Api.ID, + Action: rbac.DecryptKey, + }), + ))) if err != nil { - return fault.Wrap(err, fault.Code(codes.App.Internal.UnexpectedError.URN()), - fault.Internal("unable to find permissions for key"), fault.Public("Could not load permissions for key.")) - } - if len(permissionSlugs) > 0 { - k.Permissions = ptr.P(permissionSlugs) + return nil, err } - // Get roles for the key - roles, err := db.Query.ListRolesByKeyID(ctx, h.DB.RO(), k.KeyId) - if err != nil { - return fault.Wrap(err, fault.Code(codes.App.Internal.UnexpectedError.URN()), - fault.Internal("unable to find roles for key"), fault.Public("Could not load roles for key.")) + if !keyData.KeyAuth.StoreEncryptedKeys { + return nil, fault.New("api not set up for key encryption", + fault.Code(codes.App.Precondition.PreconditionFailed.URN()), + fault.Internal("api not set up for key encryption"), + fault.Public("The API for this key does not support key encryption."), + ) } - if len(roles) > 0 { - roleNames := make([]string, len(roles)) - for i, role := range roles { - roleNames[i] = role.Name - } - - k.Roles = ptr.P(roleNames) + // Only decrypt if the key is actually encrypted + if !keyData.EncryptedKey.Valid || !keyData.EncryptionKeyID.Valid { + return nil, nil } - return s.JSON(http.StatusOK, Response{ - Meta: openapi.Meta{ - RequestId: s.RequestID(), - }, - Data: k, + decrypted, err := h.Vault.Decrypt(ctx, &vaultv1.DecryptRequest{ + Keyring: keyData.Key.WorkspaceID, + Encrypted: keyData.EncryptedKey.String, }) + if err != nil { + h.Logger.Error("failed to decrypt key", + "keyId", keyData.Key.ID, + "error", err, + ) + return nil, nil // Return nil instead of failing the entire request + } + + return ptr.P(decrypted.GetPlaintext()), nil } diff --git a/go/apps/api/routes/v2_keys_whoami/500_test.go b/go/apps/api/routes/v2_keys_whoami/500_test.go new file mode 100644 index 0000000000..c23e63effd --- /dev/null +++ b/go/apps/api/routes/v2_keys_whoami/500_test.go @@ -0,0 +1,45 @@ +package handler_test + +import ( + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + "github.com/unkeyed/unkey/go/apps/api/openapi" + handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_keys_whoami" + "github.com/unkeyed/unkey/go/pkg/testutil" +) + +func TestInternalError(t *testing.T) { + h := testutil.NewHarness(t) + route := &handler.Handler{ + Logger: h.Logger, + DB: h.DB, + Keys: h.Keys, + Auditlogs: h.Auditlogs, + Vault: h.Vault, + } + h.Register(route) + + rootKey := h.CreateRootKey(h.Resources().UserWorkspace.ID, "api.*.read_key") + headers := http.Header{ + "Content-Type": {"application/json"}, + "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, + } + + t.Run("database connection closed during request", func(t *testing.T) { + // Close the database connections to simulate a database failure + err := h.DB.Close() + require.NoError(t, err) + + req := handler.Request{ + Key: "test_some_raw_key_string", + } + + res := testutil.CallRoute[handler.Request, openapi.InternalServerErrorResponse](h, route, headers, req) + require.Equal(t, 500, res.Status) + require.NotNil(t, res.Body) + require.Contains(t, res.Body.Error.Detail, "We could not load the requested key") + }) +} diff --git a/go/apps/api/routes/v2_keys_whoami/handler.go b/go/apps/api/routes/v2_keys_whoami/handler.go index 1ec9beacf0..8f32f9f241 100644 --- a/go/apps/api/routes/v2_keys_whoami/handler.go +++ b/go/apps/api/routes/v2_keys_whoami/handler.go @@ -2,9 +2,9 @@ package handler import ( "context" - "database/sql" "encoding/json" "net/http" + "sort" "github.com/oapi-codegen/nullable" "github.com/unkeyed/unkey/go/apps/api/openapi" @@ -26,8 +26,8 @@ type ( Response = openapi.V2KeysWhoamiResponseBody ) +// Handler implements zen.Route interface for the v2 keys.whoami endpoint type Handler struct { - // Services as public fields Logger logging.Logger DB db.Database Keys keys.KeyService @@ -35,12 +35,10 @@ type Handler struct { Vault *vault.Service } -// Method returns the HTTP method this route responds to func (h *Handler) Method() string { return "POST" } -// Path returns the URL path pattern this route matches func (h *Handler) Path() string { return "/v2/keys.whoami" } @@ -76,8 +74,10 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { ) } + keyData := db.ToKeyData(key) + // Validate key belongs to authorized workspace - if key.WorkspaceID != auth.AuthorizedWorkspaceID { + if keyData.Key.WorkspaceID != auth.AuthorizedWorkspaceID { return fault.New("key not found", fault.Code(codes.Data.Key.NotFound.URN()), fault.Internal("key belongs to different workspace"), @@ -94,7 +94,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { }), rbac.T(rbac.Tuple{ ResourceType: rbac.Api, - ResourceID: key.Api.ID, + ResourceID: keyData.Api.ID, Action: rbac.ReadKey, }), ))) @@ -106,181 +106,138 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { ) } - k := openapi.KeyResponseData{ - CreatedAt: key.CreatedAtM, - Enabled: key.Enabled, - KeyId: key.ID, - Start: key.Start, - Plaintext: nil, - Name: nil, - Meta: nil, - Identity: nil, - Credits: nil, - Expires: nil, - Permissions: nil, - Ratelimits: nil, - Roles: nil, - UpdatedAt: nil, + response := openapi.KeyResponseData{ + CreatedAt: keyData.Key.CreatedAtM, + Enabled: keyData.Key.Enabled, + KeyId: keyData.Key.ID, + Start: keyData.Key.Start, + Plaintext: nil, } - if key.Name.Valid { - k.Name = ptr.P(key.Name.String) + // Set optional fields + if keyData.Key.Name.Valid { + response.Name = ptr.P(keyData.Key.Name.String) } - if key.UpdatedAtM.Valid { - k.UpdatedAt = ptr.P(key.UpdatedAtM.Int64) + if keyData.Key.UpdatedAtM.Valid { + response.UpdatedAt = ptr.P(keyData.Key.UpdatedAtM.Int64) } - if key.Expires.Valid { - k.Expires = ptr.P(key.Expires.Time.UnixMilli()) + if keyData.Key.Expires.Valid { + response.Expires = ptr.P(keyData.Key.Expires.Time.UnixMilli()) } - if key.RemainingRequests.Valid { - k.Credits = &openapi.KeyCreditsData{ - Remaining: nullable.NewNullableWithValue(int64(key.RemainingRequests.Int32)), - Refill: nil, + // Set credits + if keyData.Key.RemainingRequests.Valid { + response.Credits = &openapi.KeyCreditsData{ + Remaining: nullable.NewNullableWithValue(int64(keyData.Key.RemainingRequests.Int32)), } - if key.RefillAmount.Valid { + if keyData.Key.RefillAmount.Valid { var refillDay *int interval := openapi.Daily - if key.RefillDay.Valid { + if keyData.Key.RefillDay.Valid { interval = openapi.Monthly - refillDay = ptr.P(int(key.RefillDay.Int16)) + refillDay = ptr.P(int(keyData.Key.RefillDay.Int16)) } - k.Credits.Refill = &openapi.KeyCreditsRefill{ - Amount: int64(key.RefillAmount.Int32), + response.Credits.Refill = &openapi.KeyCreditsRefill{ + Amount: int64(keyData.Key.RefillAmount.Int32), Interval: interval, RefillDay: refillDay, } } } - if key.IdentityID.Valid { - identity, idErr := db.Query.FindIdentity(ctx, h.DB.RO(), db.FindIdentityParams{Identity: key.IdentityID.String, WorkspaceID: auth.AuthorizedWorkspaceID, Deleted: false}) - if idErr != nil { - if db.IsNotFound(idErr) { - return fault.New("identity not found for key", - fault.Code(codes.Data.Identity.NotFound.URN()), - fault.Internal("identity not found"), - fault.Public("The requested identity does not exist or has been deleted."), - ) - } - - return fault.Wrap(idErr, - fault.Code(codes.App.Internal.ServiceUnavailable.URN()), - fault.Internal("database error"), - fault.Public("Failed to retrieve Identity information."), - ) - } - - k.Identity = &openapi.Identity{ - Id: identity.ID, - ExternalId: identity.ExternalID, - Meta: nil, - Ratelimits: nil, + // Set identity + if keyData.Identity != nil { + response.Identity = &openapi.Identity{ + Id: keyData.Identity.ID, + ExternalId: keyData.Identity.ExternalID, } - if len(identity.Meta) > 0 { - err = json.Unmarshal(identity.Meta, &k.Identity.Meta) - if err != nil { - return fault.Wrap(err, fault.Code(codes.App.Internal.UnexpectedError.URN()), - fault.Internal("unable to unmarshal identity meta"), - fault.Public("We encountered an error while trying to unmarshal the identity meta data."), - ) + if len(keyData.Identity.Meta) > 0 { + var identityMeta map[string]any + if err := json.Unmarshal(keyData.Identity.Meta, &identityMeta); err != nil { + h.Logger.Error("failed to unmarshal identity meta", "error", err) + } else { + response.Identity.Meta = &identityMeta } } - - ratelimits, rlErr := db.Query.ListIdentityRatelimitsByID(ctx, h.DB.RO(), sql.NullString{Valid: true, String: identity.ID}) - if rlErr != nil && !db.IsNotFound(rlErr) { - return fault.Wrap(rlErr, fault.Code(codes.App.Internal.UnexpectedError.URN()), - fault.Internal("unable to retrieve identity ratelimits"), - fault.Public("We encountered an error while trying to retrieve the identity ratelimits."), - ) - } - - identityRatelimits := make([]openapi.RatelimitResponse, 0, len(ratelimits)) - for _, ratelimit := range ratelimits { - identityRatelimits = append(identityRatelimits, openapi.RatelimitResponse{ - Id: ratelimit.ID, - Duration: ratelimit.Duration, - Limit: int64(ratelimit.Limit), - Name: ratelimit.Name, - AutoApply: ratelimit.AutoApply, - }) - } - - if len(identityRatelimits) > 0 { - k.Identity.Ratelimits = ptr.P(identityRatelimits) - } } - ratelimits, err := db.Query.ListRatelimitsByKeyID(ctx, h.DB.RO(), sql.NullString{String: key.ID, Valid: true}) - if err != nil && !db.IsNotFound(err) { - return fault.Wrap(err, fault.Code(codes.App.Internal.UnexpectedError.URN()), - fault.Internal("unable to retrieve key ratelimits"), - fault.Public("We encountered an error while trying to retrieve the key ratelimits."), - ) + // Set permissions, combine direct + role permissions + permissionSlugs := make(map[string]struct{}) + for _, p := range keyData.Permissions { + permissionSlugs[p.Slug] = struct{}{} } - - ratelimitsResponse := make([]openapi.RatelimitResponse, len(ratelimits)) - for idx, ratelimit := range ratelimits { - ratelimitsResponse[idx] = openapi.RatelimitResponse{ - Id: ratelimit.ID, - Duration: ratelimit.Duration, - Limit: int64(ratelimit.Limit), - Name: ratelimit.Name, - AutoApply: ratelimit.AutoApply, - } + for _, p := range keyData.RolePermissions { + permissionSlugs[p.Slug] = struct{}{} } - - if len(ratelimitsResponse) > 0 { - k.Ratelimits = ptr.P(ratelimitsResponse) + if len(permissionSlugs) > 0 { + slugs := make([]string, 0, len(permissionSlugs)) + for slug := range permissionSlugs { + slugs = append(slugs, slug) + } + sort.Strings(slugs) + response.Permissions = &slugs } - if key.Meta.Valid { - err = json.Unmarshal([]byte(key.Meta.String), &k.Meta) - if err != nil { - return fault.Wrap(err, fault.Code(codes.App.Internal.UnexpectedError.URN()), - fault.Internal("unable to unmarshal key meta"), - fault.Public("We encountered an error while trying to unmarshal the key meta data."), - ) + // Set roles + if len(keyData.Roles) > 0 { + roleNames := make([]string, len(keyData.Roles)) + for i, role := range keyData.Roles { + roleNames[i] = role.Name } + response.Roles = &roleNames } - permissionSlugs, err := db.Query.ListPermissionsByKeyID(ctx, h.DB.RO(), db.ListPermissionsByKeyIDParams{ - KeyID: k.KeyId, - }) - if err != nil { - return fault.Wrap(err, fault.Code(codes.App.Internal.UnexpectedError.URN()), - fault.Internal("unable to find permissions for key"), fault.Public("Could not load permissions for key.")) - } + // Set ratelimits + if len(keyData.Ratelimits) > 0 { + var keyRatelimits []openapi.RatelimitResponse + var identityRatelimits []openapi.RatelimitResponse - if len(permissionSlugs) > 0 { - k.Permissions = ptr.P(permissionSlugs) - } + for _, rl := range keyData.Ratelimits { + ratelimitResp := openapi.RatelimitResponse{ + Id: rl.ID, + Duration: rl.Duration, + Limit: int64(rl.Limit), + Name: rl.Name, + AutoApply: rl.AutoApply, + } - // Get roles for the key - roles, err := db.Query.ListRolesByKeyID(ctx, h.DB.RO(), k.KeyId) - if err != nil { - return fault.Wrap(err, fault.Code(codes.App.Internal.UnexpectedError.URN()), - fault.Internal("unable to find roles for key"), fault.Public("Could not load roles for key.")) - } + // Add to key ratelimits if it belongs to this key + if rl.KeyID.Valid { + keyRatelimits = append(keyRatelimits, ratelimitResp) + } + // Add to identity ratelimits if it has an identity_id that matches + if rl.IdentityID.Valid { + identityRatelimits = append(identityRatelimits, ratelimitResp) + } + } - if len(roles) > 0 { - roleNames := make([]string, len(roles)) - for i, role := range roles { - roleNames[i] = role.Name + if len(keyRatelimits) > 0 { + response.Ratelimits = &keyRatelimits } + if len(identityRatelimits) > 0 { + response.Identity.Ratelimits = &identityRatelimits + } + } - k.Roles = ptr.P(roleNames) + // Set meta + if keyData.Key.Meta.Valid { + var meta map[string]any + if err := json.Unmarshal([]byte(keyData.Key.Meta.String), &meta); err != nil { + h.Logger.Error("failed to unmarshal key meta", "error", err) + } else { + response.Meta = &meta + } } return s.JSON(http.StatusOK, Response{ Meta: openapi.Meta{ RequestId: s.RequestID(), }, - Data: k, + Data: response, }) } diff --git a/go/pkg/db/key_data_test.go b/go/pkg/db/key_data_test.go new file mode 100644 index 0000000000..ea2146d7a6 --- /dev/null +++ b/go/pkg/db/key_data_test.go @@ -0,0 +1,255 @@ +package db + +import ( + "database/sql" + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestToKeyData_ValidCases(t *testing.T) { + t.Run("FindLiveKeyByIDRow value", func(t *testing.T) { + row := FindLiveKeyByIDRow{ + ID: "test-key-id", + Hash: "test-hash", + WorkspaceID: "test-workspace", + Enabled: true, + } + + result := ToKeyData(row) + + require.NotNil(t, result) + require.Equal(t, "test-key-id", result.Key.ID) + require.Equal(t, "test-hash", result.Key.Hash) + require.Equal(t, "test-workspace", result.Key.WorkspaceID) + require.True(t, result.Key.Enabled) + }) + + t.Run("FindLiveKeyByIDRow pointer", func(t *testing.T) { + row := FindLiveKeyByIDRow{ + ID: "test-key-id-ptr", + Hash: "test-hash-ptr", + WorkspaceID: "test-workspace-ptr", + Enabled: false, + } + + result := ToKeyData(row) + + require.NotNil(t, result) + require.Equal(t, "test-key-id-ptr", result.Key.ID) + require.Equal(t, "test-hash-ptr", result.Key.Hash) + require.Equal(t, "test-workspace-ptr", result.Key.WorkspaceID) + require.False(t, result.Key.Enabled) + }) + + t.Run("FindLiveKeyByHashRow value", func(t *testing.T) { + row := FindLiveKeyByHashRow{ + ID: "hash-key-id", + Hash: "hash-test", + WorkspaceID: "hash-workspace", + Enabled: true, + } + + result := ToKeyData(row) + + require.NotNil(t, result) + require.Equal(t, "hash-key-id", result.Key.ID) + require.Equal(t, "hash-test", result.Key.Hash) + require.Equal(t, "hash-workspace", result.Key.WorkspaceID) + require.True(t, result.Key.Enabled) + }) + + t.Run("FindLiveKeyByHashRow pointer", func(t *testing.T) { + row := FindLiveKeyByHashRow{ + ID: "hash-key-ptr", + Hash: "hash-ptr", + WorkspaceID: "hash-workspace-ptr", + Enabled: false, + } + + result := ToKeyData(row) + + require.NotNil(t, result) + require.Equal(t, "hash-key-ptr", result.Key.ID) + require.Equal(t, "hash-ptr", result.Key.Hash) + require.Equal(t, "hash-workspace-ptr", result.Key.WorkspaceID) + require.False(t, result.Key.Enabled) + }) +} + +func TestToKeyData_EmptyValues(t *testing.T) { + t.Run("zero value FindLiveKeyByIDRow", func(t *testing.T) { + row := FindLiveKeyByIDRow{} // All zero values + + result := ToKeyData(row) + + require.NotNil(t, result) + require.Equal(t, "", result.Key.ID) + require.Equal(t, "", result.Key.Hash) + require.Equal(t, "", result.Key.WorkspaceID) + require.False(t, result.Key.Enabled) // bool zero value + require.Nil(t, result.Identity) // No identity data + require.Empty(t, result.Roles) + require.Empty(t, result.Permissions) + require.Empty(t, result.RolePermissions) + require.Empty(t, result.Ratelimits) + }) + + t.Run("zero value FindLiveKeyByHashRow", func(t *testing.T) { + row := FindLiveKeyByHashRow{} // All zero values + + result := ToKeyData(row) + + require.NotNil(t, result) + require.Equal(t, "", result.Key.ID) + require.Equal(t, "", result.Key.Hash) + require.Equal(t, "", result.Key.WorkspaceID) + require.False(t, result.Key.Enabled) + require.Nil(t, result.Identity) + require.Empty(t, result.Roles) + require.Empty(t, result.Permissions) + require.Empty(t, result.RolePermissions) + require.Empty(t, result.Ratelimits) + }) +} + +func TestToKeyData_WithIdentity(t *testing.T) { + t.Run("with valid identity data", func(t *testing.T) { + row := FindLiveKeyByHashRow{ + ID: "key-with-identity", + WorkspaceID: "workspace-123", + IdentityTableID: sql.NullString{String: "identity-123", Valid: true}, + IdentityExternalID: sql.NullString{String: "user-456", Valid: true}, + IdentityMeta: []byte(`{"role": "admin"}`), + } + + result := ToKeyData(row) + + require.NotNil(t, result) + require.NotNil(t, result.Identity) + require.Equal(t, "identity-123", result.Identity.ID) + require.Equal(t, "user-456", result.Identity.ExternalID) + require.Equal(t, "workspace-123", result.Identity.WorkspaceID) + require.Equal(t, []byte(`{"role": "admin"}`), result.Identity.Meta) + }) + + t.Run("without identity data", func(t *testing.T) { + row := FindLiveKeyByHashRow{ + ID: "key-no-identity", + WorkspaceID: "workspace-123", + IdentityTableID: sql.NullString{Valid: false}, // No identity + } + + result := ToKeyData(row) + + require.NotNil(t, result) + require.Nil(t, result.Identity) + }) +} + +func TestToKeyData_WithJSONFields(t *testing.T) { + t.Run("with valid JSON arrays", func(t *testing.T) { + roles := []RoleInfo{{Name: "admin"}, {Name: "user"}} + rolesJSON, _ := json.Marshal(roles) + + permissions := []PermissionInfo{{Slug: "read"}, {Slug: "write"}} + permissionsJSON, _ := json.Marshal(permissions) + + ratelimits := []RatelimitInfo{ + { + ID: "rate-1", + Duration: 3600, + Limit: 100, + Name: "hourly-limit", + AutoApply: true, + }, + { + ID: "rate-2", + Duration: 60, + Limit: 10, + Name: "minute-limit", + AutoApply: false, + }, + } + ratelimitsJSON, _ := json.Marshal(ratelimits) + + row := FindLiveKeyByHashRow{ + ID: "key-with-json", + Roles: rolesJSON, + Permissions: permissionsJSON, + RolePermissions: permissionsJSON, + Ratelimits: ratelimitsJSON, + } + + result := ToKeyData(row) + + require.NotNil(t, result) + require.Len(t, result.Roles, 2) + require.Equal(t, "admin", result.Roles[0].Name) + require.Equal(t, "user", result.Roles[1].Name) + require.Len(t, result.Permissions, 2) + require.Equal(t, "read", result.Permissions[0].Slug) + require.Equal(t, "write", result.Permissions[1].Slug) + require.Len(t, result.RolePermissions, 2) + require.Len(t, result.Ratelimits, 2) + require.Equal(t, "rate-1", result.Ratelimits[0].ID) + require.Equal(t, int64(3600), result.Ratelimits[0].Duration) + require.Equal(t, int32(100), result.Ratelimits[0].Limit) + require.Equal(t, "hourly-limit", result.Ratelimits[0].Name) + require.True(t, result.Ratelimits[0].AutoApply) + }) + + t.Run("with invalid JSON - should ignore errors", func(t *testing.T) { + row := FindLiveKeyByHashRow{ + ID: "key-bad-json", + Roles: []byte(`{invalid json}`), // Bad JSON + Permissions: []byte(`not json at all`), // Bad JSON + Ratelimits: []byte(`{"incomplete": true`), // Bad JSON + } + + result := ToKeyData(row) + + require.NotNil(t, result) + // Should default to empty arrays when JSON unmarshaling fails + require.Empty(t, result.Roles) + require.Empty(t, result.Permissions) + require.Empty(t, result.RolePermissions) + require.Empty(t, result.Ratelimits) + }) + + t.Run("with nil JSON fields", func(t *testing.T) { + row := FindLiveKeyByHashRow{ + ID: "key-nil-json", + Roles: nil, + Permissions: nil, + RolePermissions: nil, + Ratelimits: nil, + } + + result := ToKeyData(row) + + require.NotNil(t, result) + require.Empty(t, result.Roles) + require.Empty(t, result.Permissions) + require.Empty(t, result.RolePermissions) + require.Empty(t, result.Ratelimits) + }) + + t.Run("with non-byte slice fields", func(t *testing.T) { + row := FindLiveKeyByHashRow{ + ID: "key-wrong-type", + Roles: "not a byte slice", // Wrong type + Permissions: 123, // Wrong type + RolePermissions: struct{}{}, // Wrong type + } + + result := ToKeyData(row) + + require.NotNil(t, result) + // Should default to empty arrays when type assertion fails + require.Empty(t, result.Roles) + require.Empty(t, result.Permissions) + require.Empty(t, result.RolePermissions) + }) +} diff --git a/go/pkg/zen/auth.go b/go/pkg/zen/auth.go index b44a45b30d..e6ef542e20 100644 --- a/go/pkg/zen/auth.go +++ b/go/pkg/zen/auth.go @@ -49,5 +49,4 @@ func Bearer(s *Session) (string, error) { } return bearer, nil - }