Skip to content

Commit

Permalink
Override OIDC provider discovered claims (#3267)
Browse files Browse the repository at this point in the history
Signed-off-by: m.nabokikh <[email protected]>
  • Loading branch information
nabokihms authored Jan 10, 2024
1 parent 231a97d commit 665a5b6
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 4 deletions.
70 changes: 66 additions & 4 deletions connector/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ type Config struct {
ClientSecret string `json:"clientSecret"`
RedirectURI string `json:"redirectURI"`

// The section to override options discovered automatically from
// the providers' discovery URL (.well-known/openid-configuration).
ProviderDiscoveryOverrides ProviderDiscoveryOverrides `json:"providerDiscoveryOverrides"`

// Causes client_secret to be passed as POST parameters instead of basic
// auth. This is specifically "NOT RECOMMENDED" by the OAuth2 RFC, but some
// providers require it.
Expand Down Expand Up @@ -96,6 +100,61 @@ type Config struct {
} `json:"claimModifications"`
}

type ProviderDiscoveryOverrides struct {
// TokenURL provides a way to user overwrite the Token URL
// from the .well-known/openid-configuration token_endpoint
TokenURL string `json:"tokenURL"`
// AuthURL provides a way to user overwrite the Auth URL
// from the .well-known/openid-configuration authorization_endpoint
AuthURL string `json:"authURL"`
}

func (o *ProviderDiscoveryOverrides) Empty() bool {
return o.TokenURL == "" && o.AuthURL == ""
}

func getProvider(ctx context.Context, issuer string, overrides ProviderDiscoveryOverrides) (*oidc.Provider, error) {
provider, err := oidc.NewProvider(ctx, issuer)
if err != nil {
return nil, fmt.Errorf("failed to get provider: %v", err)
}

if overrides.Empty() {
return provider, nil
}

v := &struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
DeviceAuthURL string `json:"device_authorization_endpoint"`
JWKSURL string `json:"jwks_uri"`
UserInfoURL string `json:"userinfo_endpoint"`
Algorithms []string `json:"id_token_signing_alg_values_supported"`
}{}
if err := provider.Claims(v); err != nil {
return nil, fmt.Errorf("failed to extract provider discovery claims: %v", err)
}
config := oidc.ProviderConfig{
IssuerURL: v.Issuer,
AuthURL: v.AuthURL,
TokenURL: v.TokenURL,
DeviceAuthURL: v.DeviceAuthURL,
JWKSURL: v.JWKSURL,
UserInfoURL: v.UserInfoURL,
Algorithms: v.Algorithms,
}

if overrides.TokenURL != "" {
config.TokenURL = overrides.TokenURL
}
if overrides.AuthURL != "" {
config.AuthURL = overrides.AuthURL
}

return config.NewProvider(context.Background()), nil
}

// NewGroupFromClaims creates a new group from a list of claims and appends it to the list of existing groups.
type NewGroupFromClaims struct {
// List of claim to join together
Expand Down Expand Up @@ -152,13 +211,16 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e
return nil, err
}

ctx, cancel := context.WithCancel(context.Background())
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
bgctx, cancel := context.WithCancel(context.Background())
ctx := context.WithValue(bgctx, oauth2.HTTPClient, httpClient)

provider, err := oidc.NewProvider(ctx, c.Issuer)
provider, err := getProvider(ctx, c.Issuer, c.ProviderDiscoveryOverrides)
if err != nil {
cancel()
return nil, fmt.Errorf("failed to get provider: %v", err)
return nil, err
}
if !c.ProviderDiscoveryOverrides.Empty() {
logger.Warnf("overrides for connector %q are set, this can be a vulnerability when not properly configured", id)
}

endpoint := provider.Endpoint()
Expand Down
51 changes: 51 additions & 0 deletions connector/oidc/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,57 @@ func TestTokenIdentity(t *testing.T) {
}
}

func TestProviderOverride(t *testing.T) {
testServer, err := setupServer(map[string]any{
"sub": "subvalue",
"name": "namevalue",
}, true)
if err != nil {
t.Fatal("failed to setup test server", err)
}

t.Run("No override", func(t *testing.T) {
conn, err := newConnector(Config{
Issuer: testServer.URL,
Scopes: []string{"openid", "groups"},
})
if err != nil {
t.Fatal("failed to create new connector", err)
}

expAuth := fmt.Sprintf("%s/authorize", testServer.URL)
if conn.provider.Endpoint().AuthURL != expAuth {
t.Fatalf("unexpected auth URL: %s, expected: %s\n", conn.provider.Endpoint().AuthURL, expAuth)
}

expToken := fmt.Sprintf("%s/token", testServer.URL)
if conn.provider.Endpoint().TokenURL != expToken {
t.Fatalf("unexpected token URL: %s, expected: %s\n", conn.provider.Endpoint().TokenURL, expToken)
}
})

t.Run("Override", func(t *testing.T) {
conn, err := newConnector(Config{
Issuer: testServer.URL,
Scopes: []string{"openid", "groups"},
ProviderDiscoveryOverrides: ProviderDiscoveryOverrides{TokenURL: "/test1", AuthURL: "/test2"},
})
if err != nil {
t.Fatal("failed to create new connector", err)
}

expAuth := "/test2"
if conn.provider.Endpoint().AuthURL != expAuth {
t.Fatalf("unexpected auth URL: %s, expected: %s\n", conn.provider.Endpoint().AuthURL, expAuth)
}

expToken := "/test1"
if conn.provider.Endpoint().TokenURL != expToken {
t.Fatalf("unexpected token URL: %s, expected: %s\n", conn.provider.Endpoint().TokenURL, expToken)
}
})
}

func setupServer(tok map[string]interface{}, idTokenDesired bool) (*httptest.Server, error) {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
Expand Down

0 comments on commit 665a5b6

Please sign in to comment.