Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
- **Feature:** Add package `runtime`, which implements methods to be used when performing API requests
- **Feature:** Add method `WithCaptureHTTPResponse` to package `runtime`, which does the same as `config.WithCaptureHTTPResponse`. Method was moved to avoid confusion due to it not being a configuration option, and will be removed in a later release.
- **Deprecation:** Mark method `config.WithCaptureHTTPResponse` as deprecated, to avoid confusion due to it not being a configuration option. Use `runtime.WithCaptureHTTPResponse` instead.
- **Deprecation:** Marked method `config.WithJWKSEndpoint` as deprecated. Validation using JWKS was removed, for being redundant with token validation done in the APIs. This option has no effect

## Release (2024-02-07)

Expand Down
3 changes: 2 additions & 1 deletion core/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## v0.9.0 (YYYY-MM-DD)

- **Deprecation:** Mark method `config.WithCaptureHTTPResponse` as deprecated, to avoid confusion due to it not being a configuration option. Use `runtime.WithCaptureHTTPResponse` instead.
- **Deprecation:** Marked method `config.WithCaptureHTTPResponse` as deprecated, to avoid confusion due
- **Deprecation:** Marked method `config.WithJWKSEndpoint` as deprecated. Validation using JWKS was removed, for being redundant with token validation done in the APIs. This option has no effect

## v0.8.0 (2024-02-16)

Expand Down
7 changes: 0 additions & 7 deletions core/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,19 +178,12 @@ func KeyAuth(cfg *config.Configuration) (http.RoundTripper, error) {
cfg.TokenCustomUrl = tokenCustomUrl
}
}
if cfg.JWKSCustomUrl == "" {
jwksCustomUrl, jwksUrlSet := os.LookupEnv("STACKIT_JWKS_BASEURL")
if jwksUrlSet {
cfg.JWKSCustomUrl = jwksCustomUrl
}
}

keyCfg := clients.KeyFlowConfig{
ServiceAccountKey: serviceAccountKey,
PrivateKey: cfg.PrivateKey,
ClientRetry: cfg.RetryOptions,
TokenUrl: cfg.TokenCustomUrl,
JWKSUrl: cfg.JWKSCustomUrl,
}

client := &clients.KeyFlow{}
Expand Down
72 changes: 14 additions & 58 deletions core/clients/key_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@ import (
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"

"github.com/MicahParks/keyfunc/v2"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
Expand All @@ -26,7 +24,6 @@ const (
ServiceAccountKeyPath = "STACKIT_SERVICE_ACCOUNT_KEY_PATH"
PrivateKeyPath = "STACKIT_PRIVATE_KEY_PATH"
tokenAPI = "https://service-account.api.stackit.cloud/token" //nolint:gosec // linter false positive
jwksAPI = "https://service-account.api.stackit.cloud/.well-known/jwks.json"
defaultTokenType = "Bearer"
defaultScope = ""
)
Expand All @@ -48,7 +45,6 @@ type KeyFlowConfig struct {
PrivateKey string
ClientRetry *RetryConfig
TokenUrl string
JWKSUrl string
}

// TokenResponseBody is the API response
Expand Down Expand Up @@ -112,13 +108,9 @@ func (c *KeyFlow) Init(cfg *KeyFlowConfig) error {
c.config = cfg
c.doer = Do

// set defaults if no custom token and jwks url are provided
if c.config.TokenUrl == "" {
c.config.TokenUrl = tokenAPI
}
if c.config.JWKSUrl == "" {
c.config.JWKSUrl = jwksAPI
}
c.configureHTTPClient()
if c.config.ClientRetry == nil {
c.config.ClientRetry = NewRetryConfig()
Expand Down Expand Up @@ -181,11 +173,11 @@ func (c *KeyFlow) RoundTrip(req *http.Request) (*http.Response, error) {

// GetAccessToken returns a short-lived access token and saves the access and refresh tokens in the token field
func (c *KeyFlow) GetAccessToken() (string, error) {
accessTokenIsValid, err := c.validateToken(c.token.AccessToken)
accessTokenExpired, err := tokenExpired(c.token.AccessToken)
if err != nil {
return "", fmt.Errorf("failed initial validation: %w", err)
}
if accessTokenIsValid {
if !accessTokenExpired {
return c.token.AccessToken, nil
}
if err := c.recreateAccessToken(); err != nil {
Expand Down Expand Up @@ -232,11 +224,11 @@ func (c *KeyFlow) validate() error {
// recreateAccessToken is used to create a new access token
// when the existing one isn't valid anymore
func (c *KeyFlow) recreateAccessToken() error {
refreshTokenIsValid, err := c.validateToken(c.token.RefreshToken)
refreshTokenExpired, err := tokenExpired(c.token.RefreshToken)
if err != nil {
return err
}
if refreshTokenIsValid {
if !refreshTokenExpired {
return c.createAccessTokenWithRefreshToken()
}
return c.createAccessToken()
Expand Down Expand Up @@ -327,54 +319,18 @@ func (c *KeyFlow) parseTokenResponse(res *http.Response) error {
return json.Unmarshal(body, c.token)
}

// validateToken returns true if token is valid
func (c *KeyFlow) validateToken(token string) (bool, error) {
if token == "" {
return false, nil
}
if _, err := c.parseToken(token); err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
c.token = &TokenResponseBody{}
return false, nil
}
return false, fmt.Errorf("parse token: %w", err)
}
return true, nil
}

// parseToken parses and validates a JWT token
func (c *KeyFlow) parseToken(token string) (*jwt.Token, error) {
b, err := c.getJwksJSON(token)
if err != nil {
return nil, fmt.Errorf("get JWKS Json: %w", err)
}
var jwksBytes = json.RawMessage(b)
jwks, err := keyfunc.NewJSON(jwksBytes)
if err != nil {
return nil, fmt.Errorf("get JWKS function from JSON: %w", err)
}
return jwt.Parse(token, jwks.Keyfunc)
}

func (c *KeyFlow) getJwksJSON(token string) (jwks []byte, err error) {
req, err := http.NewRequest("GET", c.config.JWKSUrl, http.NoBody)
func tokenExpired(token string) (bool, error) {
// We can safely use ParseUnverified because we are not authenticating the user at this point.
// We're just checking the expiration time
tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{})
if err != nil {
return nil, err
return false, fmt.Errorf("parse access token: %w", err)
}
req.Header.Add("Authorization", "Bearer "+token)
res, err := c.doer(&http.Client{}, req, c.config.ClientRetry)
expirationTimestampNumeric, err := tokenParsed.Claims.GetExpirationTime()
if err != nil {
return nil, err
}
defer func() {
tempErr := res.Body.Close()
if tempErr != nil {
jwks = nil
err = fmt.Errorf("closing get jwks response: %w", tempErr)
}
}()
if res.StatusCode != 200 {
return nil, fmt.Errorf("getting jwks return error status %s", res.Status)
return false, fmt.Errorf("get expiration timestamp from access token: %w", err)
}
return io.ReadAll(res.Body)
expirationTimestamp := expirationTimestampNumeric.Time
now := time.Now()
return now.After(expirationTimestamp), nil
}
148 changes: 35 additions & 113 deletions core/clients/key_flow_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package clients

import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
Expand All @@ -21,7 +20,6 @@ import (

var (
testSigningKey = []byte(`Test`)
testJwks = []byte(`{ "keys": [ { "kty":"oct", "kid":"test", "alg":"HS256" } ] }`)
)

func fixtureServiceAccountKey(mods ...func(*ServiceAccountKeyResponse)) *ServiceAccountKeyResponse {
Expand Down Expand Up @@ -126,10 +124,6 @@ func TestKeyFlowInit(t *testing.T) {
}
}

type MyCustomClaims struct {
Foo string `json:"foo"`
}

func TestSetToken(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -207,127 +201,55 @@ func TestKeyClone(t *testing.T) {
}
}

func TestKeyFlowValidateToken(t *testing.T) {
// Generate a random private key
privateKey := make([]byte, 32)
if _, err := rand.Read(privateKey); err != nil {
t.Fatal(err)
}
func TestTokenExpired(t *testing.T) {
tests := []struct {
name string
token string
jwksMockResponse *http.Response
jwksMockError error
want bool
wantErr bool
desc string
tokenInvalid bool
tokenExpiresAt time.Time
expectedErr bool
expectedIsExpired bool
}{
{
name: "no token",
token: "",
want: false,
wantErr: false,
desc: "token valid",
tokenExpiresAt: time.Now().Add(time.Hour),
expectedErr: false,
expectedIsExpired: false,
},
{
name: "invalid token",
token: "bad token",
jwksMockResponse: &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader(testJwks)),
},
jwksMockError: nil,
want: false,
wantErr: true,
desc: "token expired",
tokenExpiresAt: time.Now().Add(-time.Hour),
expectedErr: false,
expectedIsExpired: true,
},
{
name: "get_jwks_fail",
token: "bad token",
jwksMockResponse: nil,
jwksMockError: fmt.Errorf("error"),
want: false,
wantErr: true,
desc: "token invalid",
tokenInvalid: true,
expectedErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockDo := func(client *http.Client, req *http.Request, cfg *RetryConfig) (resp *http.Response, err error) {
return tt.jwksMockResponse, tt.jwksMockError
}
c := &KeyFlow{
config: &KeyFlowConfig{
PrivateKey: string(privateKey),
JWKSUrl: jwksAPI,
},
doer: mockDo,
}

got, err := c.validateToken(tt.token)
if (err != nil) != tt.wantErr {
t.Errorf("KeyFlow.validateToken() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("KeyFlow.validateToken() = %v, want %v", got, tt.want)
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
var err error
token := "foo"
if !tt.tokenInvalid {
token, err = jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(tt.tokenExpiresAt),
}).SignedString([]byte("test"))
if err != nil {
t.Fatalf("failed to create token: %v", err)
}
}
})
}
}

func TestGetJwksJSON(t *testing.T) {
testCases := []struct {
name string
token string
mockResponse *http.Response
mockError error
expectedResult []byte
expectedError error
}{
{
name: "Success",
token: "test_token",
mockResponse: &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader([]byte(`{"key": "value"}`))),
},
mockError: nil,
expectedResult: []byte(`{"key": "value"}`),
expectedError: nil,
},
{
name: "Error",
token: "test_token",
mockResponse: nil,
mockError: fmt.Errorf("some error"),
expectedResult: nil,
expectedError: fmt.Errorf("some error"),
},
}

for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
mockDo := func(client *http.Client, req *http.Request, cfg *RetryConfig) (resp *http.Response, err error) {
return tt.mockResponse, tt.mockError
isExpired, err := tokenExpired(token)
if err != nil && !tt.expectedErr {
t.Fatalf("failed on valid input: %v", err)
}

c := &KeyFlow{
config: &KeyFlowConfig{ClientRetry: NewRetryConfig()},
doer: mockDo,
if err != nil {
return
}

result, err := c.getJwksJSON(tt.token)

if tt.expectedError != nil {
if err == nil {
t.Errorf("Expected error %v but no error was returned", tt.expectedError)
} else if tt.expectedError.Error() != err.Error() {
t.Errorf("Error is not correct. Expected %v, got %v", tt.expectedError, err)
}
} else {
if err != nil {
t.Errorf("Expected no error but error was returned: %v", err)
}
if !cmp.Equal(tt.expectedResult, result) {
t.Errorf("The returned result is wrong. Expected %s, got %s", string(tt.expectedResult), string(result))
}
if isExpired != tt.expectedIsExpired {
t.Fatalf("expected isValid to be %t, got %t", tt.expectedIsExpired, isExpired)
}
})
}
Expand Down
6 changes: 2 additions & 4 deletions core/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ type Configuration struct {
PrivateKeyPath string `json:"privateKeyPath,omitempty"`
CredentialsFilePath string `json:"credentialsFilePath,omitempty"`
TokenCustomUrl string `json:"tokenCustomUrl,omitempty"`
JWKSCustomUrl string `json:"jwksCustomUrl,omitempty"`
Region string `json:"region,omitempty"`
CustomAuth http.RoundTripper
Servers ServerConfigurations
Expand Down Expand Up @@ -156,10 +155,9 @@ func WithTokenEndpoint(url string) ConfigurationOption {
}
}

// WithJWKSEndpoint returns a ConfigurationOption that overrides the default url to be used to get the jwks when using the key flow
func WithJWKSEndpoint(url string) ConfigurationOption {
// Deprecated: validation using JWKS was removed, for being redundant with token validation done in the APIs. This option has no effect, and will be removed in a later update
func WithJWKSEndpoint(_ string) ConfigurationOption {
return func(config *Configuration) error {
config.JWKSCustomUrl = url
return nil
}
}
Expand Down
Loading