diff --git a/go.mod b/go.mod index e9214d4cd..13e9f8d67 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/BurntSushi/toml v1.5.0 github.com/coreos/go-oidc/v3 v3.14.1 github.com/fsnotify/fsnotify v1.9.0 + github.com/go-jose/go-jose/v4 v4.0.5 github.com/mark3labs/mcp-go v0.34.0 github.com/pkg/errors v0.9.1 github.com/spf13/afero v1.14.0 @@ -52,7 +53,6 @@ require ( github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/go-errors/errors v1.4.2 // indirect github.com/go-gorp/gorp/v3 v3.1.0 // indirect - github.com/go-jose/go-jose/v4 v4.0.5 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-openapi/jsonpointer v0.21.0 // indirect github.com/go-openapi/jsonreference v0.20.2 // indirect diff --git a/pkg/http/authorization.go b/pkg/http/authorization.go index c517cb72e..d46e25f64 100644 --- a/pkg/http/authorization.go +++ b/pkg/http/authorization.go @@ -2,14 +2,13 @@ package http import ( "context" - "encoding/base64" - "encoding/json" "fmt" "net/http" "strings" - "time" "github.com/coreos/go-oidc/v3/oidc" + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" "k8s.io/klog/v2" "github.com/manusa/kubernetes-mcp-server/pkg/mcp" @@ -55,7 +54,10 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, mcpServer *mcp // Validate the token offline for simple sanity check // Because missing expected audience and expired tokens must be // rejected already. - claims, err := validateJWTToken(token, audience) + claims, err := ParseJWTClaims(token) + if err == nil && claims != nil { + err = claims.Validate(audience) + } if err != nil { klog.V(1).Infof("Authentication failed - JWT validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err) @@ -118,11 +120,25 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, mcpServer *mcp } } +var allSignatureAlgorithms = []jose.SignatureAlgorithm{ + jose.EdDSA, + jose.HS256, + jose.HS384, + jose.HS512, + jose.RS256, + jose.RS384, + jose.RS512, + jose.ES256, + jose.ES384, + jose.ES512, + jose.PS256, + jose.PS384, + jose.PS512, +} + type JWTClaims struct { - Issuer string `json:"iss"` - Audience any `json:"aud"` - ExpiresAt int64 `json:"exp"` - Scope string `json:"scope,omitempty"` + jwt.Claims + Scope string `json:"scope,omitempty"` } func (c *JWTClaims) GetScopes() []string { @@ -132,66 +148,21 @@ func (c *JWTClaims) GetScopes() []string { return strings.Fields(c.Scope) } -func (c *JWTClaims) ContainsAudience(audience string) bool { - switch aud := c.Audience.(type) { - case string: - return aud == audience - case []interface{}: - for _, a := range aud { - if str, ok := a.(string); ok && str == audience { - return true - } - } - case []string: - for _, a := range aud { - if a == audience { - return true - } - } - } - return false -} - -// validateJWTToken validates basic JWT claims without signature verification and returns the claims -func validateJWTToken(token, audience string) (*JWTClaims, error) { - parts := strings.Split(token, ".") - if len(parts) != 3 { - return nil, fmt.Errorf("invalid JWT token format") - } - - claims, err := parseJWTClaims(parts[1]) - if err != nil { - return nil, fmt.Errorf("failed to parse JWT claims: %v", err) - } - - if claims.ExpiresAt > 0 && time.Now().Unix() > claims.ExpiresAt { - return nil, fmt.Errorf("token expired") - } - - if !claims.ContainsAudience(audience) { - return nil, fmt.Errorf("token audience mismatch: %v", claims.Audience) - } - - return claims, nil +// Validate Checks if the JWT claims are valid and if the audience matches the expected one. +func (c *JWTClaims) Validate(audience string) error { + return c.Claims.Validate(jwt.Expected{ + AnyAudience: jwt.Audience{audience}, + }) } -func parseJWTClaims(payload string) (*JWTClaims, error) { - // Add padding if needed - if len(payload)%4 != 0 { - payload += strings.Repeat("=", 4-len(payload)%4) - } - - decoded, err := base64.URLEncoding.DecodeString(payload) +func ParseJWTClaims(token string) (*JWTClaims, error) { + tkn, err := jwt.ParseSigned(token, allSignatureAlgorithms) if err != nil { - return nil, fmt.Errorf("failed to decode JWT payload: %v", err) + return nil, fmt.Errorf("failed to parse JWT token: %w", err) } - - var claims JWTClaims - if err := json.Unmarshal(decoded, &claims); err != nil { - return nil, fmt.Errorf("failed to unmarshal JWT claims: %v", err) - } - - return &claims, nil + claims := &JWTClaims{} + err = tkn.UnsafeClaimsWithoutVerification(claims) + return claims, err } func validateTokenWithOIDC(ctx context.Context, provider *oidc.Provider, token, audience string) error { diff --git a/pkg/http/authorization_test.go b/pkg/http/authorization_test.go index 0ffe816e0..9dd45111c 100644 --- a/pkg/http/authorization_test.go +++ b/pkg/http/authorization_test.go @@ -1,187 +1,222 @@ package http import ( - "encoding/base64" - "encoding/json" "net/http" "net/http/httptest" "strings" "testing" - "time" + + "github.com/go-jose/go-jose/v4/jwt" ) -func TestParseJWTClaims(t *testing.T) { - t.Run("valid JWT payload", func(t *testing.T) { - // Sample payload from a valid JWT - payload := "eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoxNzUxOTYzOTQ4LCJpYXQiOjE3NTE5NjAzNDgsImlzcyI6Imh0dHBzOi8va3ViZXJuZXRlcy5kZWZhdWx0LnN2Yy5jbHVzdGVyLmxvY2FsIiwianRpIjoiOTkyMjJkNTYtMzQwZS00ZWI2LTg1ODgtMjYxNDExZjM1ZDI2Iiwia3ViZXJuZXRlcy5pbyI6eyJuYW1lc3BhY2UiOiJkZWZhdWx0Iiwic2VydmljZWFjY291bnQiOnsibmFtZSI6ImRlZmF1bHQiLCJ1aWQiOiJlYWNiNmFkMi04MGI3LTQxNzktODQzZC05MmViMWU2YmJiYTYifX0sIm5iZiI6MTc1MTk2MDM0OCwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6ZGVmYXVsdCJ9" +const ( + // https://jwt.io/#token=eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk4ZDU3YmUwNWI3ZjUzNWIwMzYyYjg2MDJhNTJlNGYxIn0.eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoyNTM0MDIyOTcxOTksImlhdCI6MCwiaXNzIjoiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJqdGkiOiI5OTIyMmQ1Ni0zNDBlLTRlYjYtODU4OC0yNjE0MTFmMzVkMjYiLCJrdWJlcm5ldGVzLmlvIjp7Im5hbWVzcGFjZSI6ImRlZmF1bHQiLCJzZXJ2aWNlYWNjb3VudCI6eyJuYW1lIjoiZGVmYXVsdCIsInVpZCI6ImVhY2I2YWQyLTgwYjctNDE3OS04NDNkLTkyZWIxZTZiYmJhNiJ9fSwibmJmIjowLCJzdWIiOiJzeXN0ZW06c2VydmljZWFjY291bnQ6ZGVmYXVsdDpkZWZhdWx0In0.0363P6xGmWpU-O9TAVkcOd95lPXxhI-_k5NKbHGNQeL--B8XMAz2vC8hpKnyC6rKOGifRTSR2XNHx_5fjd7lEA // notsecret + tokenBasicNotExpired = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk4ZDU3YmUwNWI3ZjUzNWIwMzYyYjg2MDJhNTJlNGYxIn0.eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoyNTM0MDIyOTcxOTksImlhdCI6MCwiaXNzIjoiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJqdGkiOiI5OTIyMmQ1Ni0zNDBlLTRlYjYtODU4OC0yNjE0MTFmMzVkMjYiLCJrdWJlcm5ldGVzLmlvIjp7Im5hbWVzcGFjZSI6ImRlZmF1bHQiLCJzZXJ2aWNlYWNjb3VudCI6eyJuYW1lIjoiZGVmYXVsdCIsInVpZCI6ImVhY2I2YWQyLTgwYjctNDE3OS04NDNkLTkyZWIxZTZiYmJhNiJ9fSwibmJmIjowLCJzdWIiOiJzeXN0ZW06c2VydmljZWFjY291bnQ6ZGVmYXVsdDpkZWZhdWx0In0.0363P6xGmWpU-O9TAVkcOd95lPXxhI-_k5NKbHGNQeL--B8XMAz2vC8hpKnyC6rKOGifRTSR2XNHx_5fjd7lEA" // notsecret + // https://jwt.io/#token=eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk4ZDU3YmUwNWI3ZjUzNWIwMzYyYjg2MDJhNTJlNGYxIn0.eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoxLCJpYXQiOjAsImlzcyI6Imh0dHBzOi8va3ViZXJuZXRlcy5kZWZhdWx0LnN2Yy5jbHVzdGVyLmxvY2FsIiwianRpIjoiOTkyMjJkNTYtMzQwZS00ZWI2LTg1ODgtMjYxNDExZjM1ZDI2Iiwia3ViZXJuZXRlcy5pbyI6eyJuYW1lc3BhY2UiOiJkZWZhdWx0Iiwic2VydmljZWFjY291bnQiOnsibmFtZSI6ImRlZmF1bHQiLCJ1aWQiOiJlYWNiNmFkMi04MGI3LTQxNzktODQzZC05MmViMWU2YmJiYTYifX0sIm5iZiI6MCwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6ZGVmYXVsdCJ9.USsuGLsB_7MwG9i0__cFkVVZa0djtmQpc8Vwi56GrapAgVAcyTfmae3s83XMDP5AwcFnxhYxLCfiZWRJri6GTA // notsecret + tokenBasicExpired = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk4ZDU3YmUwNWI3ZjUzNWIwMzYyYjg2MDJhNTJlNGYxIn0.eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoxLCJpYXQiOjAsImlzcyI6Imh0dHBzOi8va3ViZXJuZXRlcy5kZWZhdWx0LnN2Yy5jbHVzdGVyLmxvY2FsIiwianRpIjoiOTkyMjJkNTYtMzQwZS00ZWI2LTg1ODgtMjYxNDExZjM1ZDI2Iiwia3ViZXJuZXRlcy5pbyI6eyJuYW1lc3BhY2UiOiJkZWZhdWx0Iiwic2VydmljZWFjY291bnQiOnsibmFtZSI6ImRlZmF1bHQiLCJ1aWQiOiJlYWNiNmFkMi04MGI3LTQxNzktODQzZC05MmViMWU2YmJiYTYifX0sIm5iZiI6MCwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6ZGVmYXVsdCJ9.USsuGLsB_7MwG9i0__cFkVVZa0djtmQpc8Vwi56GrapAgVAcyTfmae3s83XMDP5AwcFnxhYxLCfiZWRJri6GTA" // notsecret + // https://jwt.io/#token=eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk4ZDU3YmUwNWI3ZjUzNWIwMzYyYjg2MDJhNTJlNGYxIn0.eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoyNTM0MDIyOTcxOTksImlhdCI6MCwiaXNzIjoiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJqdGkiOiI5OTIyMmQ1Ni0zNDBlLTRlYjYtODU4OC0yNjE0MTFmMzVkMjYiLCJrdWJlcm5ldGVzLmlvIjp7Im5hbWVzcGFjZSI6ImRlZmF1bHQiLCJzZXJ2aWNlYWNjb3VudCI6eyJuYW1lIjoiZGVmYXVsdCIsInVpZCI6ImVhY2I2YWQyLTgwYjctNDE3OS04NDNkLTkyZWIxZTZiYmJhNiJ9fSwibmJmIjowLCJzdWIiOiJzeXN0ZW06c2VydmljZWFjY291bnQ6ZGVmYXVsdDpkZWZhdWx0Iiwic2NvcGUiOiJyZWFkIHdyaXRlIn0.vl5se9BuxoVDhvR7M5wGfkLoyMSYUiORMZVxl0CQ7jw3x53mZfGEkU_kkIVIl9Ui371qCCVVxdvuZPcAgbM6pQ // notsecret + tokenMultipleAudienceNotExpired = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk4ZDU3YmUwNWI3ZjUzNWIwMzYyYjg2MDJhNTJlNGYxIn0.eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoyNTM0MDIyOTcxOTksImlhdCI6MCwiaXNzIjoiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJqdGkiOiI5OTIyMmQ1Ni0zNDBlLTRlYjYtODU4OC0yNjE0MTFmMzVkMjYiLCJrdWJlcm5ldGVzLmlvIjp7Im5hbWVzcGFjZSI6ImRlZmF1bHQiLCJzZXJ2aWNlYWNjb3VudCI6eyJuYW1lIjoiZGVmYXVsdCIsInVpZCI6ImVhY2I2YWQyLTgwYjctNDE3OS04NDNkLTkyZWIxZTZiYmJhNiJ9fSwibmJmIjowLCJzdWIiOiJzeXN0ZW06c2VydmljZWFjY291bnQ6ZGVmYXVsdDpkZWZhdWx0Iiwic2NvcGUiOiJyZWFkIHdyaXRlIn0.vl5se9BuxoVDhvR7M5wGfkLoyMSYUiORMZVxl0CQ7jw3x53mZfGEkU_kkIVIl9Ui371qCCVVxdvuZPcAgbM6pQ" // notsecret +) - claims, err := parseJWTClaims(payload) +func TestParseJWTClaimsPayloadValid(t *testing.T) { + basicClaims, err := ParseJWTClaims(tokenBasicNotExpired) + t.Run("Is parseable", func(t *testing.T) { if err != nil { t.Fatalf("expected no error, got %v", err) } - - if claims == nil { + if basicClaims == nil { t.Fatal("expected claims, got nil") } - - if claims.Issuer != "https://kubernetes.default.svc.cluster.local" { - t.Errorf("expected issuer 'https://kubernetes.default.svc.cluster.local', got %s", claims.Issuer) + }) + t.Run("Parses issuer", func(t *testing.T) { + if basicClaims.Issuer != "https://kubernetes.default.svc.cluster.local" { + t.Errorf("expected issuer 'https://kubernetes.default.svc.cluster.local', got %s", basicClaims.Issuer) } - + }) + t.Run("Parses audience", func(t *testing.T) { expectedAudiences := []string{"https://kubernetes.default.svc.cluster.local", "kubernetes-mcp-server"} for _, expected := range expectedAudiences { - if !claims.ContainsAudience(expected) { + if !basicClaims.Audience.Contains(expected) { t.Errorf("expected audience to contain %s", expected) } } - - if claims.ExpiresAt != 1751963948 { - t.Errorf("expected exp 1751963948, got %d", claims.ExpiresAt) + }) + t.Run("Parses expiration", func(t *testing.T) { + if *basicClaims.Expiry != jwt.NumericDate(253402297199) { + t.Errorf("expected expiration 253402297199, got %d", basicClaims.Expiry) } }) - - t.Run("payload needs padding", func(t *testing.T) { - // Create a payload that needs padding - testClaims := JWTClaims{ - Issuer: "test-issuer", - Audience: "test-audience", - ExpiresAt: time.Now().Add(time.Hour).Unix(), + t.Run("Parses scope", func(t *testing.T) { + scopeClaims, err := ParseJWTClaims(tokenMultipleAudienceNotExpired) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if scopeClaims == nil { + t.Fatal("expected claims, got nil") } - jsonBytes, _ := json.Marshal(testClaims) - // Create a payload without proper padding - encodedWithoutPadding := strings.TrimRight(base64.URLEncoding.EncodeToString(jsonBytes), "=") + scopes := scopeClaims.GetScopes() - claims, err := parseJWTClaims(encodedWithoutPadding) + expectedScopes := []string{"read", "write"} + if len(scopes) != len(expectedScopes) { + t.Errorf("expected %d scopes, got %d", len(expectedScopes), len(scopes)) + } + for i, expectedScope := range expectedScopes { + if scopes[i] != expectedScope { + t.Errorf("expected scope[%d] to be '%s', got '%s'", i, expectedScope, scopes[i]) + } + } + }) + t.Run("Parses expired token", func(t *testing.T) { + expiredClaims, err := ParseJWTClaims(tokenBasicExpired) if err != nil { t.Fatalf("expected no error, got %v", err) } - if claims.Issuer != "test-issuer" { - t.Errorf("expected issuer 'test-issuer', got %s", claims.Issuer) + if *expiredClaims.Expiry != jwt.NumericDate(1) { + t.Errorf("expected expiration 1, got %d", basicClaims.Expiry) } }) +} - t.Run("invalid base64 payload", func(t *testing.T) { - invalidPayload := "invalid-base64!!!" +func TestParseJWTClaimsPayloadInvalid(t *testing.T) { + t.Run("invalid token segments", func(t *testing.T) { + invalidToken := "header.payload.signature.extra" - _, err := parseJWTClaims(invalidPayload) + _, err := ParseJWTClaims(invalidToken) if err == nil { - t.Error("expected error for invalid base64, got nil") + t.Fatal("expected error for invalid token segments, got nil") } - if !strings.Contains(err.Error(), "failed to decode JWT payload") { - t.Errorf("expected decode error message, got %v", err) + if !strings.Contains(err.Error(), "compact JWS format must have three parts") { + t.Errorf("expected invalid token segments error message, got %v", err) } }) + t.Run("invalid base64 payload", func(t *testing.T) { + invalidPayload := "invalid_base64" + tokenBasicNotExpired - t.Run("invalid JSON payload", func(t *testing.T) { - // Valid base64 but invalid JSON - invalidJSON := base64.URLEncoding.EncodeToString([]byte("{invalid-json")) - - _, err := parseJWTClaims(invalidJSON) + _, err := ParseJWTClaims(invalidPayload) if err == nil { - t.Error("expected error for invalid JSON, got nil") + t.Fatal("expected error for invalid base64, got nil") } - if !strings.Contains(err.Error(), "failed to unmarshal JWT claims") { - t.Errorf("expected unmarshal error message, got %v", err) + if !strings.Contains(err.Error(), "illegal base64 data") { + t.Errorf("expected decode error message, got %v", err) } }) } -func TestValidateJWTToken(t *testing.T) { - t.Run("invalid token format - not enough parts", func(t *testing.T) { - invalidToken := "header.payload" +func TestJWTTokenValidate(t *testing.T) { + t.Run("expired token returns error", func(t *testing.T) { + claims, err := ParseJWTClaims(tokenBasicExpired) + if err != nil { + t.Fatalf("expected no error for expired token parsing, got %v", err) + } - _, err := validateJWTToken(invalidToken, "test") + err = claims.Validate("kubernetes-mcp-server") if err == nil { - t.Error("expected error for invalid token format, got nil") + t.Fatalf("expected error for expired token, got nil") } - if !strings.Contains(err.Error(), "invalid JWT token format") { - t.Errorf("expected format error message, got %v", err) + if !strings.Contains(err.Error(), "token is expired (exp)") { + t.Errorf("expected expiration error message, got %v", err) } }) - t.Run("expired token", func(t *testing.T) { - // Create an expired token - expiredClaims := JWTClaims{ - Issuer: "test-issuer", - Audience: "kubernetes-mcp-server", - ExpiresAt: time.Now().Add(-time.Hour).Unix(), + t.Run("multiple audiences with correct one", func(t *testing.T) { + claims, err := ParseJWTClaims(tokenMultipleAudienceNotExpired) + if err != nil { + t.Fatalf("expected no error for multiple audience token parsing, got %v", err) + } + if claims == nil { + t.Fatalf("expected claims to be returned, got nil") + } + + err = claims.Validate("kubernetes-mcp-server") + if err != nil { + t.Fatalf("expected no error for valid audience, got %v", err) } + }) - jsonBytes, _ := json.Marshal(expiredClaims) - payload := base64.URLEncoding.EncodeToString(jsonBytes) - expiredToken := "header." + payload + ".signature" + t.Run("multiple audiences with mismatch returns error", func(t *testing.T) { + claims, err := ParseJWTClaims(tokenMultipleAudienceNotExpired) + if err != nil { + t.Fatalf("expected no error for multiple audience token parsing, got %v", err) + } + if claims == nil { + t.Fatalf("expected claims to be returned, got nil") + } - _, err := validateJWTToken(expiredToken, "kubernetes-mcp-server") + err = claims.Validate("missing-audience") if err == nil { - t.Error("expected error for expired token, got nil") + t.Fatalf("expected error for token with wrong audience, got nil") } - if !strings.Contains(err.Error(), "token expired") { - t.Errorf("expected expiration error message, got %v", err) + if !strings.Contains(err.Error(), "invalid audience claim (aud)") { + t.Errorf("expected audience mismatch error, got %v", err) } }) +} - t.Run("multiple audiences with correct one", func(t *testing.T) { - // Create a token with multiple audiences including the correct one - multiAudClaims := JWTClaims{ - Issuer: "test-issuer", - Audience: []string{"other-audience", "kubernetes-mcp-server", "another-audience"}, - ExpiresAt: time.Now().Add(time.Hour).Unix(), - Scope: "read write admin", +func TestJWTClaimsGetScopes(t *testing.T) { + t.Run("no scopes", func(t *testing.T) { + claims, err := ParseJWTClaims(tokenBasicExpired) + if err != nil { + t.Fatalf("expected no error for parsing token, got %v", err) } - jsonBytes, _ := json.Marshal(multiAudClaims) - payload := base64.URLEncoding.EncodeToString(jsonBytes) - multiAudToken := "header." + payload + ".signature" - - claims, err := validateJWTToken(multiAudToken, "kubernetes-mcp-server") - if err != nil { - t.Errorf("expected no error for token with multiple audiences, got %v", err) + if scopes := claims.GetScopes(); len(scopes) != 0 { + t.Errorf("expected no scopes, got %d", len(scopes)) } - if claims == nil { - t.Error("expected claims to be returned, got nil") + }) + t.Run("single scope", func(t *testing.T) { + claims := &JWTClaims{ + Scope: "read", } - if claims.Issuer != "test-issuer" { - t.Errorf("expected issuer 'test-issuer', got %s", claims.Issuer) + scopes := claims.GetScopes() + expected := []string{"read"} + + if len(scopes) != 1 { + t.Errorf("expected 1 scope, got %d", len(scopes)) } + if scopes[0] != expected[0] { + t.Errorf("expected scope 'read', got '%s'", scopes[0]) + } + }) - // Test scope parsing + t.Run("multiple scopes", func(t *testing.T) { + claims := &JWTClaims{ + Scope: "read write admin", + } scopes := claims.GetScopes() - expectedScopes := []string{"read", "write", "admin"} + expected := []string{"read", "write", "admin"} + if len(scopes) != 3 { t.Errorf("expected 3 scopes, got %d", len(scopes)) } - for i, expectedScope := range expectedScopes { + + for i, expectedScope := range expected { if i >= len(scopes) || scopes[i] != expectedScope { t.Errorf("expected scope[%d] to be '%s', got '%s'", i, expectedScope, scopes[i]) } } }) - t.Run("audience mismatch", func(t *testing.T) { - // Create a token with wrong audience - wrongAudClaims := JWTClaims{ - Issuer: "test-issuer", - Audience: "wrong-audience", - ExpiresAt: time.Now().Add(time.Hour).Unix(), + t.Run("scopes with extra whitespace", func(t *testing.T) { + claims := &JWTClaims{ + Scope: " read write admin ", } + scopes := claims.GetScopes() + expected := []string{"read", "write", "admin"} - jsonBytes, _ := json.Marshal(wrongAudClaims) - payload := base64.URLEncoding.EncodeToString(jsonBytes) - wrongAudToken := "header." + payload + ".signature" - - _, err := validateJWTToken(wrongAudToken, "audience") - if err == nil { - t.Error("expected error for token with wrong audience, got nil") + if len(scopes) != 3 { + t.Errorf("expected 3 scopes, got %d", len(scopes)) } - if !strings.Contains(err.Error(), "audience mismatch") { - t.Errorf("expected audience mismatch error, got %v", err) + for i, expectedScope := range expected { + if i >= len(scopes) || scopes[i] != expectedScope { + t.Errorf("expected scope[%d] to be '%s', got '%s'", i, expectedScope, scopes[i]) + } } }) } @@ -285,85 +320,3 @@ func TestAuthorizationMiddleware(t *testing.T) { } }) } - -func TestJWTClaimsGetScopes(t *testing.T) { - t.Run("single scope", func(t *testing.T) { - claims := &JWTClaims{Scope: "read"} - scopes := claims.GetScopes() - expected := []string{"read"} - - if len(scopes) != 1 { - t.Errorf("expected 1 scope, got %d", len(scopes)) - } - if scopes[0] != expected[0] { - t.Errorf("expected scope 'read', got '%s'", scopes[0]) - } - }) - - t.Run("multiple scopes", func(t *testing.T) { - claims := &JWTClaims{Scope: "read write admin"} - scopes := claims.GetScopes() - expected := []string{"read", "write", "admin"} - - if len(scopes) != 3 { - t.Errorf("expected 3 scopes, got %d", len(scopes)) - } - - for i, expectedScope := range expected { - if i >= len(scopes) || scopes[i] != expectedScope { - t.Errorf("expected scope[%d] to be '%s', got '%s'", i, expectedScope, scopes[i]) - } - } - }) - - t.Run("scopes with extra whitespace", func(t *testing.T) { - claims := &JWTClaims{Scope: " read write admin "} - scopes := claims.GetScopes() - expected := []string{"read", "write", "admin"} - - if len(scopes) != 3 { - t.Errorf("expected 3 scopes, got %d", len(scopes)) - } - - for i, expectedScope := range expected { - if i >= len(scopes) || scopes[i] != expectedScope { - t.Errorf("expected scope[%d] to be '%s', got '%s'", i, expectedScope, scopes[i]) - } - } - }) -} - -func TestJWTClaimsContainsAudience(t *testing.T) { - t.Run("single string audience", func(t *testing.T) { - claims := &JWTClaims{Audience: "test-audience"} - - if !claims.ContainsAudience("test-audience") { - t.Error("expected ContainsAudience to return true for matching audience") - } - - if claims.ContainsAudience("other-audience") { - t.Error("expected ContainsAudience to return false for non-matching audience") - } - }) - - t.Run("array audience", func(t *testing.T) { - claims := &JWTClaims{Audience: []string{"aud1", "aud2", "aud3"}} - - testCases := []struct { - audience string - expected bool - }{ - {"aud1", true}, - {"aud2", true}, - {"aud3", true}, - {"aud4", false}, - {"", false}, - } - - for _, tc := range testCases { - if claims.ContainsAudience(tc.audience) != tc.expected { - t.Errorf("expected ContainsAudience(%s) to return %v", tc.audience, tc.expected) - } - } - }) -}