Skip to content

Commit 169796b

Browse files
hcsa73Henrique Santos
andauthored
Remove token validation using JWKS in KeyFlow (#328)
* Remove token validation with JWKS * Fix comment * Add changelog * Fix changelog --------- Co-authored-by: Henrique Santos <[email protected]>
1 parent 1a60853 commit 169796b

File tree

8 files changed

+55
-187
lines changed

8 files changed

+55
-187
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
- **Feature:** Add package `runtime`, which implements methods to be used when performing API requests
55
- **Feature:** Add method `WithCaptureHTTPResponse` to package `runtime`, which does the same as `config.WithCaptureHTTPResponse`. Method was moved to avoid confusion due to it not being a configuration option, and will be removed in a later release.
66
- **Deprecation:** Mark method `config.WithCaptureHTTPResponse` as deprecated, to avoid confusion due to it not being a configuration option. Use `runtime.WithCaptureHTTPResponse` instead.
7+
- **Deprecation:** Marked method `config.WithJWKSEndpoint` as deprecated. Validation using JWKS was removed, for being redundant with token validation done in the APIs. This option has no effect
78

89
## Release (2024-02-07)
910

core/CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
## v0.9.0 (YYYY-MM-DD)
22

3-
- **Deprecation:** Mark method `config.WithCaptureHTTPResponse` as deprecated, to avoid confusion due to it not being a configuration option. Use `runtime.WithCaptureHTTPResponse` instead.
3+
- **Deprecation:** Marked method `config.WithCaptureHTTPResponse` as deprecated, to avoid confusion due
4+
- **Deprecation:** Marked method `config.WithJWKSEndpoint` as deprecated. Validation using JWKS was removed, for being redundant with token validation done in the APIs. This option has no effect
45

56
## v0.8.0 (2024-02-16)
67

core/auth/auth.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -178,19 +178,12 @@ func KeyAuth(cfg *config.Configuration) (http.RoundTripper, error) {
178178
cfg.TokenCustomUrl = tokenCustomUrl
179179
}
180180
}
181-
if cfg.JWKSCustomUrl == "" {
182-
jwksCustomUrl, jwksUrlSet := os.LookupEnv("STACKIT_JWKS_BASEURL")
183-
if jwksUrlSet {
184-
cfg.JWKSCustomUrl = jwksCustomUrl
185-
}
186-
}
187181

188182
keyCfg := clients.KeyFlowConfig{
189183
ServiceAccountKey: serviceAccountKey,
190184
PrivateKey: cfg.PrivateKey,
191185
ClientRetry: cfg.RetryOptions,
192186
TokenUrl: cfg.TokenCustomUrl,
193-
JWKSUrl: cfg.JWKSCustomUrl,
194187
}
195188

196189
client := &clients.KeyFlow{}

core/clients/key_flow.go

Lines changed: 14 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,13 @@ import (
55
"crypto/x509"
66
"encoding/json"
77
"encoding/pem"
8-
"errors"
98
"fmt"
109
"io"
1110
"net/http"
1211
"net/url"
1312
"strings"
1413
"time"
1514

16-
"github.com/MicahParks/keyfunc/v2"
1715
"github.com/golang-jwt/jwt/v5"
1816
"github.com/google/uuid"
1917
)
@@ -26,7 +24,6 @@ const (
2624
ServiceAccountKeyPath = "STACKIT_SERVICE_ACCOUNT_KEY_PATH"
2725
PrivateKeyPath = "STACKIT_PRIVATE_KEY_PATH"
2826
tokenAPI = "https://service-account.api.stackit.cloud/token" //nolint:gosec // linter false positive
29-
jwksAPI = "https://service-account.api.stackit.cloud/.well-known/jwks.json"
3027
defaultTokenType = "Bearer"
3128
defaultScope = ""
3229
)
@@ -48,7 +45,6 @@ type KeyFlowConfig struct {
4845
PrivateKey string
4946
ClientRetry *RetryConfig
5047
TokenUrl string
51-
JWKSUrl string
5248
}
5349

5450
// TokenResponseBody is the API response
@@ -112,13 +108,9 @@ func (c *KeyFlow) Init(cfg *KeyFlowConfig) error {
112108
c.config = cfg
113109
c.doer = Do
114110

115-
// set defaults if no custom token and jwks url are provided
116111
if c.config.TokenUrl == "" {
117112
c.config.TokenUrl = tokenAPI
118113
}
119-
if c.config.JWKSUrl == "" {
120-
c.config.JWKSUrl = jwksAPI
121-
}
122114
c.configureHTTPClient()
123115
if c.config.ClientRetry == nil {
124116
c.config.ClientRetry = NewRetryConfig()
@@ -181,11 +173,11 @@ func (c *KeyFlow) RoundTrip(req *http.Request) (*http.Response, error) {
181173

182174
// GetAccessToken returns a short-lived access token and saves the access and refresh tokens in the token field
183175
func (c *KeyFlow) GetAccessToken() (string, error) {
184-
accessTokenIsValid, err := c.validateToken(c.token.AccessToken)
176+
accessTokenExpired, err := tokenExpired(c.token.AccessToken)
185177
if err != nil {
186178
return "", fmt.Errorf("failed initial validation: %w", err)
187179
}
188-
if accessTokenIsValid {
180+
if !accessTokenExpired {
189181
return c.token.AccessToken, nil
190182
}
191183
if err := c.recreateAccessToken(); err != nil {
@@ -232,11 +224,11 @@ func (c *KeyFlow) validate() error {
232224
// recreateAccessToken is used to create a new access token
233225
// when the existing one isn't valid anymore
234226
func (c *KeyFlow) recreateAccessToken() error {
235-
refreshTokenIsValid, err := c.validateToken(c.token.RefreshToken)
227+
refreshTokenExpired, err := tokenExpired(c.token.RefreshToken)
236228
if err != nil {
237229
return err
238230
}
239-
if refreshTokenIsValid {
231+
if !refreshTokenExpired {
240232
return c.createAccessTokenWithRefreshToken()
241233
}
242234
return c.createAccessToken()
@@ -327,54 +319,18 @@ func (c *KeyFlow) parseTokenResponse(res *http.Response) error {
327319
return json.Unmarshal(body, c.token)
328320
}
329321

330-
// validateToken returns true if token is valid
331-
func (c *KeyFlow) validateToken(token string) (bool, error) {
332-
if token == "" {
333-
return false, nil
334-
}
335-
if _, err := c.parseToken(token); err != nil {
336-
if errors.Is(err, jwt.ErrTokenExpired) {
337-
c.token = &TokenResponseBody{}
338-
return false, nil
339-
}
340-
return false, fmt.Errorf("parse token: %w", err)
341-
}
342-
return true, nil
343-
}
344-
345-
// parseToken parses and validates a JWT token
346-
func (c *KeyFlow) parseToken(token string) (*jwt.Token, error) {
347-
b, err := c.getJwksJSON(token)
348-
if err != nil {
349-
return nil, fmt.Errorf("get JWKS Json: %w", err)
350-
}
351-
var jwksBytes = json.RawMessage(b)
352-
jwks, err := keyfunc.NewJSON(jwksBytes)
353-
if err != nil {
354-
return nil, fmt.Errorf("get JWKS function from JSON: %w", err)
355-
}
356-
return jwt.Parse(token, jwks.Keyfunc)
357-
}
358-
359-
func (c *KeyFlow) getJwksJSON(token string) (jwks []byte, err error) {
360-
req, err := http.NewRequest("GET", c.config.JWKSUrl, http.NoBody)
322+
func tokenExpired(token string) (bool, error) {
323+
// We can safely use ParseUnverified because we are not authenticating the user at this point.
324+
// We're just checking the expiration time
325+
tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{})
361326
if err != nil {
362-
return nil, err
327+
return false, fmt.Errorf("parse access token: %w", err)
363328
}
364-
req.Header.Add("Authorization", "Bearer "+token)
365-
res, err := c.doer(&http.Client{}, req, c.config.ClientRetry)
329+
expirationTimestampNumeric, err := tokenParsed.Claims.GetExpirationTime()
366330
if err != nil {
367-
return nil, err
368-
}
369-
defer func() {
370-
tempErr := res.Body.Close()
371-
if tempErr != nil {
372-
jwks = nil
373-
err = fmt.Errorf("closing get jwks response: %w", tempErr)
374-
}
375-
}()
376-
if res.StatusCode != 200 {
377-
return nil, fmt.Errorf("getting jwks return error status %s", res.Status)
331+
return false, fmt.Errorf("get expiration timestamp from access token: %w", err)
378332
}
379-
return io.ReadAll(res.Body)
333+
expirationTimestamp := expirationTimestampNumeric.Time
334+
now := time.Now()
335+
return now.After(expirationTimestamp), nil
380336
}

core/clients/key_flow_test.go

Lines changed: 35 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package clients
22

33
import (
4-
"bytes"
54
"crypto/rand"
65
"crypto/rsa"
76
"crypto/x509"
@@ -21,7 +20,6 @@ import (
2120

2221
var (
2322
testSigningKey = []byte(`Test`)
24-
testJwks = []byte(`{ "keys": [ { "kty":"oct", "kid":"test", "alg":"HS256" } ] }`)
2523
)
2624

2725
func fixtureServiceAccountKey(mods ...func(*ServiceAccountKeyResponse)) *ServiceAccountKeyResponse {
@@ -126,10 +124,6 @@ func TestKeyFlowInit(t *testing.T) {
126124
}
127125
}
128126

129-
type MyCustomClaims struct {
130-
Foo string `json:"foo"`
131-
}
132-
133127
func TestSetToken(t *testing.T) {
134128
tests := []struct {
135129
name string
@@ -207,127 +201,55 @@ func TestKeyClone(t *testing.T) {
207201
}
208202
}
209203

210-
func TestKeyFlowValidateToken(t *testing.T) {
211-
// Generate a random private key
212-
privateKey := make([]byte, 32)
213-
if _, err := rand.Read(privateKey); err != nil {
214-
t.Fatal(err)
215-
}
204+
func TestTokenExpired(t *testing.T) {
216205
tests := []struct {
217-
name string
218-
token string
219-
jwksMockResponse *http.Response
220-
jwksMockError error
221-
want bool
222-
wantErr bool
206+
desc string
207+
tokenInvalid bool
208+
tokenExpiresAt time.Time
209+
expectedErr bool
210+
expectedIsExpired bool
223211
}{
224212
{
225-
name: "no token",
226-
token: "",
227-
want: false,
228-
wantErr: false,
213+
desc: "token valid",
214+
tokenExpiresAt: time.Now().Add(time.Hour),
215+
expectedErr: false,
216+
expectedIsExpired: false,
229217
},
230218
{
231-
name: "invalid token",
232-
token: "bad token",
233-
jwksMockResponse: &http.Response{
234-
StatusCode: 200,
235-
Body: io.NopCloser(bytes.NewReader(testJwks)),
236-
},
237-
jwksMockError: nil,
238-
want: false,
239-
wantErr: true,
219+
desc: "token expired",
220+
tokenExpiresAt: time.Now().Add(-time.Hour),
221+
expectedErr: false,
222+
expectedIsExpired: true,
240223
},
241224
{
242-
name: "get_jwks_fail",
243-
token: "bad token",
244-
jwksMockResponse: nil,
245-
jwksMockError: fmt.Errorf("error"),
246-
want: false,
247-
wantErr: true,
225+
desc: "token invalid",
226+
tokenInvalid: true,
227+
expectedErr: true,
248228
},
249229
}
250-
for _, tt := range tests {
251-
t.Run(tt.name, func(t *testing.T) {
252-
mockDo := func(client *http.Client, req *http.Request, cfg *RetryConfig) (resp *http.Response, err error) {
253-
return tt.jwksMockResponse, tt.jwksMockError
254-
}
255-
c := &KeyFlow{
256-
config: &KeyFlowConfig{
257-
PrivateKey: string(privateKey),
258-
JWKSUrl: jwksAPI,
259-
},
260-
doer: mockDo,
261-
}
262230

263-
got, err := c.validateToken(tt.token)
264-
if (err != nil) != tt.wantErr {
265-
t.Errorf("KeyFlow.validateToken() error = %v, wantErr %v", err, tt.wantErr)
266-
return
267-
}
268-
if got != tt.want {
269-
t.Errorf("KeyFlow.validateToken() = %v, want %v", got, tt.want)
231+
for _, tt := range tests {
232+
t.Run(tt.desc, func(t *testing.T) {
233+
var err error
234+
token := "foo"
235+
if !tt.tokenInvalid {
236+
token, err = jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
237+
ExpiresAt: jwt.NewNumericDate(tt.tokenExpiresAt),
238+
}).SignedString([]byte("test"))
239+
if err != nil {
240+
t.Fatalf("failed to create token: %v", err)
241+
}
270242
}
271-
})
272-
}
273-
}
274-
275-
func TestGetJwksJSON(t *testing.T) {
276-
testCases := []struct {
277-
name string
278-
token string
279-
mockResponse *http.Response
280-
mockError error
281-
expectedResult []byte
282-
expectedError error
283-
}{
284-
{
285-
name: "Success",
286-
token: "test_token",
287-
mockResponse: &http.Response{
288-
StatusCode: 200,
289-
Body: io.NopCloser(bytes.NewReader([]byte(`{"key": "value"}`))),
290-
},
291-
mockError: nil,
292-
expectedResult: []byte(`{"key": "value"}`),
293-
expectedError: nil,
294-
},
295-
{
296-
name: "Error",
297-
token: "test_token",
298-
mockResponse: nil,
299-
mockError: fmt.Errorf("some error"),
300-
expectedResult: nil,
301-
expectedError: fmt.Errorf("some error"),
302-
},
303-
}
304243

305-
for _, tt := range testCases {
306-
t.Run(tt.name, func(t *testing.T) {
307-
mockDo := func(client *http.Client, req *http.Request, cfg *RetryConfig) (resp *http.Response, err error) {
308-
return tt.mockResponse, tt.mockError
244+
isExpired, err := tokenExpired(token)
245+
if err != nil && !tt.expectedErr {
246+
t.Fatalf("failed on valid input: %v", err)
309247
}
310-
311-
c := &KeyFlow{
312-
config: &KeyFlowConfig{ClientRetry: NewRetryConfig()},
313-
doer: mockDo,
248+
if err != nil {
249+
return
314250
}
315-
316-
result, err := c.getJwksJSON(tt.token)
317-
318-
if tt.expectedError != nil {
319-
if err == nil {
320-
t.Errorf("Expected error %v but no error was returned", tt.expectedError)
321-
} else if tt.expectedError.Error() != err.Error() {
322-
t.Errorf("Error is not correct. Expected %v, got %v", tt.expectedError, err)
323-
}
324-
} else {
325-
if err != nil {
326-
t.Errorf("Expected no error but error was returned: %v", err)
327-
}
328-
if !cmp.Equal(tt.expectedResult, result) {
329-
t.Errorf("The returned result is wrong. Expected %s, got %s", string(tt.expectedResult), string(result))
330-
}
251+
if isExpired != tt.expectedIsExpired {
252+
t.Fatalf("expected isValid to be %t, got %t", tt.expectedIsExpired, isExpired)
331253
}
332254
})
333255
}

core/config/config.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ type Configuration struct {
7878
PrivateKeyPath string `json:"privateKeyPath,omitempty"`
7979
CredentialsFilePath string `json:"credentialsFilePath,omitempty"`
8080
TokenCustomUrl string `json:"tokenCustomUrl,omitempty"`
81-
JWKSCustomUrl string `json:"jwksCustomUrl,omitempty"`
8281
Region string `json:"region,omitempty"`
8382
CustomAuth http.RoundTripper
8483
Servers ServerConfigurations
@@ -156,10 +155,9 @@ func WithTokenEndpoint(url string) ConfigurationOption {
156155
}
157156
}
158157

159-
// WithJWKSEndpoint returns a ConfigurationOption that overrides the default url to be used to get the jwks when using the key flow
160-
func WithJWKSEndpoint(url string) ConfigurationOption {
158+
// Deprecated: validation using JWKS was removed, for being redundant with token validation done in the APIs. This option has no effect, and will be removed in a later update
159+
func WithJWKSEndpoint(_ string) ConfigurationOption {
161160
return func(config *Configuration) error {
162-
config.JWKSCustomUrl = url
163161
return nil
164162
}
165163
}

0 commit comments

Comments
 (0)