From 24e8ed345c96ca72b7bc61d35c7c0791f8ca9cfe Mon Sep 17 00:00:00 2001 From: Marc Nuri Date: Thu, 17 Jul 2025 12:47:00 +0200 Subject: [PATCH 1/3] fix(auth): delegate JWT parsing to github.com/golang-jwt/jwt Signed-off-by: Marc Nuri --- go.mod | 1 + go.sum | 2 + pkg/http/authorization.go | 98 ++++------ pkg/http/authorization_test.go | 316 ++++++++++++++------------------- 4 files changed, 169 insertions(+), 248 deletions(-) diff --git a/go.mod b/go.mod index e9214d4cd..da5b00679 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/BurntSushi/toml v1.5.0 github.com/coreos/go-oidc/v3 v3.14.1 github.com/fsnotify/fsnotify v1.9.0 + github.com/golang-jwt/jwt/v4 v4.5.2 github.com/mark3labs/mcp-go v0.34.0 github.com/pkg/errors v0.9.1 github.com/spf13/afero v1.14.0 diff --git a/go.sum b/go.sum index 082d047d8..773888ab4 100644 --- a/go.sum +++ b/go.sum @@ -118,6 +118,8 @@ github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= +github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= diff --git a/pkg/http/authorization.go b/pkg/http/authorization.go index c517cb72e..8397d96f7 100644 --- a/pkg/http/authorization.go +++ b/pkg/http/authorization.go @@ -2,14 +2,12 @@ package http import ( "context" - "encoding/base64" - "encoding/json" "fmt" "net/http" "strings" - "time" "github.com/coreos/go-oidc/v3/oidc" + "github.com/golang-jwt/jwt/v4" "k8s.io/klog/v2" "github.com/manusa/kubernetes-mcp-server/pkg/mcp" @@ -55,7 +53,10 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, mcpServer *mcp // Validate the token offline for simple sanity check // Because missing expected audience and expired tokens must be // rejected already. - claims, err := validateJWTToken(token, audience) + claims, err := ParseJWTClaims(token) + if err == nil && claims != nil { + err = claims.Validate(audience) + } if err != nil { klog.V(1).Infof("Authentication failed - JWT validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err) @@ -118,80 +119,49 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, mcpServer *mcp } } -type JWTClaims struct { - Issuer string `json:"iss"` - Audience any `json:"aud"` - ExpiresAt int64 `json:"exp"` - Scope string `json:"scope,omitempty"` -} +type JWTClaims jwt.MapClaims func (c *JWTClaims) GetScopes() []string { - if c.Scope == "" { - return nil - } - return strings.Fields(c.Scope) -} - -func (c *JWTClaims) ContainsAudience(audience string) bool { - switch aud := c.Audience.(type) { + scope := jwt.MapClaims(*c)["scope"] + switch scope.(type) { case string: - return aud == audience - case []interface{}: - for _, a := range aud { - if str, ok := a.(string); ok && str == audience { - return true - } - } - case []string: - for _, a := range aud { - if a == audience { - return true - } - } + return strings.Fields(scope.(string)) } - return false + return nil } -// validateJWTToken validates basic JWT claims without signature verification and returns the claims -func validateJWTToken(token, audience string) (*JWTClaims, error) { - parts := strings.Split(token, ".") - if len(parts) != 3 { - return nil, fmt.Errorf("invalid JWT token format") - } - - claims, err := parseJWTClaims(parts[1]) - if err != nil { - return nil, fmt.Errorf("failed to parse JWT claims: %v", err) - } - - if claims.ExpiresAt > 0 && time.Now().Unix() > claims.ExpiresAt { - return nil, fmt.Errorf("token expired") - } +func (c *JWTClaims) VerifyAudience(audience string) bool { + return jwt.MapClaims(*c).VerifyAudience(audience, true) +} - if !claims.ContainsAudience(audience) { - return nil, fmt.Errorf("token audience mismatch: %v", claims.Audience) - } +func (c *JWTClaims) VerifyExpiresAt(expriesAt int64) bool { + return jwt.MapClaims(*c).VerifyExpiresAt(expriesAt, true) +} - return claims, nil +func (c *JWTClaims) VerifyIssuer(issuer string) bool { + return jwt.MapClaims(*c).VerifyIssuer(issuer, true) } -func parseJWTClaims(payload string) (*JWTClaims, error) { - // Add padding if needed - if len(payload)%4 != 0 { - payload += strings.Repeat("=", 4-len(payload)%4) - } +func (c *JWTClaims) Valid() error { + return jwt.MapClaims(*c).Valid() +} - decoded, err := base64.URLEncoding.DecodeString(payload) - if err != nil { - return nil, fmt.Errorf("failed to decode JWT payload: %v", err) +// Validate Checks if the JWT claims are valid and if the audience matches the expected one. +func (c *JWTClaims) Validate(audience string) error { + if err := c.Valid(); err != nil { + return err } - - var claims JWTClaims - if err := json.Unmarshal(decoded, &claims); err != nil { - return nil, fmt.Errorf("failed to unmarshal JWT claims: %v", err) + if !c.VerifyAudience(audience) { + return fmt.Errorf("token audience mismatch: %v", jwt.MapClaims(*c)["aud"]) } + return nil +} - return &claims, nil +func ParseJWTClaims(token string) (*JWTClaims, error) { + parser := jwt.NewParser(jwt.WithoutClaimsValidation()) + mapClaims := &JWTClaims{} + _, _, err := parser.ParseUnverified(token, mapClaims) + return mapClaims, err } func validateTokenWithOIDC(ctx context.Context, provider *oidc.Provider, token, audience string) error { diff --git a/pkg/http/authorization_test.go b/pkg/http/authorization_test.go index 0ffe816e0..7a93eb8c6 100644 --- a/pkg/http/authorization_test.go +++ b/pkg/http/authorization_test.go @@ -1,187 +1,217 @@ package http import ( - "encoding/base64" - "encoding/json" "net/http" "net/http/httptest" "strings" "testing" - "time" ) -func TestParseJWTClaims(t *testing.T) { - t.Run("valid JWT payload", func(t *testing.T) { - // Sample payload from a valid JWT - payload := "eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoxNzUxOTYzOTQ4LCJpYXQiOjE3NTE5NjAzNDgsImlzcyI6Imh0dHBzOi8va3ViZXJuZXRlcy5kZWZhdWx0LnN2Yy5jbHVzdGVyLmxvY2FsIiwianRpIjoiOTkyMjJkNTYtMzQwZS00ZWI2LTg1ODgtMjYxNDExZjM1ZDI2Iiwia3ViZXJuZXRlcy5pbyI6eyJuYW1lc3BhY2UiOiJkZWZhdWx0Iiwic2VydmljZWFjY291bnQiOnsibmFtZSI6ImRlZmF1bHQiLCJ1aWQiOiJlYWNiNmFkMi04MGI3LTQxNzktODQzZC05MmViMWU2YmJiYTYifX0sIm5iZiI6MTc1MTk2MDM0OCwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6ZGVmYXVsdCJ9" +const ( + // https://jwt.io/#token=eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk4ZDU3YmUwNWI3ZjUzNWIwMzYyYjg2MDJhNTJlNGYxIn0.eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoyNTM0MDIyOTcxOTksImlhdCI6MCwiaXNzIjoiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJqdGkiOiI5OTIyMmQ1Ni0zNDBlLTRlYjYtODU4OC0yNjE0MTFmMzVkMjYiLCJrdWJlcm5ldGVzLmlvIjp7Im5hbWVzcGFjZSI6ImRlZmF1bHQiLCJzZXJ2aWNlYWNjb3VudCI6eyJuYW1lIjoiZGVmYXVsdCIsInVpZCI6ImVhY2I2YWQyLTgwYjctNDE3OS04NDNkLTkyZWIxZTZiYmJhNiJ9fSwibmJmIjowLCJzdWIiOiJzeXN0ZW06c2VydmljZWFjY291bnQ6ZGVmYXVsdDpkZWZhdWx0In0.0363P6xGmWpU-O9TAVkcOd95lPXxhI-_k5NKbHGNQeL--B8XMAz2vC8hpKnyC6rKOGifRTSR2XNHx_5fjd7lEA // notsecret + tokenBasicNotExpired = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk4ZDU3YmUwNWI3ZjUzNWIwMzYyYjg2MDJhNTJlNGYxIn0.eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoyNTM0MDIyOTcxOTksImlhdCI6MCwiaXNzIjoiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJqdGkiOiI5OTIyMmQ1Ni0zNDBlLTRlYjYtODU4OC0yNjE0MTFmMzVkMjYiLCJrdWJlcm5ldGVzLmlvIjp7Im5hbWVzcGFjZSI6ImRlZmF1bHQiLCJzZXJ2aWNlYWNjb3VudCI6eyJuYW1lIjoiZGVmYXVsdCIsInVpZCI6ImVhY2I2YWQyLTgwYjctNDE3OS04NDNkLTkyZWIxZTZiYmJhNiJ9fSwibmJmIjowLCJzdWIiOiJzeXN0ZW06c2VydmljZWFjY291bnQ6ZGVmYXVsdDpkZWZhdWx0In0.0363P6xGmWpU-O9TAVkcOd95lPXxhI-_k5NKbHGNQeL--B8XMAz2vC8hpKnyC6rKOGifRTSR2XNHx_5fjd7lEA" // notsecret + // https://jwt.io/#token=eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk4ZDU3YmUwNWI3ZjUzNWIwMzYyYjg2MDJhNTJlNGYxIn0.eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoxLCJpYXQiOjAsImlzcyI6Imh0dHBzOi8va3ViZXJuZXRlcy5kZWZhdWx0LnN2Yy5jbHVzdGVyLmxvY2FsIiwianRpIjoiOTkyMjJkNTYtMzQwZS00ZWI2LTg1ODgtMjYxNDExZjM1ZDI2Iiwia3ViZXJuZXRlcy5pbyI6eyJuYW1lc3BhY2UiOiJkZWZhdWx0Iiwic2VydmljZWFjY291bnQiOnsibmFtZSI6ImRlZmF1bHQiLCJ1aWQiOiJlYWNiNmFkMi04MGI3LTQxNzktODQzZC05MmViMWU2YmJiYTYifX0sIm5iZiI6MCwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6ZGVmYXVsdCJ9.USsuGLsB_7MwG9i0__cFkVVZa0djtmQpc8Vwi56GrapAgVAcyTfmae3s83XMDP5AwcFnxhYxLCfiZWRJri6GTA // notsecret + tokenBasicExpired = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk4ZDU3YmUwNWI3ZjUzNWIwMzYyYjg2MDJhNTJlNGYxIn0.eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoxLCJpYXQiOjAsImlzcyI6Imh0dHBzOi8va3ViZXJuZXRlcy5kZWZhdWx0LnN2Yy5jbHVzdGVyLmxvY2FsIiwianRpIjoiOTkyMjJkNTYtMzQwZS00ZWI2LTg1ODgtMjYxNDExZjM1ZDI2Iiwia3ViZXJuZXRlcy5pbyI6eyJuYW1lc3BhY2UiOiJkZWZhdWx0Iiwic2VydmljZWFjY291bnQiOnsibmFtZSI6ImRlZmF1bHQiLCJ1aWQiOiJlYWNiNmFkMi04MGI3LTQxNzktODQzZC05MmViMWU2YmJiYTYifX0sIm5iZiI6MCwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6ZGVmYXVsdCJ9.USsuGLsB_7MwG9i0__cFkVVZa0djtmQpc8Vwi56GrapAgVAcyTfmae3s83XMDP5AwcFnxhYxLCfiZWRJri6GTA" // notsecret + // https://jwt.io/#token=eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk4ZDU3YmUwNWI3ZjUzNWIwMzYyYjg2MDJhNTJlNGYxIn0.eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoyNTM0MDIyOTcxOTksImlhdCI6MCwiaXNzIjoiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJqdGkiOiI5OTIyMmQ1Ni0zNDBlLTRlYjYtODU4OC0yNjE0MTFmMzVkMjYiLCJrdWJlcm5ldGVzLmlvIjp7Im5hbWVzcGFjZSI6ImRlZmF1bHQiLCJzZXJ2aWNlYWNjb3VudCI6eyJuYW1lIjoiZGVmYXVsdCIsInVpZCI6ImVhY2I2YWQyLTgwYjctNDE3OS04NDNkLTkyZWIxZTZiYmJhNiJ9fSwibmJmIjowLCJzdWIiOiJzeXN0ZW06c2VydmljZWFjY291bnQ6ZGVmYXVsdDpkZWZhdWx0Iiwic2NvcGUiOiJyZWFkIHdyaXRlIn0.vl5se9BuxoVDhvR7M5wGfkLoyMSYUiORMZVxl0CQ7jw3x53mZfGEkU_kkIVIl9Ui371qCCVVxdvuZPcAgbM6pQ // notsecret + tokenMultipleAudienceNotExpired = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk4ZDU3YmUwNWI3ZjUzNWIwMzYyYjg2MDJhNTJlNGYxIn0.eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoyNTM0MDIyOTcxOTksImlhdCI6MCwiaXNzIjoiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJqdGkiOiI5OTIyMmQ1Ni0zNDBlLTRlYjYtODU4OC0yNjE0MTFmMzVkMjYiLCJrdWJlcm5ldGVzLmlvIjp7Im5hbWVzcGFjZSI6ImRlZmF1bHQiLCJzZXJ2aWNlYWNjb3VudCI6eyJuYW1lIjoiZGVmYXVsdCIsInVpZCI6ImVhY2I2YWQyLTgwYjctNDE3OS04NDNkLTkyZWIxZTZiYmJhNiJ9fSwibmJmIjowLCJzdWIiOiJzeXN0ZW06c2VydmljZWFjY291bnQ6ZGVmYXVsdDpkZWZhdWx0Iiwic2NvcGUiOiJyZWFkIHdyaXRlIn0.vl5se9BuxoVDhvR7M5wGfkLoyMSYUiORMZVxl0CQ7jw3x53mZfGEkU_kkIVIl9Ui371qCCVVxdvuZPcAgbM6pQ" // notsecret +) - claims, err := parseJWTClaims(payload) +func TestParseJWTClaimsPayloadValid(t *testing.T) { + basicClaims, err := ParseJWTClaims(tokenBasicNotExpired) + t.Run("Is parseable", func(t *testing.T) { if err != nil { t.Fatalf("expected no error, got %v", err) } - - if claims == nil { + if basicClaims == nil { t.Fatal("expected claims, got nil") } - - if claims.Issuer != "https://kubernetes.default.svc.cluster.local" { - t.Errorf("expected issuer 'https://kubernetes.default.svc.cluster.local', got %s", claims.Issuer) + }) + t.Run("Parses issuer", func(t *testing.T) { + if !basicClaims.VerifyIssuer("https://kubernetes.default.svc.cluster.local") { + t.Errorf("expected issuer 'https://kubernetes.default.svc.cluster.local', got %s", (*basicClaims)["iss"]) } - + }) + t.Run("Parses audience", func(t *testing.T) { expectedAudiences := []string{"https://kubernetes.default.svc.cluster.local", "kubernetes-mcp-server"} for _, expected := range expectedAudiences { - if !claims.ContainsAudience(expected) { + if !basicClaims.VerifyAudience(expected) { t.Errorf("expected audience to contain %s", expected) } } - - if claims.ExpiresAt != 1751963948 { - t.Errorf("expected exp 1751963948, got %d", claims.ExpiresAt) + }) + t.Run("Parses expiration", func(t *testing.T) { + if basicClaims.VerifyExpiresAt(253402297199) { + t.Errorf("expected expiration 1751963948, got %d", (*basicClaims)["exp"]) } }) - - t.Run("payload needs padding", func(t *testing.T) { - // Create a payload that needs padding - testClaims := JWTClaims{ - Issuer: "test-issuer", - Audience: "test-audience", - ExpiresAt: time.Now().Add(time.Hour).Unix(), + t.Run("Parses scope", func(t *testing.T) { + scopeClaims, err := ParseJWTClaims(tokenMultipleAudienceNotExpired) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if scopeClaims == nil { + t.Fatal("expected claims, got nil") } - jsonBytes, _ := json.Marshal(testClaims) - // Create a payload without proper padding - encodedWithoutPadding := strings.TrimRight(base64.URLEncoding.EncodeToString(jsonBytes), "=") + scopes := scopeClaims.GetScopes() - claims, err := parseJWTClaims(encodedWithoutPadding) + expectedScopes := []string{"read", "write"} + if len(scopes) != len(expectedScopes) { + t.Errorf("expected %d scopes, got %d", len(expectedScopes), len(scopes)) + } + for i, expectedScope := range expectedScopes { + if scopes[i] != expectedScope { + t.Errorf("expected scope[%d] to be '%s', got '%s'", i, expectedScope, scopes[i]) + } + } + }) + t.Run("Parses expired token", func(t *testing.T) { + expiredClaims, err := ParseJWTClaims(tokenBasicExpired) if err != nil { t.Fatalf("expected no error, got %v", err) } - if claims.Issuer != "test-issuer" { - t.Errorf("expected issuer 'test-issuer', got %s", claims.Issuer) + if expiredClaims.VerifyExpiresAt(1) { + t.Errorf("expected expiration 1751963948, got %d", (*basicClaims)["exp"]) } }) +} - t.Run("invalid base64 payload", func(t *testing.T) { - invalidPayload := "invalid-base64!!!" +func TestParseJWTClaimsPayloadInvalid(t *testing.T) { + t.Run("invalid token segments", func(t *testing.T) { + invalidToken := "header.payload.signature.extra" - _, err := parseJWTClaims(invalidPayload) + _, err := ParseJWTClaims(invalidToken) if err == nil { - t.Error("expected error for invalid base64, got nil") + t.Fatal("expected error for invalid token segments, got nil") } - if !strings.Contains(err.Error(), "failed to decode JWT payload") { - t.Errorf("expected decode error message, got %v", err) + if !strings.Contains(err.Error(), "token contains an invalid number of segments") { + t.Errorf("expected invalid token segments error message, got %v", err) } }) + t.Run("invalid base64 payload", func(t *testing.T) { + invalidPayload := "invalid_base64" + tokenBasicNotExpired - t.Run("invalid JSON payload", func(t *testing.T) { - // Valid base64 but invalid JSON - invalidJSON := base64.URLEncoding.EncodeToString([]byte("{invalid-json")) - - _, err := parseJWTClaims(invalidJSON) + _, err := ParseJWTClaims(invalidPayload) if err == nil { - t.Error("expected error for invalid JSON, got nil") + t.Fatal("expected error for invalid base64, got nil") } - if !strings.Contains(err.Error(), "failed to unmarshal JWT claims") { - t.Errorf("expected unmarshal error message, got %v", err) + if !strings.Contains(err.Error(), "illegal base64 data") { + t.Errorf("expected decode error message, got %v", err) } }) } -func TestValidateJWTToken(t *testing.T) { - t.Run("invalid token format - not enough parts", func(t *testing.T) { - invalidToken := "header.payload" +func TestJWTTokenValidate(t *testing.T) { + t.Run("expired token returns error", func(t *testing.T) { + claims, err := ParseJWTClaims(tokenBasicExpired) + if err != nil { + t.Fatalf("expected no error for expired token parsing, got %v", err) + } - _, err := validateJWTToken(invalidToken, "test") + err = claims.Validate("kubernetes-mcp-server") if err == nil { - t.Error("expected error for invalid token format, got nil") + t.Fatalf("expected error for expired token, got nil") } - if !strings.Contains(err.Error(), "invalid JWT token format") { - t.Errorf("expected format error message, got %v", err) + if !strings.Contains(err.Error(), "Token is expired") { + t.Errorf("expected expiration error message, got %v", err) } }) - t.Run("expired token", func(t *testing.T) { - // Create an expired token - expiredClaims := JWTClaims{ - Issuer: "test-issuer", - Audience: "kubernetes-mcp-server", - ExpiresAt: time.Now().Add(-time.Hour).Unix(), + t.Run("multiple audiences with correct one", func(t *testing.T) { + claims, err := ParseJWTClaims(tokenMultipleAudienceNotExpired) + if err != nil { + t.Fatalf("expected no error for multiple audience token parsing, got %v", err) + } + if claims == nil { + t.Fatalf("expected claims to be returned, got nil") } - jsonBytes, _ := json.Marshal(expiredClaims) - payload := base64.URLEncoding.EncodeToString(jsonBytes) - expiredToken := "header." + payload + ".signature" + err = claims.Validate("kubernetes-mcp-server") + if err != nil { + t.Fatalf("expected no error for valid audience, got %v", err) + } + }) - _, err := validateJWTToken(expiredToken, "kubernetes-mcp-server") + t.Run("multiple audiences with mismatch returns error", func(t *testing.T) { + claims, err := ParseJWTClaims(tokenMultipleAudienceNotExpired) + if err != nil { + t.Fatalf("expected no error for multiple audience token parsing, got %v", err) + } + if claims == nil { + t.Fatalf("expected claims to be returned, got nil") + } + + err = claims.Validate("missing-audience") if err == nil { - t.Error("expected error for expired token, got nil") + t.Fatalf("expected error for token with wrong audience, got nil") } - if !strings.Contains(err.Error(), "token expired") { - t.Errorf("expected expiration error message, got %v", err) + if !strings.Contains(err.Error(), "token audience mismatch") { + t.Errorf("expected audience mismatch error, got %v", err) } }) +} - t.Run("multiple audiences with correct one", func(t *testing.T) { - // Create a token with multiple audiences including the correct one - multiAudClaims := JWTClaims{ - Issuer: "test-issuer", - Audience: []string{"other-audience", "kubernetes-mcp-server", "another-audience"}, - ExpiresAt: time.Now().Add(time.Hour).Unix(), - Scope: "read write admin", +func TestJWTClaimsGetScopes(t *testing.T) { + t.Run("no scopes", func(t *testing.T) { + claims, err := ParseJWTClaims(tokenBasicExpired) + if err != nil { + t.Fatalf("expected no error for parsing token, got %v", err) } - jsonBytes, _ := json.Marshal(multiAudClaims) - payload := base64.URLEncoding.EncodeToString(jsonBytes) - multiAudToken := "header." + payload + ".signature" - - claims, err := validateJWTToken(multiAudToken, "kubernetes-mcp-server") - if err != nil { - t.Errorf("expected no error for token with multiple audiences, got %v", err) + if scopes := claims.GetScopes(); len(scopes) != 0 { + t.Errorf("expected no scopes, got %d", len(scopes)) } - if claims == nil { - t.Error("expected claims to be returned, got nil") + }) + t.Run("single scope", func(t *testing.T) { + claims := &JWTClaims{} + (*claims)["scope"] = "read" + scopes := claims.GetScopes() + expected := []string{"read"} + + if len(scopes) != 1 { + t.Errorf("expected 1 scope, got %d", len(scopes)) } - if claims.Issuer != "test-issuer" { - t.Errorf("expected issuer 'test-issuer', got %s", claims.Issuer) + if scopes[0] != expected[0] { + t.Errorf("expected scope 'read', got '%s'", scopes[0]) } + }) - // Test scope parsing + t.Run("multiple scopes", func(t *testing.T) { + claims := &JWTClaims{} + (*claims)["scope"] = "read write admin" scopes := claims.GetScopes() - expectedScopes := []string{"read", "write", "admin"} + expected := []string{"read", "write", "admin"} + if len(scopes) != 3 { t.Errorf("expected 3 scopes, got %d", len(scopes)) } - for i, expectedScope := range expectedScopes { + + for i, expectedScope := range expected { if i >= len(scopes) || scopes[i] != expectedScope { t.Errorf("expected scope[%d] to be '%s', got '%s'", i, expectedScope, scopes[i]) } } }) - t.Run("audience mismatch", func(t *testing.T) { - // Create a token with wrong audience - wrongAudClaims := JWTClaims{ - Issuer: "test-issuer", - Audience: "wrong-audience", - ExpiresAt: time.Now().Add(time.Hour).Unix(), - } - - jsonBytes, _ := json.Marshal(wrongAudClaims) - payload := base64.URLEncoding.EncodeToString(jsonBytes) - wrongAudToken := "header." + payload + ".signature" + t.Run("scopes with extra whitespace", func(t *testing.T) { + claims := &JWTClaims{} + (*claims)["scope"] = " read write admin " + scopes := claims.GetScopes() + expected := []string{"read", "write", "admin"} - _, err := validateJWTToken(wrongAudToken, "audience") - if err == nil { - t.Error("expected error for token with wrong audience, got nil") + if len(scopes) != 3 { + t.Errorf("expected 3 scopes, got %d", len(scopes)) } - if !strings.Contains(err.Error(), "audience mismatch") { - t.Errorf("expected audience mismatch error, got %v", err) + for i, expectedScope := range expected { + if i >= len(scopes) || scopes[i] != expectedScope { + t.Errorf("expected scope[%d] to be '%s', got '%s'", i, expectedScope, scopes[i]) + } } }) } @@ -285,85 +315,3 @@ func TestAuthorizationMiddleware(t *testing.T) { } }) } - -func TestJWTClaimsGetScopes(t *testing.T) { - t.Run("single scope", func(t *testing.T) { - claims := &JWTClaims{Scope: "read"} - scopes := claims.GetScopes() - expected := []string{"read"} - - if len(scopes) != 1 { - t.Errorf("expected 1 scope, got %d", len(scopes)) - } - if scopes[0] != expected[0] { - t.Errorf("expected scope 'read', got '%s'", scopes[0]) - } - }) - - t.Run("multiple scopes", func(t *testing.T) { - claims := &JWTClaims{Scope: "read write admin"} - scopes := claims.GetScopes() - expected := []string{"read", "write", "admin"} - - if len(scopes) != 3 { - t.Errorf("expected 3 scopes, got %d", len(scopes)) - } - - for i, expectedScope := range expected { - if i >= len(scopes) || scopes[i] != expectedScope { - t.Errorf("expected scope[%d] to be '%s', got '%s'", i, expectedScope, scopes[i]) - } - } - }) - - t.Run("scopes with extra whitespace", func(t *testing.T) { - claims := &JWTClaims{Scope: " read write admin "} - scopes := claims.GetScopes() - expected := []string{"read", "write", "admin"} - - if len(scopes) != 3 { - t.Errorf("expected 3 scopes, got %d", len(scopes)) - } - - for i, expectedScope := range expected { - if i >= len(scopes) || scopes[i] != expectedScope { - t.Errorf("expected scope[%d] to be '%s', got '%s'", i, expectedScope, scopes[i]) - } - } - }) -} - -func TestJWTClaimsContainsAudience(t *testing.T) { - t.Run("single string audience", func(t *testing.T) { - claims := &JWTClaims{Audience: "test-audience"} - - if !claims.ContainsAudience("test-audience") { - t.Error("expected ContainsAudience to return true for matching audience") - } - - if claims.ContainsAudience("other-audience") { - t.Error("expected ContainsAudience to return false for non-matching audience") - } - }) - - t.Run("array audience", func(t *testing.T) { - claims := &JWTClaims{Audience: []string{"aud1", "aud2", "aud3"}} - - testCases := []struct { - audience string - expected bool - }{ - {"aud1", true}, - {"aud2", true}, - {"aud3", true}, - {"aud4", false}, - {"", false}, - } - - for _, tc := range testCases { - if claims.ContainsAudience(tc.audience) != tc.expected { - t.Errorf("expected ContainsAudience(%s) to return %v", tc.audience, tc.expected) - } - } - }) -} From f28131345c43acd6f7b27f8be95f71b559a29943 Mon Sep 17 00:00:00 2001 From: Marc Nuri Date: Thu, 17 Jul 2025 15:14:11 +0200 Subject: [PATCH 2/3] fix(auth): delegate JWT parsing to go-jose Signed-off-by: Marc Nuri --- go.mod | 3 +- go.sum | 2 -- pkg/http/authorization.go | 66 +++++++++++++++++----------------- pkg/http/authorization_test.go | 37 ++++++++++--------- 4 files changed, 55 insertions(+), 53 deletions(-) diff --git a/go.mod b/go.mod index da5b00679..13e9f8d67 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/BurntSushi/toml v1.5.0 github.com/coreos/go-oidc/v3 v3.14.1 github.com/fsnotify/fsnotify v1.9.0 - github.com/golang-jwt/jwt/v4 v4.5.2 + github.com/go-jose/go-jose/v4 v4.0.5 github.com/mark3labs/mcp-go v0.34.0 github.com/pkg/errors v0.9.1 github.com/spf13/afero v1.14.0 @@ -53,7 +53,6 @@ require ( github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/go-errors/errors v1.4.2 // indirect github.com/go-gorp/gorp/v3 v3.1.0 // indirect - github.com/go-jose/go-jose/v4 v4.0.5 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-openapi/jsonpointer v0.21.0 // indirect github.com/go-openapi/jsonreference v0.20.2 // indirect diff --git a/go.sum b/go.sum index 773888ab4..082d047d8 100644 --- a/go.sum +++ b/go.sum @@ -118,8 +118,6 @@ github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= -github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= diff --git a/pkg/http/authorization.go b/pkg/http/authorization.go index 8397d96f7..6cd263bf3 100644 --- a/pkg/http/authorization.go +++ b/pkg/http/authorization.go @@ -7,7 +7,8 @@ import ( "strings" "github.com/coreos/go-oidc/v3/oidc" - "github.com/golang-jwt/jwt/v4" + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" "k8s.io/klog/v2" "github.com/manusa/kubernetes-mcp-server/pkg/mcp" @@ -119,49 +120,48 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, mcpServer *mcp } } -type JWTClaims jwt.MapClaims - -func (c *JWTClaims) GetScopes() []string { - scope := jwt.MapClaims(*c)["scope"] - switch scope.(type) { - case string: - return strings.Fields(scope.(string)) - } - return nil -} - -func (c *JWTClaims) VerifyAudience(audience string) bool { - return jwt.MapClaims(*c).VerifyAudience(audience, true) -} - -func (c *JWTClaims) VerifyExpiresAt(expriesAt int64) bool { - return jwt.MapClaims(*c).VerifyExpiresAt(expriesAt, true) +var allSignatureAlgorithms = []jose.SignatureAlgorithm{ + jose.EdDSA, + jose.HS256, + jose.HS384, + jose.HS512, + jose.RS256, + jose.RS384, + jose.RS512, + jose.ES256, + jose.ES384, + jose.ES512, + jose.PS256, + jose.PS384, + jose.PS512, } -func (c *JWTClaims) VerifyIssuer(issuer string) bool { - return jwt.MapClaims(*c).VerifyIssuer(issuer, true) +type JWTClaims struct { + jwt.Claims + Scope string `json:"scope,omitempty"` } -func (c *JWTClaims) Valid() error { - return jwt.MapClaims(*c).Valid() +func (c *JWTClaims) GetScopes() []string { + if c.Scope == "" { + return nil + } + return strings.Fields(c.Scope) } // Validate Checks if the JWT claims are valid and if the audience matches the expected one. func (c *JWTClaims) Validate(audience string) error { - if err := c.Valid(); err != nil { - return err - } - if !c.VerifyAudience(audience) { - return fmt.Errorf("token audience mismatch: %v", jwt.MapClaims(*c)["aud"]) - } - return nil + return c.Claims.Validate(jwt.Expected{ + AnyAudience: jwt.Audience{audience}, + }) } func ParseJWTClaims(token string) (*JWTClaims, error) { - parser := jwt.NewParser(jwt.WithoutClaimsValidation()) - mapClaims := &JWTClaims{} - _, _, err := parser.ParseUnverified(token, mapClaims) - return mapClaims, err + tkn, err := jwt.ParseSigned(token, allSignatureAlgorithms) + if err != nil { + return nil, fmt.Errorf("failed to parse JWT token: %w", err) + } + claims := &JWTClaims{} + return claims, tkn.UnsafeClaimsWithoutVerification(claims) } func validateTokenWithOIDC(ctx context.Context, provider *oidc.Provider, token, audience string) error { diff --git a/pkg/http/authorization_test.go b/pkg/http/authorization_test.go index 7a93eb8c6..9dd45111c 100644 --- a/pkg/http/authorization_test.go +++ b/pkg/http/authorization_test.go @@ -5,6 +5,8 @@ import ( "net/http/httptest" "strings" "testing" + + "github.com/go-jose/go-jose/v4/jwt" ) const ( @@ -27,21 +29,21 @@ func TestParseJWTClaimsPayloadValid(t *testing.T) { } }) t.Run("Parses issuer", func(t *testing.T) { - if !basicClaims.VerifyIssuer("https://kubernetes.default.svc.cluster.local") { - t.Errorf("expected issuer 'https://kubernetes.default.svc.cluster.local', got %s", (*basicClaims)["iss"]) + if basicClaims.Issuer != "https://kubernetes.default.svc.cluster.local" { + t.Errorf("expected issuer 'https://kubernetes.default.svc.cluster.local', got %s", basicClaims.Issuer) } }) t.Run("Parses audience", func(t *testing.T) { expectedAudiences := []string{"https://kubernetes.default.svc.cluster.local", "kubernetes-mcp-server"} for _, expected := range expectedAudiences { - if !basicClaims.VerifyAudience(expected) { + if !basicClaims.Audience.Contains(expected) { t.Errorf("expected audience to contain %s", expected) } } }) t.Run("Parses expiration", func(t *testing.T) { - if basicClaims.VerifyExpiresAt(253402297199) { - t.Errorf("expected expiration 1751963948, got %d", (*basicClaims)["exp"]) + if *basicClaims.Expiry != jwt.NumericDate(253402297199) { + t.Errorf("expected expiration 253402297199, got %d", basicClaims.Expiry) } }) t.Run("Parses scope", func(t *testing.T) { @@ -71,8 +73,8 @@ func TestParseJWTClaimsPayloadValid(t *testing.T) { t.Fatalf("expected no error, got %v", err) } - if expiredClaims.VerifyExpiresAt(1) { - t.Errorf("expected expiration 1751963948, got %d", (*basicClaims)["exp"]) + if *expiredClaims.Expiry != jwt.NumericDate(1) { + t.Errorf("expected expiration 1, got %d", basicClaims.Expiry) } }) } @@ -86,7 +88,7 @@ func TestParseJWTClaimsPayloadInvalid(t *testing.T) { t.Fatal("expected error for invalid token segments, got nil") } - if !strings.Contains(err.Error(), "token contains an invalid number of segments") { + if !strings.Contains(err.Error(), "compact JWS format must have three parts") { t.Errorf("expected invalid token segments error message, got %v", err) } }) @@ -116,7 +118,7 @@ func TestJWTTokenValidate(t *testing.T) { t.Fatalf("expected error for expired token, got nil") } - if !strings.Contains(err.Error(), "Token is expired") { + if !strings.Contains(err.Error(), "token is expired (exp)") { t.Errorf("expected expiration error message, got %v", err) } }) @@ -150,7 +152,7 @@ func TestJWTTokenValidate(t *testing.T) { t.Fatalf("expected error for token with wrong audience, got nil") } - if !strings.Contains(err.Error(), "token audience mismatch") { + if !strings.Contains(err.Error(), "invalid audience claim (aud)") { t.Errorf("expected audience mismatch error, got %v", err) } }) @@ -168,8 +170,9 @@ func TestJWTClaimsGetScopes(t *testing.T) { } }) t.Run("single scope", func(t *testing.T) { - claims := &JWTClaims{} - (*claims)["scope"] = "read" + claims := &JWTClaims{ + Scope: "read", + } scopes := claims.GetScopes() expected := []string{"read"} @@ -182,8 +185,9 @@ func TestJWTClaimsGetScopes(t *testing.T) { }) t.Run("multiple scopes", func(t *testing.T) { - claims := &JWTClaims{} - (*claims)["scope"] = "read write admin" + claims := &JWTClaims{ + Scope: "read write admin", + } scopes := claims.GetScopes() expected := []string{"read", "write", "admin"} @@ -199,8 +203,9 @@ func TestJWTClaimsGetScopes(t *testing.T) { }) t.Run("scopes with extra whitespace", func(t *testing.T) { - claims := &JWTClaims{} - (*claims)["scope"] = " read write admin " + claims := &JWTClaims{ + Scope: " read write admin ", + } scopes := claims.GetScopes() expected := []string{"read", "write", "admin"} From 5cc6098bc9b9e79f9f5e095ab88c34fb5b346ccc Mon Sep 17 00:00:00 2001 From: Marc Nuri Date: Fri, 18 Jul 2025 11:40:25 +0200 Subject: [PATCH 3/3] fix(auth): delegate JWT parsing to go-jose - review comment Signed-off-by: Marc Nuri --- pkg/http/authorization.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/http/authorization.go b/pkg/http/authorization.go index 6cd263bf3..d46e25f64 100644 --- a/pkg/http/authorization.go +++ b/pkg/http/authorization.go @@ -161,7 +161,8 @@ func ParseJWTClaims(token string) (*JWTClaims, error) { return nil, fmt.Errorf("failed to parse JWT token: %w", err) } claims := &JWTClaims{} - return claims, tkn.UnsafeClaimsWithoutVerification(claims) + err = tkn.UnsafeClaimsWithoutVerification(claims) + return claims, err } func validateTokenWithOIDC(ctx context.Context, provider *oidc.Provider, token, audience string) error {