Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion go/apps/api/routes/v2_keys_add_roles/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,9 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error {
}

rolePermissions := make([]db.Permission, 0)
json.Unmarshal(role.Permissions.([]byte), &rolePermissions)
if permBytes, ok := role.Permissions.([]byte); ok && permBytes != nil {
_ = json.Unmarshal(permBytes, &rolePermissions) // Ignore error, default to empty array
}

perms := make([]openapi.Permission, 0)
for _, permission := range rolePermissions {
Expand Down
4 changes: 3 additions & 1 deletion go/apps/api/routes/v2_keys_remove_roles/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,9 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error {
}

rolePermissions := make([]db.Permission, 0)
json.Unmarshal(role.Permissions.([]byte), &rolePermissions)
if permBytes, ok := role.Permissions.([]byte); ok && permBytes != nil {
_ = json.Unmarshal(permBytes, &rolePermissions) // Ignore error, default to empty array
}

perms := make([]openapi.Permission, 0)
for _, permission := range rolePermissions {
Expand Down
11 changes: 10 additions & 1 deletion go/apps/api/routes/v2_keys_set_roles/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,16 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error {
}

rolePermissions := make([]db.Permission, 0)
json.Unmarshal(role.Permissions.([]byte), &rolePermissions)
if permBytes, ok := role.Permissions.([]byte); ok && permBytes != nil {
// AIDEV-SAFETY: On JSON parse failure, we default to empty permissions list
// to maintain least-privilege security posture rather than failing open
if err := json.Unmarshal(permBytes, &rolePermissions); err != nil {
h.Logger.Debug("failed to parse role permissions JSON, defaulting to empty list",
"roleId", role.ID,
"rawBytes", string(permBytes),
"error", err.Error())
}
}

perms := make([]openapi.Permission, 0)
for _, permission := range rolePermissions {
Expand Down
4 changes: 3 additions & 1 deletion go/apps/api/routes/v2_permissions_get_role/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error {
}

rolePermissions := make([]db.Permission, 0)
json.Unmarshal(role.Permissions.([]byte), &rolePermissions)
if permBytes, ok := role.Permissions.([]byte); ok && permBytes != nil {
_ = json.Unmarshal(permBytes, &rolePermissions) // Ignore error, default to empty array
}

perms := make([]openapi.Permission, 0)
for _, perm := range rolePermissions {
Expand Down
4 changes: 3 additions & 1 deletion go/apps/api/routes/v2_permissions_list_roles/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error {
}

rolePermissions := make([]db.Permission, 0)
json.Unmarshal(role.Permissions.([]byte), &rolePermissions)
if permBytes, ok := role.Permissions.([]byte); ok && permBytes != nil {
_ = json.Unmarshal(permBytes, &rolePermissions) // Ignore error, default to empty array
}
perms := make([]openapi.Permission, 0)

for _, perm := range rolePermissions {
Expand Down
20 changes: 17 additions & 3 deletions go/apps/api/routes/v2_ratelimit_get_override/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,19 @@ type Handler struct {
RatelimitNamespaceCache cache.Cache[cache.ScopedKey, db.FindRatelimitNamespace]
}

// decodeOverrides safely decodes JSON bytes into override slice with proper error handling
func decodeOverrides(data interface{}) ([]db.FindRatelimitNamespaceLimitOverride, error) {
overrides := make([]db.FindRatelimitNamespaceLimitOverride, 0)
if overrideBytes, ok := data.([]byte); ok && overrideBytes != nil {
if err := json.Unmarshal(overrideBytes, &overrides); err != nil {
return nil, fault.Wrap(err,
fault.Code(codes.App.Internal.UnexpectedError.URN()),
fault.Public("An unexpected error occurred while processing override data."))
}
}
return overrides, nil
}

// Method returns the HTTP method this route responds to
func (h *Handler) Method() string {
return "POST"
Expand Down Expand Up @@ -76,8 +89,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error {
WildcardOverrides: make([]db.FindRatelimitNamespaceLimitOverride, 0),
}

overrides := make([]db.FindRatelimitNamespaceLimitOverride, 0)
err = json.Unmarshal(response.Overrides.([]byte), &overrides)
overrides, err := decodeOverrides(response.Overrides)
if err != nil {
return result, err
}
Expand All @@ -100,7 +112,9 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error {
)
}

return err
return fault.Wrap(err,
fault.Code(codes.App.Internal.UnexpectedError.URN()),
fault.Public("An unexpected error occurred while fetching the namespace."))
}

if hit == cache.Null {
Expand Down
8 changes: 5 additions & 3 deletions go/apps/api/routes/v2_ratelimit_limit/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,11 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error {
}

overrides := make([]db.FindRatelimitNamespaceLimitOverride, 0)
err = json.Unmarshal(response.Overrides.([]byte), &overrides)
if err != nil {
return result, err
if overrideBytes, ok := response.Overrides.([]byte); ok && overrideBytes != nil {
err = json.Unmarshal(overrideBytes, &overrides)
if err != nil {
return result, err
}
}

for _, override := range overrides {
Expand Down
2 changes: 1 addition & 1 deletion go/internal/services/keys/create_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func TestCreateKey_Uniqueness(t *testing.T) {
keys := make(map[string]bool)
hashes := make(map[string]bool)

for i := 0; i < 10; i++ {
for range 10 {
req := CreateKeyRequest{
Prefix: "",
ByteLength: 16,
Expand Down
46 changes: 34 additions & 12 deletions go/internal/services/keys/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ func (s *service) GetRootKey(ctx context.Context, sess *zen.Session) (*KeyVerifi
}

key, log, err := s.Get(ctx, sess, rootKey)
if err != nil {
return nil, log, err
}

if key.Key.ForWorkspaceID.Valid {
key.AuthorizedWorkspaceID = key.Key.ForWorkspaceID.String
}
sess.WorkspaceID = key.AuthorizedWorkspaceID
if err != nil {
return nil, log, err
}

if key.Status != StatusValid {
return nil, log, fault.Wrap(
Expand Down Expand Up @@ -116,17 +117,38 @@ func (s *service) Get(ctx context.Context, sess *zen.Session, rawKey string) (*K
// The DB returns this in array format and an empty array if not found
var roles, permissions []string
var ratelimitArr []db.KeyFindForVerificationRatelimit
err = json.Unmarshal(key.Roles.([]byte), &roles)
if err != nil {
return nil, emptyLog, err

// Safely handle roles field
rolesBytes, ok := key.Roles.([]byte)
if !ok || rolesBytes == nil {
roles = []string{} // Default to empty array if nil or wrong type
} else {
err = json.Unmarshal(rolesBytes, &roles)
if err != nil {
return nil, emptyLog, fault.Wrap(err, fault.Internal("failed to unmarshal roles"))
}
}
err = json.Unmarshal(key.Permissions.([]byte), &permissions)
if err != nil {
return nil, emptyLog, err

// Safely handle permissions field
permissionsBytes, ok := key.Permissions.([]byte)
if !ok || permissionsBytes == nil {
permissions = []string{} // Default to empty array if nil or wrong type
} else {
err = json.Unmarshal(permissionsBytes, &permissions)
if err != nil {
return nil, emptyLog, fault.Wrap(err, fault.Internal("failed to unmarshal permissions"))
}
}
err = json.Unmarshal(key.Ratelimits.([]byte), &ratelimitArr)
if err != nil {
return nil, emptyLog, err

// Safely handle ratelimits field
ratelimitsBytes, ok := key.Ratelimits.([]byte)
if !ok || ratelimitsBytes == nil {
ratelimitArr = []db.KeyFindForVerificationRatelimit{} // Default to empty array if nil or wrong type
} else {
err = json.Unmarshal(ratelimitsBytes, &ratelimitArr)
if err != nil {
return nil, emptyLog, fault.Wrap(err, fault.Internal("failed to unmarshal ratelimits"))
}
}

// Convert rate limits array to map (key name -> config)
Expand Down
88 changes: 88 additions & 0 deletions go/internal/services/keys/get_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package keys

import (
"context"
"testing"

"github.com/stretchr/testify/require"
"github.com/unkeyed/unkey/go/pkg/codes"
"github.com/unkeyed/unkey/go/pkg/fault"
)

func TestGetRootKey_ErrorHandling_ReturnsError(t *testing.T) {
t.Parallel()

// Create a service with nil dependencies, similar to create_test.go pattern
s := &service{}

ctx := context.Background()

// This test validates that GetRootKey properly handles errors from Get()
// Before the fix, GetRootKey would panic when trying to access key.Key.ForWorkspaceID.Valid
// after Get() returned an error with nil key. Now it should return the error safely.
key, log, err := s.GetRootKey(ctx, nil)

require.Error(t, err)
require.Nil(t, key)
require.NotNil(t, log)

// Verify specific error code for missing auth when session is nil
code, ok := fault.GetCode(err)
require.True(t, ok)
require.Equal(t, codes.Auth.Authentication.Missing.URN(), code)
}

func TestGetRootKey_WithEmptyRawKey_ReturnsError(t *testing.T) {
t.Parallel()

// Create a service with nil dependencies, following create_test.go pattern
s := &service{}

ctx := context.Background()

// Call Get with empty raw key to test the assert.NotEmpty validation
key, log, err := s.Get(ctx, nil, "")

// Verify that we get an error for empty key
require.Error(t, err)
require.Nil(t, key)
require.NotNil(t, log)
require.Contains(t, err.Error(), "rawKey is empty")
}

func TestGet_WithEmptyRawKey_ReturnsError(t *testing.T) {
t.Parallel()

// Test the assert.NotEmpty validation path directly in Get function
s := &service{}
ctx := context.Background()

key, log, err := s.Get(ctx, nil, "")

require.Error(t, err)
require.Nil(t, key)
require.NotNil(t, log)
require.Contains(t, err.Error(), "rawKey is empty")
}

func TestGet_EmptyString_Variants(t *testing.T) {
t.Parallel()

// Test various empty string cases to improve assert.NotEmpty coverage
s := &service{}
ctx := context.Background()

// Only test cases that will hit the validation path, not the cache/db path
emptyVariants := []string{
"", // Classic empty string
}

for _, empty := range emptyVariants {
key, log, err := s.Get(ctx, nil, empty)

require.Error(t, err)
require.Nil(t, key)
require.NotNil(t, log)
require.Contains(t, err.Error(), "rawKey is empty")
}
}
9 changes: 9 additions & 0 deletions go/pkg/zen/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ import (
// }
// // Validate the token
func Bearer(s *Session) (string, error) {
if s == nil {
return "", fault.New("nil session", fault.Code(codes.Auth.Authentication.Missing.URN()),
fault.Internal("session is nil"), fault.Public("Invalid session."))
}

if s.r == nil {
return "", fault.New("nil request", fault.Code(codes.Auth.Authentication.Malformed.URN()),
fault.Internal("session request is nil"), fault.Public("Invalid request."))
}

header := s.r.Header.Get("Authorization")
if header == "" {
Expand Down
30 changes: 30 additions & 0 deletions go/pkg/zen/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,33 @@ func TestBearer_Integration(t *testing.T) {
require.NoError(t, err)
require.Equal(t, "token123", token)
}

func TestBearer_NilSession(t *testing.T) {
t.Parallel()

// Test with nil session - should not panic
token, err := Bearer(nil)
require.Error(t, err)
require.Empty(t, token)
require.Contains(t, err.Error(), "nil session")

code, ok := fault.GetCode(err)
require.True(t, ok)
require.Equal(t, codes.Auth.Authentication.Missing.URN(), code)
}

func TestBearer_NilRequest(t *testing.T) {
t.Parallel()

// Test with session that has nil request - should not panic
sess := &Session{} // r field is nil

token, err := Bearer(sess)
require.Error(t, err)
require.Empty(t, token)
require.Contains(t, err.Error(), "nil request")

code, ok := fault.GetCode(err)
require.True(t, ok)
require.Equal(t, codes.Auth.Authentication.Malformed.URN(), code)
}