Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions config/auth_azure_cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ func (c AzureCliCredentials) getVisitor(ctx context.Context, cfg *Config, inner
}

func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if !cfg.IsAzure() {
return nil, nil
}
// Set the azure tenant ID from host if available
err := cfg.loadAzureTenantId(ctx)
if err != nil {
Expand Down
9 changes: 7 additions & 2 deletions config/auth_azure_cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,11 @@ var azDummy = &Config{
Host: "https://adb-xyz.c.azuredatabricks.net/",
azureTenantIdFetchClient: makeClient(redirectResponse),
}

var azDummyWithResourceId = &Config{
Host: "https://adb-xyz.c.azuredatabricks.net/",
AzureResourceID: "/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123",
azureTenantIdFetchClient: makeClient(redirectResponse),
}

var azDummyWitInvalidResourceId = &Config{
Host: "https://adb-xyz.c.azuredatabricks.net/",
AzureResourceID: "invalidResourceId",
Expand All @@ -80,6 +78,13 @@ func testdataPath() string {
return strings.Join(paths, ":")
}

func TestAzureCliCredentials_SkipAws(t *testing.T) {
aa := AzureCliCredentials{}
x, err := aa.Configure(context.Background(), &Config{Host: "https://xyz.cloud.databricks.com/"})
assert.Nil(t, x)
assert.NoError(t, err)
}

func TestAzureCliCredentials_NotInstalled(t *testing.T) {
env.CleanupEnvironment(t)
os.Setenv("PATH", "whatever")
Expand Down
3 changes: 3 additions & 0 deletions config/auth_azure_client_secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config
if cfg.AzureClientID == "" || cfg.AzureClientSecret == "" || cfg.AzureTenantID == "" {
return nil, nil
}
if !cfg.IsAzure() {
return nil, nil
}
err := cfg.loadAzureTenantId(ctx)
if err != nil {
return nil, fmt.Errorf("load tenant id: %w", err)
Expand Down
3 changes: 2 additions & 1 deletion config/auth_azure_github_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ func (c AzureGithubOIDCCredentials) Name() string {

// Configure implements [CredentialsStrategy.Configure].
func (c AzureGithubOIDCCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if cfg.AzureClientID == "" || cfg.Host == "" || cfg.AzureTenantID == "" || cfg.ActionsIDTokenRequestURL == "" || cfg.ActionsIDTokenRequestToken == "" {
// Sanity check that the config is configured for Azure Databricks.
if !cfg.IsAzure() || cfg.AzureClientID == "" || cfg.Host == "" || cfg.AzureTenantID == "" || cfg.ActionsIDTokenRequestURL == "" || cfg.ActionsIDTokenRequestToken == "" {
return nil, nil
}
supplier := oidc.NewGithubIDTokenSource(
Expand Down
2 changes: 1 addition & 1 deletion config/auth_azure_msi.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (c AzureMsiCredentials) Name() string {
}

func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if !cfg.AzureUseMSI || (cfg.AzureResourceID == "" && cfg.ConfigType() == WorkspaceConfig) {
if !cfg.IsAzure() || !cfg.AzureUseMSI || (cfg.AzureResourceID == "" && cfg.ConfigType() == WorkspaceConfig) {
return nil, nil
}
env := cfg.Environment()
Expand Down
44 changes: 4 additions & 40 deletions config/auth_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"net/http"

"github.com/databricks/databricks-sdk-go/common/environment"
"github.com/databricks/databricks-sdk-go/config/credentials"
"github.com/databricks/databricks-sdk-go/logger"
)
Expand All @@ -23,25 +22,13 @@ var ErrCannotConfigureDefault = fmt.Errorf("cannot configure default credentials
// INTERNAL: This function is not part of the public API and is subject to
// change. Users are encouraged to use an explicit credentials strategy rather
// than relying on a custom credentials chain.
func NewCredentialsChain(strategies ...CredentialsStrategy) *credentialsChain {
func NewCredentialsChain(strategies ...CredentialsStrategy) CredentialsStrategy {
return &credentialsChain{strategies: strategies}
}

type credentialsChain struct {
strategies []CredentialsStrategy
// cloudRequirements maps a strategy name to the cloud it requires. When
// set, the auto-detect loop skips strategies whose required cloud does not
// match the configured host. The map is not consulted when AuthType is
// explicitly set — in that case the named strategy is always attempted.
cloudRequirements map[string]environment.Cloud
name string
}

// WithCloudRequirements sets the cloud requirements for the chain and returns
// the chain for method chaining.
func (c *credentialsChain) WithCloudRequirements(m map[string]environment.Cloud) *credentialsChain {
c.cloudRequirements = m
return c
name string
}

func (c *credentialsChain) Name() string {
Expand All @@ -55,9 +42,7 @@ func (c *credentialsChain) Configure(ctx context.Context, cfg *Config) (credenti
}

// If an auth type is specified, try to configure the credentials for that
// specific auth type. Cloud filtering is bypassed entirely so that users
// can explicitly request any strategy regardless of detected cloud (e.g.
// "azure-cli" on a GCP host).
// specific auth type. If an error is encountered, return it.
if cfg.AuthType != "" {
for _, s := range c.strategies {
if s.Name() == cfg.AuthType {
Expand All @@ -73,16 +58,6 @@ func (c *credentialsChain) Configure(ctx context.Context, cfg *Config) (credenti
// succeeds, returns the credentials provider. If a strategy fails, swallow
// the error and try the next strategy.
for _, s := range c.strategies {
// In auto-detect mode, skip cloud-specific strategies that don't match
// the detected cloud. This prevents Azure strategies from being
// attempted silently on GCP hosts and vice-versa.
if requiredCloud, ok := c.cloudRequirements[s.Name()]; ok {
if cfg.Environment().Cloud != requiredCloud {
logger.Debugf(ctx, "Skipping %q: not configured for %s", s.Name(), requiredCloud)
continue
}
}

logger.Tracef(ctx, "Attempting to configure auth: %q", s.Name())
cp, err := s.Configure(ctx, cfg)
if err != nil || cp == nil {
Expand Down Expand Up @@ -144,18 +119,7 @@ func (c *DefaultCredentials) Configure(ctx context.Context, cfg *Config) (creden
// Google strategies.
GoogleCredentials{},
GoogleDefaultCredentials{},
).WithCloudRequirements(map[string]environment.Cloud{
// cloudRequirements declares the cloud each strategy requires.
// DefaultCredentials uses this to skip cloud-specific strategies in
// auto-detect mode when the host cloud does not match. Cloud filtering
// is bypassed when AuthType is explicitly set.
"github-oidc-azure": environment.CloudAzure,
"azure-msi": environment.CloudAzure,
"azure-client-secret": environment.CloudAzure,
"azure-cli": environment.CloudAzure,
"google-credentials": environment.CloudGCP,
"google-id": environment.CloudGCP,
})
)
cp, err := chain.Configure(ctx, cfg)
if err != nil {
return nil, err
Expand Down
57 changes: 0 additions & 57 deletions config/auth_default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,65 +5,8 @@ import (
"net/http"
"strings"
"testing"

"github.com/databricks/databricks-sdk-go/common/environment"
"github.com/databricks/databricks-sdk-go/config/credentials"
)

// recordingStrategy is a test helper that records whether Configure was called.
type recordingStrategy struct {
name string
called bool
}

func (r *recordingStrategy) Name() string { return r.name }
func (r *recordingStrategy) Configure(_ context.Context, _ *Config) (credentials.CredentialsProvider, error) {
r.called = true
return nil, nil
}

// TestCredentialsChain_CloudFiltering_SkipsOnCloudMismatch verifies that the
// chain skips a cloud-specific strategy in auto-detect mode when the detected
// cloud does not match the strategy's required cloud.
func TestCredentialsChain_CloudFiltering_SkipsOnCloudMismatch(t *testing.T) {
azureStrategy := &recordingStrategy{name: "azure-cli"}
chain := &credentialsChain{
strategies: []CredentialsStrategy{azureStrategy},
cloudRequirements: map[string]environment.Cloud{
"azure-cli": environment.CloudAzure,
},
}

// GCP host: azure-cli must be skipped in auto-detect mode.
cfg := &Config{Host: "https://xyz.gcp.databricks.com/", resolved: true}
chain.Configure(context.Background(), cfg) //nolint:errcheck

if azureStrategy.called {
t.Error("azure-cli strategy was called on GCP host, want it to be skipped in auto-detect mode")
}
}

// TestCredentialsChain_CloudFiltering_BypassesOnExplicitAuthType verifies that
// the cloud filter is bypassed when AuthType is explicitly set, so that a user
// can request "azure-cli" even on a GCP host.
func TestCredentialsChain_CloudFiltering_BypassesOnExplicitAuthType(t *testing.T) {
azureStrategy := &recordingStrategy{name: "azure-cli"}
chain := &credentialsChain{
strategies: []CredentialsStrategy{azureStrategy},
cloudRequirements: map[string]environment.Cloud{
"azure-cli": environment.CloudAzure,
},
}

// GCP host but auth_type is explicitly set: cloud filter must be bypassed.
cfg := &Config{Host: "https://xyz.gcp.databricks.com/", AuthType: "azure-cli", resolved: true}
chain.Configure(context.Background(), cfg) //nolint:errcheck

if !azureStrategy.called {
t.Error("azure-cli strategy was not called despite explicit auth_type on GCP host, want bypass of cloud filter")
}
}

func TestDefaultCredentialStrategy(t *testing.T) {
original := DefaultCredentialStrategyProvider
t.Cleanup(func() { DefaultCredentialStrategyProvider = original })
Expand Down
2 changes: 1 addition & 1 deletion config/auth_gcp_google_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func (c GoogleCredentials) Name() string {
}

func (c GoogleCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if cfg.GoogleCredentials == "" {
if cfg.GoogleCredentials == "" || !cfg.IsGcp() {
return nil, nil
}
json, err := readCredentials(cfg.GoogleCredentials)
Expand Down
2 changes: 1 addition & 1 deletion config/auth_gcp_google_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func (c GoogleDefaultCredentials) Name() string {
}

func (c GoogleDefaultCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if cfg.GoogleServiceAccount == "" {
if cfg.GoogleServiceAccount == "" || !cfg.IsGcp() {
return nil, nil
}
inner, err := c.idTokenSource(ctx, cfg.Host, cfg.GoogleServiceAccount, c.opts...)
Expand Down
Loading