diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 274315be8..b001d35ae 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -8,6 +8,8 @@ ### Bug Fixes + * Restore `NewCredentialsChain` return type to `CredentialsStrategy` interface ([#1516](https://github.com/databricks/databricks-sdk-go/pull/1516)). + ### Documentation ### Internal Changes diff --git a/config/auth_azure_cli.go b/config/auth_azure_cli.go index b81c3276a..910b54df9 100644 --- a/config/auth_azure_cli.go +++ b/config/auth_azure_cli.go @@ -11,6 +11,7 @@ import ( "golang.org/x/oauth2" + "github.com/databricks/databricks-sdk-go/common/environment" "github.com/databricks/databricks-sdk-go/config/credentials" "github.com/databricks/databricks-sdk-go/logger" ) @@ -28,6 +29,14 @@ func (c AzureCliCredentials) Name() string { return "azure-cli" } +// Validate implements [ValidatingStrategy.Validate]. +func (c AzureCliCredentials) Validate(_ context.Context, cfg *Config) error { + if cfg.Environment().Cloud != environment.CloudAzure { + return fmt.Errorf("%w: requires Azure, got %s", ErrInvalidCloud, cfg.Environment().Cloud) + } + return nil +} + // implementing azureHostResolver for ensureWorkspaceUrl to work func (c AzureCliCredentials) tokenSourceFor( ctx context.Context, cfg *Config, _, resource string) oauth2.TokenSource { diff --git a/config/auth_azure_client_secret.go b/config/auth_azure_client_secret.go index 1b1a54e9d..3c8e5f6d8 100644 --- a/config/auth_azure_client_secret.go +++ b/config/auth_azure_client_secret.go @@ -8,6 +8,7 @@ import ( "golang.org/x/oauth2" "golang.org/x/oauth2/clientcredentials" + "github.com/databricks/databricks-sdk-go/common/environment" "github.com/databricks/databricks-sdk-go/config/credentials" "github.com/databricks/databricks-sdk-go/logger" ) @@ -19,6 +20,23 @@ func (c AzureClientSecretCredentials) Name() string { return "azure-client-secret" } +// Validate implements [ValidatingStrategy.Validate]. +func (c AzureClientSecretCredentials) Validate(_ context.Context, cfg *Config) error { + if cfg.AzureClientID == "" { + return fmt.Errorf("azure_client_id is required") + } + if cfg.AzureClientSecret == "" { + return fmt.Errorf("azure_client_secret is required") + } + if cfg.AzureTenantID == "" { + return fmt.Errorf("azure_tenant_id is required") + } + if cfg.Environment().Cloud != environment.CloudAzure { + return fmt.Errorf("%w: requires Azure, got %s", ErrInvalidCloud, cfg.Environment().Cloud) + } + return nil +} + func (c AzureClientSecretCredentials) tokenSourceFor( ctx context.Context, cfg *Config, aadEndpoint, resource string) oauth2.TokenSource { return (&clientcredentials.Config{ diff --git a/config/auth_azure_github_oidc.go b/config/auth_azure_github_oidc.go index de30640f1..f3c14855c 100644 --- a/config/auth_azure_github_oidc.go +++ b/config/auth_azure_github_oidc.go @@ -6,6 +6,7 @@ import ( "fmt" "time" + "github.com/databricks/databricks-sdk-go/common/environment" "github.com/databricks/databricks-sdk-go/config/credentials" "github.com/databricks/databricks-sdk-go/config/experimental/auth" "github.com/databricks/databricks-sdk-go/config/experimental/auth/oidc" @@ -22,6 +23,29 @@ func (c AzureGithubOIDCCredentials) Name() string { return "github-oidc-azure" } +// Validate implements [ValidatingStrategy.Validate]. +func (c AzureGithubOIDCCredentials) Validate(_ context.Context, cfg *Config) error { + if cfg.AzureClientID == "" { + return fmt.Errorf("azure_client_id is required") + } + if cfg.Host == "" { + return fmt.Errorf("host is required") + } + if cfg.AzureTenantID == "" { + return fmt.Errorf("azure_tenant_id is required") + } + if cfg.ActionsIDTokenRequestURL == "" { + return fmt.Errorf("ACTIONS_ID_TOKEN_REQUEST_URL is required") + } + if cfg.ActionsIDTokenRequestToken == "" { + return fmt.Errorf("ACTIONS_ID_TOKEN_REQUEST_TOKEN is required") + } + if cfg.Environment().Cloud != environment.CloudAzure { + return fmt.Errorf("%w: requires Azure, got %s", ErrInvalidCloud, cfg.Environment().Cloud) + } + return nil +} + // Configure implements [CredentialsStrategy.Configure]. func (c AzureGithubOIDCCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { if cfg.AzureClientID == "" || cfg.Host == "" || cfg.AzureTenantID == "" || cfg.ActionsIDTokenRequestURL == "" || cfg.ActionsIDTokenRequestToken == "" { diff --git a/config/auth_azure_msi.go b/config/auth_azure_msi.go index 1dc9a498d..fee87a72c 100644 --- a/config/auth_azure_msi.go +++ b/config/auth_azure_msi.go @@ -8,6 +8,7 @@ import ( "net/http" "time" + "github.com/databricks/databricks-sdk-go/common/environment" "github.com/databricks/databricks-sdk-go/config/credentials" "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/logger" @@ -31,6 +32,20 @@ func (c AzureMsiCredentials) Name() string { return "azure-msi" } +// Validate implements [ValidatingStrategy.Validate]. +func (c AzureMsiCredentials) Validate(_ context.Context, cfg *Config) error { + if !cfg.AzureUseMSI { + return fmt.Errorf("azure_use_msi is not enabled") + } + if cfg.AzureResourceID == "" && cfg.ConfigType() == WorkspaceConfig { + return fmt.Errorf("azure_workspace_resource_id is required for workspace authentication") + } + if cfg.Environment().Cloud != environment.CloudAzure { + return fmt.Errorf("%w: requires Azure, got %s", ErrInvalidCloud, cfg.Environment().Cloud) + } + return nil +} + func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { if !cfg.AzureUseMSI || (cfg.AzureResourceID == "" && cfg.ConfigType() == WorkspaceConfig) { return nil, nil diff --git a/config/auth_default.go b/config/auth_default.go index 258753b4d..6d141ee3f 100644 --- a/config/auth_default.go +++ b/config/auth_default.go @@ -2,10 +2,10 @@ package config import ( "context" + "errors" "fmt" "net/http" - "github.com/databricks/databricks-sdk-go/common/environment" "github.com/databricks/databricks-sdk-go/config/credentials" "github.com/databricks/databricks-sdk-go/logger" ) @@ -23,25 +23,13 @@ var ErrCannotConfigureDefault = fmt.Errorf("cannot configure default credentials // INTERNAL: This function is not part of the public API and is subject to // change. Users are encouraged to use an explicit credentials strategy rather // than relying on a custom credentials chain. -func NewCredentialsChain(strategies ...CredentialsStrategy) *credentialsChain { +func NewCredentialsChain(strategies ...CredentialsStrategy) CredentialsStrategy { return &credentialsChain{strategies: strategies} } type credentialsChain struct { strategies []CredentialsStrategy - // cloudRequirements maps a strategy name to the cloud it requires. When - // set, the auto-detect loop skips strategies whose required cloud does not - // match the configured host. The map is not consulted when AuthType is - // explicitly set — in that case the named strategy is always attempted. - cloudRequirements map[string]environment.Cloud - name string -} - -// WithCloudRequirements sets the cloud requirements for the chain and returns -// the chain for method chaining. -func (c *credentialsChain) WithCloudRequirements(m map[string]environment.Cloud) *credentialsChain { - c.cloudRequirements = m - return c + name string } func (c *credentialsChain) Name() string { @@ -55,12 +43,18 @@ func (c *credentialsChain) Configure(ctx context.Context, cfg *Config) (credenti } // If an auth type is specified, try to configure the credentials for that - // specific auth type. Cloud filtering is bypassed entirely so that users - // can explicitly request any strategy regardless of detected cloud (e.g. - // "azure-cli" on a GCP host). + // specific auth type. ErrInvalidCloud is ignored so users can explicitly + // request a cloud-specific strategy (e.g. "azure-cli") on any host. Other + // validation errors are propagated so misconfigured strategies fail fast. if cfg.AuthType != "" { for _, s := range c.strategies { if s.Name() == cfg.AuthType { + if vs, ok := s.(ValidatingStrategy); ok { + // ErrInvalidCloud is ignored so users can explicitly request a cloud-specific strategy (e.g. "azure-cli") on any host. + if err := vs.Validate(ctx, cfg); err != nil && !errors.Is(err, ErrInvalidCloud) { + return nil, err + } + } logger.Tracef(ctx, "Attempting to configure auth: %q", s.Name()) c.name = s.Name() return s.Configure(ctx, cfg) @@ -73,12 +67,12 @@ func (c *credentialsChain) Configure(ctx context.Context, cfg *Config) (credenti // succeeds, returns the credentials provider. If a strategy fails, swallow // the error and try the next strategy. for _, s := range c.strategies { - // In auto-detect mode, skip cloud-specific strategies that don't match - // the detected cloud. This prevents Azure strategies from being - // attempted silently on GCP hosts and vice-versa. - if requiredCloud, ok := c.cloudRequirements[s.Name()]; ok { - if cfg.Environment().Cloud != requiredCloud { - logger.Debugf(ctx, "Skipping %q: not configured for %s", s.Name(), requiredCloud) + // In auto-detect mode, consult ValidatingStrategy to skip strategies + // that are not applicable (e.g. cloud mismatch). This prevents Azure + // strategies from being attempted silently on GCP hosts and vice-versa. + if vs, ok := s.(ValidatingStrategy); ok { + if err := vs.Validate(ctx, cfg); err != nil { + logger.Debugf(ctx, "Skipping %q: %v", s.Name(), err) continue } } @@ -144,18 +138,7 @@ func (c *DefaultCredentials) Configure(ctx context.Context, cfg *Config) (creden // Google strategies. GoogleCredentials{}, GoogleDefaultCredentials{}, - ).WithCloudRequirements(map[string]environment.Cloud{ - // cloudRequirements declares the cloud each strategy requires. - // DefaultCredentials uses this to skip cloud-specific strategies in - // auto-detect mode when the host cloud does not match. Cloud filtering - // is bypassed when AuthType is explicitly set. - "github-oidc-azure": environment.CloudAzure, - "azure-msi": environment.CloudAzure, - "azure-client-secret": environment.CloudAzure, - "azure-cli": environment.CloudAzure, - "google-credentials": environment.CloudGCP, - "google-id": environment.CloudGCP, - }) + ) cp, err := chain.Configure(ctx, cfg) if err != nil { return nil, err diff --git a/config/auth_default_test.go b/config/auth_default_test.go index 06fbc77a5..3701ac9f6 100644 --- a/config/auth_default_test.go +++ b/config/auth_default_test.go @@ -2,6 +2,7 @@ package config import ( "context" + "fmt" "net/http" "strings" "testing" @@ -10,15 +11,23 @@ import ( "github.com/databricks/databricks-sdk-go/config/credentials" ) -// recordingStrategy is a test helper that records whether Configure was called. -type recordingStrategy struct { +// validatingStrategy is a test helper that implements [CredentialsStrategy] +// and [ValidatingStrategy], recording whether Configure was called. +type validatingStrategy struct { name string called bool + cloud environment.Cloud } -func (r *recordingStrategy) Name() string { return r.name } -func (r *recordingStrategy) Configure(_ context.Context, _ *Config) (credentials.CredentialsProvider, error) { - r.called = true +func (s *validatingStrategy) Name() string { return s.name } +func (s *validatingStrategy) Validate(_ context.Context, cfg *Config) error { + if cfg.Environment().Cloud != s.cloud { + return fmt.Errorf("%w: requires %s, got %s", ErrInvalidCloud, s.cloud, cfg.Environment().Cloud) + } + return nil +} +func (s *validatingStrategy) Configure(_ context.Context, _ *Config) (credentials.CredentialsProvider, error) { + s.called = true return nil, nil } @@ -26,13 +35,8 @@ func (r *recordingStrategy) Configure(_ context.Context, _ *Config) (credentials // chain skips a cloud-specific strategy in auto-detect mode when the detected // cloud does not match the strategy's required cloud. func TestCredentialsChain_CloudFiltering_SkipsOnCloudMismatch(t *testing.T) { - azureStrategy := &recordingStrategy{name: "azure-cli"} - chain := &credentialsChain{ - strategies: []CredentialsStrategy{azureStrategy}, - cloudRequirements: map[string]environment.Cloud{ - "azure-cli": environment.CloudAzure, - }, - } + azureStrategy := &validatingStrategy{name: "azure-cli", cloud: environment.CloudAzure} + chain := &credentialsChain{strategies: []CredentialsStrategy{azureStrategy}} // GCP host: azure-cli must be skipped in auto-detect mode. cfg := &Config{Host: "https://xyz.gcp.databricks.com/", resolved: true} @@ -47,13 +51,8 @@ func TestCredentialsChain_CloudFiltering_SkipsOnCloudMismatch(t *testing.T) { // the cloud filter is bypassed when AuthType is explicitly set, so that a user // can request "azure-cli" even on a GCP host. func TestCredentialsChain_CloudFiltering_BypassesOnExplicitAuthType(t *testing.T) { - azureStrategy := &recordingStrategy{name: "azure-cli"} - chain := &credentialsChain{ - strategies: []CredentialsStrategy{azureStrategy}, - cloudRequirements: map[string]environment.Cloud{ - "azure-cli": environment.CloudAzure, - }, - } + azureStrategy := &validatingStrategy{name: "azure-cli", cloud: environment.CloudAzure} + chain := &credentialsChain{strategies: []CredentialsStrategy{azureStrategy}} // GCP host but auth_type is explicitly set: cloud filter must be bypassed. cfg := &Config{Host: "https://xyz.gcp.databricks.com/", AuthType: "azure-cli", resolved: true} diff --git a/config/auth_gcp_google_credentials.go b/config/auth_gcp_google_credentials.go index 2453079ab..b26d7e844 100644 --- a/config/auth_gcp_google_credentials.go +++ b/config/auth_gcp_google_credentials.go @@ -6,6 +6,7 @@ import ( "io/ioutil" "os" + "github.com/databricks/databricks-sdk-go/common/environment" "github.com/databricks/databricks-sdk-go/config/credentials" "github.com/databricks/databricks-sdk-go/logger" "golang.org/x/oauth2/google" @@ -20,6 +21,17 @@ func (c GoogleCredentials) Name() string { return "google-credentials" } +// Validate implements [ValidatingStrategy.Validate]. +func (c GoogleCredentials) Validate(_ context.Context, cfg *Config) error { + if cfg.GoogleCredentials == "" { + return fmt.Errorf("google_credentials is not set") + } + if cfg.Environment().Cloud != environment.CloudGCP { + return fmt.Errorf("%w: requires GCP, got %s", ErrInvalidCloud, cfg.Environment().Cloud) + } + return nil +} + func (c GoogleCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { if cfg.GoogleCredentials == "" { return nil, nil diff --git a/config/auth_gcp_google_id.go b/config/auth_gcp_google_id.go index 8ef76d720..bc7e70ed1 100644 --- a/config/auth_gcp_google_id.go +++ b/config/auth_gcp_google_id.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "github.com/databricks/databricks-sdk-go/common/environment" "github.com/databricks/databricks-sdk-go/config/credentials" "github.com/databricks/databricks-sdk-go/logger" "golang.org/x/oauth2" @@ -20,6 +21,17 @@ func (c GoogleDefaultCredentials) Name() string { return "google-id" } +// Validate implements [ValidatingStrategy.Validate]. +func (c GoogleDefaultCredentials) Validate(_ context.Context, cfg *Config) error { + if cfg.GoogleServiceAccount == "" { + return fmt.Errorf("google_service_account is not set") + } + if cfg.Environment().Cloud != environment.CloudGCP { + return fmt.Errorf("%w: requires GCP, got %s", ErrInvalidCloud, cfg.Environment().Cloud) + } + return nil +} + func (c GoogleDefaultCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { if cfg.GoogleServiceAccount == "" { return nil, nil diff --git a/config/config.go b/config/config.go index ac07d4d46..a77e8ed32 100644 --- a/config/config.go +++ b/config/config.go @@ -34,6 +34,26 @@ type CredentialsStrategy interface { Configure(context.Context, *Config) (credentials.CredentialsProvider, error) } +// ErrInvalidCloud is returned by [ValidatingStrategy.Validate] when the +// configured host's cloud does not match the cloud required by the strategy. +// In auto-detect mode, strategies returning this error are skipped. When +// [Config.AuthType] is explicitly set, this error is ignored and the strategy +// is still attempted, allowing users to force a cloud-specific auth method on +// any host. +var ErrInvalidCloud = errors.New("cloud not supported by this strategy") + +// ValidatingStrategy is an optional interface that a [CredentialsStrategy] can +// implement to declare upfront whether it is applicable for the current +// configuration. The credentials chain consults Validate before Configure: +// - In auto-detect mode: any error from Validate causes the strategy to be +// skipped. +// - When [Config.AuthType] is explicitly set: [ErrInvalidCloud] is ignored so +// that users can force a cloud-specific strategy on a different-cloud host. +// Any other validation error is propagated. +type ValidatingStrategy interface { + Validate(context.Context, *Config) error +} + type Loader interface { // Name is human-addressable representation of this config resolver Name() string