diff --git a/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go b/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go index 1e855fec0d6e2..de8d27a291ca2 100644 --- a/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go +++ b/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go @@ -22,13 +22,16 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "encoding/base64" "encoding/json" "errors" "fmt" "math/big" "net" "os" + "regexp" "slices" + "strings" "testing" "time" @@ -750,6 +753,33 @@ func TestIssueWorkloadIdentity(t *testing.T) { }, requireErr: require.NoError, assert: func(t *testing.T, res *workloadidentityv1pb.IssueWorkloadIdentityResponse) { + // Checks for a bug where unix epoch timestamps (e.g. the `exp` + // and `iat` claims) were represented in scientific notation + // rather than as plain integers due to a conversion bug. + payloadSection := strings.Split( + res.GetCredential().GetJwtSvid().GetJwt(), + ".", + )[1] + payload, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(payloadSection, "=")) + require.NoError(t, err) + + var numericClaims struct { + Exp json.Number `json:"exp"` + Iat json.Number `json:"iat"` + } + require.NoError(t, json.Unmarshal(payload, &numericClaims)) + + integerExpr, err := regexp.Compile(`^\d+$`) + require.NoError(t, err) + require.Truef(t, + integerExpr.MatchString(numericClaims.Exp.String()), + "unexpected number format: %s", numericClaims.Exp.String(), + ) + require.Truef(t, + integerExpr.MatchString(numericClaims.Iat.String()), + "unexpected number format: %s", numericClaims.Iat.String(), + ) + parsed, err := jwt.ParseSigned(res.GetCredential().GetJwtSvid().GetJwt()) require.NoError(t, err) diff --git a/lib/jwt/jwt.go b/lib/jwt/jwt.go index 01ff10ba4f37c..d8ca5d61f5ef0 100644 --- a/lib/jwt/jwt.go +++ b/lib/jwt/jwt.go @@ -306,35 +306,46 @@ type SignParamsJWTSVID struct { func (k *Key) SignJWTSVID(p SignParamsJWTSVID) (string, error) { // Record time here for consistency between exp and iat. now := k.config.Clock.Now() - claims := jwt.Claims{ + + // We use map[string]any instead of jwt.Claims to avoid a json.Marshal/Unmarshal + // round-trip that would convert jwt.NumericDate (int64) to float64, causing + // timestamp claims to be serialized in scientific notation (e.g., "exp": 1.7e9). + // Using map[string]any preserves the jwt.NumericDate type until final marshaling. + claims := map[string]any{ // > 3.1. Subject: // > The sub claim MUST be set to the SPIFFE ID of the workload to which it is issued. - Subject: p.SPIFFEID.String(), + "sub": p.SPIFFEID.String(), + // > 3.2. Audience: // > The aud claim MUST be present, containing one or more values. - Audience: p.Audiences, + "aud": jwt.Audience(p.Audiences), + // > 3.3. Expiration Time: // > The exp claim MUST be set - Expiry: jwt.NewNumericDate(now.Add(p.TTL)), + "exp": jwt.NewNumericDate(now.Add(p.TTL)), + // The spec makes no comment on inclusion of `iat`, but the SPIRE // implementation does set this value and it feels like a good idea. - IssuedAt: jwt.NewNumericDate(now), + "iat": jwt.NewNumericDate(now), + // > 7.1. Replay Protection // > the jti claim is permitted by this specification, it should be // > noted that JWT-SVID validators are not required to track jti // > uniqueness. - ID: p.JTI, + "jti": p.JTI, + // The SPIFFE specification makes no comment on the inclusion of `iss`, // however, we provide this value so that the issued token can be a // valid OIDC ID token and used with non-SPIFFE aware systems that do // understand OIDC. - Issuer: p.Issuer, + "iss": p.Issuer, } + if !p.SetIssuedAt.IsZero() { - claims.IssuedAt = jwt.NewNumericDate(p.SetIssuedAt) + claims["iat"] = jwt.NewNumericDate(p.SetIssuedAt) } if !p.SetExpiry.IsZero() { - claims.Expiry = jwt.NewNumericDate(p.SetExpiry) + claims["exp"] = jwt.NewNumericDate(p.SetExpiry) } // > 2.2. Key ID: @@ -361,30 +372,17 @@ func (k *Key) SignJWTSVID(p SignParamsJWTSVID) (string, error) { // // > Registered claims not described in this document, in addition to // > private claims, MAY be used as implementers see fit. - var rawClaims any = claims if len(p.PrivateClaims) != 0 { - // This is slightly awkward. We take a round-trip through json.Marshal - // and json.Unmarshal to get a version of the claims we can add to. - marshaled, err := json.Marshal(rawClaims) - if err != nil { - return "", trace.Wrap(err, "marshaling claims") - } - var unmarshaled map[string]any - if err := json.Unmarshal(marshaled, &unmarshaled); err != nil { - return "", trace.Wrap(err, "unmarshaling claims") - } - // Only inject claims that don't conflict with an existing primary claim // such as sub or aud. for k, v := range p.PrivateClaims { - if _, ok := unmarshaled[k]; !ok { - unmarshaled[k] = v + if _, ok := claims[k]; !ok { + claims[k] = v } } - rawClaims = unmarshaled } - return k.sign(rawClaims, opts) + return k.sign(claims, opts) } // SignEntraOIDC signs a JWT for the Entra ID Integration.