From 71cf8a4b66cb68fea7c48405174becb5664e94db Mon Sep 17 00:00:00 2001 From: Hector Castejon Diaz Date: Thu, 26 Feb 2026 12:55:58 +0000 Subject: [PATCH] Move cloud filtering from individual strategies to DefaultCredentials Previously each Azure/GCP strategy would return nil on a non-matching host cloud (e.g. AzureCliCredentials returned nil on a GCP host). This meant that explicitly setting auth_type="azure-cli" on a GCP host was silently ignored. Now DefaultCredentials owns the cloud filtering: a cloudRequirements map in credentialsChain associates each cloud-specific strategy with its required cloud. In auto-detect mode, mismatched strategies are skipped with a debug log. When auth_type is explicitly set, the map is not consulted and the named strategy is always attempted regardless of the detected host cloud. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Hector Castejon Diaz --- NEXT_CHANGELOG.md | 1 + config/auth_azure_cli.go | 3 -- config/auth_azure_cli_test.go | 9 +---- config/auth_azure_client_secret.go | 3 -- config/auth_azure_github_oidc.go | 3 +- config/auth_azure_msi.go | 2 +- config/auth_default.go | 44 +++++++++++++++++++-- config/auth_default_test.go | 57 +++++++++++++++++++++++++++ config/auth_gcp_google_credentials.go | 2 +- config/auth_gcp_google_id.go | 2 +- 10 files changed, 104 insertions(+), 22 deletions(-) diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 85d401ba9..0ecc91d18 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -12,5 +12,6 @@ ### Internal Changes * Implement dynamic auth token stale period based on initial token lifetime. Increased up to 20 mins for standard OAuth with proportionally shorter periods for short-lived tokens. +* Move cloud-based credential filtering from individual strategies into `DefaultCredentials`. Azure strategies are skipped on GCP/AWS hosts in auto-detect mode; GCP strategies are skipped on Azure/AWS hosts. When `auth_type` is explicitly set (e.g. `azure-cli`), cloud filtering is bypassed so the named strategy is always attempted regardless of host cloud. ### API Changes diff --git a/config/auth_azure_cli.go b/config/auth_azure_cli.go index 2da57fccc..b81c3276a 100644 --- a/config/auth_azure_cli.go +++ b/config/auth_azure_cli.go @@ -56,9 +56,6 @@ func (c AzureCliCredentials) getVisitor(ctx context.Context, cfg *Config, inner } func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { - if !cfg.IsAzure() { - return nil, nil - } // Set the azure tenant ID from host if available err := cfg.loadAzureTenantId(ctx) if err != nil { diff --git a/config/auth_azure_cli_test.go b/config/auth_azure_cli_test.go index af02e97c4..f755887fe 100644 --- a/config/auth_azure_cli_test.go +++ b/config/auth_azure_cli_test.go @@ -51,11 +51,13 @@ var azDummy = &Config{ Host: "https://adb-xyz.c.azuredatabricks.net/", azureTenantIdFetchClient: makeClient(redirectResponse), } + var azDummyWithResourceId = &Config{ Host: "https://adb-xyz.c.azuredatabricks.net/", AzureResourceID: "/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123", azureTenantIdFetchClient: makeClient(redirectResponse), } + var azDummyWitInvalidResourceId = &Config{ Host: "https://adb-xyz.c.azuredatabricks.net/", AzureResourceID: "invalidResourceId", @@ -78,13 +80,6 @@ func testdataPath() string { return strings.Join(paths, ":") } -func TestAzureCliCredentials_SkipAws(t *testing.T) { - aa := AzureCliCredentials{} - x, err := aa.Configure(context.Background(), &Config{Host: "https://xyz.cloud.databricks.com/"}) - assert.Nil(t, x) - assert.NoError(t, err) -} - func TestAzureCliCredentials_NotInstalled(t *testing.T) { env.CleanupEnvironment(t) os.Setenv("PATH", "whatever") diff --git a/config/auth_azure_client_secret.go b/config/auth_azure_client_secret.go index 88dc6de3c..1b1a54e9d 100644 --- a/config/auth_azure_client_secret.go +++ b/config/auth_azure_client_secret.go @@ -39,9 +39,6 @@ func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config if cfg.AzureClientID == "" || cfg.AzureClientSecret == "" || cfg.AzureTenantID == "" { return nil, nil } - if !cfg.IsAzure() { - return nil, nil - } err := cfg.loadAzureTenantId(ctx) if err != nil { return nil, fmt.Errorf("load tenant id: %w", err) diff --git a/config/auth_azure_github_oidc.go b/config/auth_azure_github_oidc.go index d05e76a1a..de30640f1 100644 --- a/config/auth_azure_github_oidc.go +++ b/config/auth_azure_github_oidc.go @@ -24,8 +24,7 @@ func (c AzureGithubOIDCCredentials) Name() string { // Configure implements [CredentialsStrategy.Configure]. func (c AzureGithubOIDCCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { - // Sanity check that the config is configured for Azure Databricks. - if !cfg.IsAzure() || cfg.AzureClientID == "" || cfg.Host == "" || cfg.AzureTenantID == "" || cfg.ActionsIDTokenRequestURL == "" || cfg.ActionsIDTokenRequestToken == "" { + if cfg.AzureClientID == "" || cfg.Host == "" || cfg.AzureTenantID == "" || cfg.ActionsIDTokenRequestURL == "" || cfg.ActionsIDTokenRequestToken == "" { return nil, nil } supplier := oidc.NewGithubIDTokenSource( diff --git a/config/auth_azure_msi.go b/config/auth_azure_msi.go index eb577217f..1dc9a498d 100644 --- a/config/auth_azure_msi.go +++ b/config/auth_azure_msi.go @@ -32,7 +32,7 @@ func (c AzureMsiCredentials) Name() string { } func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { - if !cfg.IsAzure() || !cfg.AzureUseMSI || (cfg.AzureResourceID == "" && cfg.ConfigType() == WorkspaceConfig) { + if !cfg.AzureUseMSI || (cfg.AzureResourceID == "" && cfg.ConfigType() == WorkspaceConfig) { return nil, nil } env := cfg.Environment() diff --git a/config/auth_default.go b/config/auth_default.go index fc99d3c41..258753b4d 100644 --- a/config/auth_default.go +++ b/config/auth_default.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" + "github.com/databricks/databricks-sdk-go/common/environment" "github.com/databricks/databricks-sdk-go/config/credentials" "github.com/databricks/databricks-sdk-go/logger" ) @@ -22,13 +23,25 @@ var ErrCannotConfigureDefault = fmt.Errorf("cannot configure default credentials // INTERNAL: This function is not part of the public API and is subject to // change. Users are encouraged to use an explicit credentials strategy rather // than relying on a custom credentials chain. -func NewCredentialsChain(strategies ...CredentialsStrategy) CredentialsStrategy { +func NewCredentialsChain(strategies ...CredentialsStrategy) *credentialsChain { return &credentialsChain{strategies: strategies} } type credentialsChain struct { strategies []CredentialsStrategy - name string + // cloudRequirements maps a strategy name to the cloud it requires. When + // set, the auto-detect loop skips strategies whose required cloud does not + // match the configured host. The map is not consulted when AuthType is + // explicitly set — in that case the named strategy is always attempted. + cloudRequirements map[string]environment.Cloud + name string +} + +// WithCloudRequirements sets the cloud requirements for the chain and returns +// the chain for method chaining. +func (c *credentialsChain) WithCloudRequirements(m map[string]environment.Cloud) *credentialsChain { + c.cloudRequirements = m + return c } func (c *credentialsChain) Name() string { @@ -42,7 +55,9 @@ func (c *credentialsChain) Configure(ctx context.Context, cfg *Config) (credenti } // If an auth type is specified, try to configure the credentials for that - // specific auth type. If an error is encountered, return it. + // specific auth type. Cloud filtering is bypassed entirely so that users + // can explicitly request any strategy regardless of detected cloud (e.g. + // "azure-cli" on a GCP host). if cfg.AuthType != "" { for _, s := range c.strategies { if s.Name() == cfg.AuthType { @@ -58,6 +73,16 @@ func (c *credentialsChain) Configure(ctx context.Context, cfg *Config) (credenti // succeeds, returns the credentials provider. If a strategy fails, swallow // the error and try the next strategy. for _, s := range c.strategies { + // In auto-detect mode, skip cloud-specific strategies that don't match + // the detected cloud. This prevents Azure strategies from being + // attempted silently on GCP hosts and vice-versa. + if requiredCloud, ok := c.cloudRequirements[s.Name()]; ok { + if cfg.Environment().Cloud != requiredCloud { + logger.Debugf(ctx, "Skipping %q: not configured for %s", s.Name(), requiredCloud) + continue + } + } + logger.Tracef(ctx, "Attempting to configure auth: %q", s.Name()) cp, err := s.Configure(ctx, cfg) if err != nil || cp == nil { @@ -119,7 +144,18 @@ func (c *DefaultCredentials) Configure(ctx context.Context, cfg *Config) (creden // Google strategies. GoogleCredentials{}, GoogleDefaultCredentials{}, - ) + ).WithCloudRequirements(map[string]environment.Cloud{ + // cloudRequirements declares the cloud each strategy requires. + // DefaultCredentials uses this to skip cloud-specific strategies in + // auto-detect mode when the host cloud does not match. Cloud filtering + // is bypassed when AuthType is explicitly set. + "github-oidc-azure": environment.CloudAzure, + "azure-msi": environment.CloudAzure, + "azure-client-secret": environment.CloudAzure, + "azure-cli": environment.CloudAzure, + "google-credentials": environment.CloudGCP, + "google-id": environment.CloudGCP, + }) cp, err := chain.Configure(ctx, cfg) if err != nil { return nil, err diff --git a/config/auth_default_test.go b/config/auth_default_test.go index 12cbdcc11..06fbc77a5 100644 --- a/config/auth_default_test.go +++ b/config/auth_default_test.go @@ -5,8 +5,65 @@ import ( "net/http" "strings" "testing" + + "github.com/databricks/databricks-sdk-go/common/environment" + "github.com/databricks/databricks-sdk-go/config/credentials" ) +// recordingStrategy is a test helper that records whether Configure was called. +type recordingStrategy struct { + name string + called bool +} + +func (r *recordingStrategy) Name() string { return r.name } +func (r *recordingStrategy) Configure(_ context.Context, _ *Config) (credentials.CredentialsProvider, error) { + r.called = true + return nil, nil +} + +// TestCredentialsChain_CloudFiltering_SkipsOnCloudMismatch verifies that the +// chain skips a cloud-specific strategy in auto-detect mode when the detected +// cloud does not match the strategy's required cloud. +func TestCredentialsChain_CloudFiltering_SkipsOnCloudMismatch(t *testing.T) { + azureStrategy := &recordingStrategy{name: "azure-cli"} + chain := &credentialsChain{ + strategies: []CredentialsStrategy{azureStrategy}, + cloudRequirements: map[string]environment.Cloud{ + "azure-cli": environment.CloudAzure, + }, + } + + // GCP host: azure-cli must be skipped in auto-detect mode. + cfg := &Config{Host: "https://xyz.gcp.databricks.com/", resolved: true} + chain.Configure(context.Background(), cfg) //nolint:errcheck + + if azureStrategy.called { + t.Error("azure-cli strategy was called on GCP host, want it to be skipped in auto-detect mode") + } +} + +// TestCredentialsChain_CloudFiltering_BypassesOnExplicitAuthType verifies that +// the cloud filter is bypassed when AuthType is explicitly set, so that a user +// can request "azure-cli" even on a GCP host. +func TestCredentialsChain_CloudFiltering_BypassesOnExplicitAuthType(t *testing.T) { + azureStrategy := &recordingStrategy{name: "azure-cli"} + chain := &credentialsChain{ + strategies: []CredentialsStrategy{azureStrategy}, + cloudRequirements: map[string]environment.Cloud{ + "azure-cli": environment.CloudAzure, + }, + } + + // GCP host but auth_type is explicitly set: cloud filter must be bypassed. + cfg := &Config{Host: "https://xyz.gcp.databricks.com/", AuthType: "azure-cli", resolved: true} + chain.Configure(context.Background(), cfg) //nolint:errcheck + + if !azureStrategy.called { + t.Error("azure-cli strategy was not called despite explicit auth_type on GCP host, want bypass of cloud filter") + } +} + func TestDefaultCredentialStrategy(t *testing.T) { original := DefaultCredentialStrategyProvider t.Cleanup(func() { DefaultCredentialStrategyProvider = original }) diff --git a/config/auth_gcp_google_credentials.go b/config/auth_gcp_google_credentials.go index 3477ba427..2453079ab 100644 --- a/config/auth_gcp_google_credentials.go +++ b/config/auth_gcp_google_credentials.go @@ -21,7 +21,7 @@ func (c GoogleCredentials) Name() string { } func (c GoogleCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { - if cfg.GoogleCredentials == "" || !cfg.IsGcp() { + if cfg.GoogleCredentials == "" { return nil, nil } json, err := readCredentials(cfg.GoogleCredentials) diff --git a/config/auth_gcp_google_id.go b/config/auth_gcp_google_id.go index 17d1112d1..8ef76d720 100644 --- a/config/auth_gcp_google_id.go +++ b/config/auth_gcp_google_id.go @@ -21,7 +21,7 @@ func (c GoogleDefaultCredentials) Name() string { } func (c GoogleDefaultCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { - if cfg.GoogleServiceAccount == "" || !cfg.IsGcp() { + if cfg.GoogleServiceAccount == "" { return nil, nil } inner, err := c.idTokenSource(ctx, cfg.Host, cfg.GoogleServiceAccount, c.opts...)