Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@

### Internal Changes
* Implement dynamic auth token stale period based on initial token lifetime. Increased up to 20 mins for standard OAuth with proportionally shorter periods for short-lived tokens.
* Move cloud-based credential filtering from individual strategies into `DefaultCredentials`. Azure strategies are skipped on GCP/AWS hosts in auto-detect mode; GCP strategies are skipped on Azure/AWS hosts. When `auth_type` is explicitly set (e.g. `azure-cli`), cloud filtering is bypassed so the named strategy is always attempted regardless of host cloud.

### API Changes
3 changes: 0 additions & 3 deletions config/auth_azure_cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@ func (c AzureCliCredentials) getVisitor(ctx context.Context, cfg *Config, inner
}

func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if !cfg.IsAzure() {
return nil, nil
}
// Set the azure tenant ID from host if available
err := cfg.loadAzureTenantId(ctx)
if err != nil {
Expand Down
9 changes: 2 additions & 7 deletions config/auth_azure_cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,13 @@ var azDummy = &Config{
Host: "https://adb-xyz.c.azuredatabricks.net/",
azureTenantIdFetchClient: makeClient(redirectResponse),
}

var azDummyWithResourceId = &Config{
Host: "https://adb-xyz.c.azuredatabricks.net/",
AzureResourceID: "/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123",
azureTenantIdFetchClient: makeClient(redirectResponse),
}

var azDummyWitInvalidResourceId = &Config{
Host: "https://adb-xyz.c.azuredatabricks.net/",
AzureResourceID: "invalidResourceId",
Expand All @@ -78,13 +80,6 @@ func testdataPath() string {
return strings.Join(paths, ":")
}

func TestAzureCliCredentials_SkipAws(t *testing.T) {
aa := AzureCliCredentials{}
x, err := aa.Configure(context.Background(), &Config{Host: "https://xyz.cloud.databricks.com/"})
assert.Nil(t, x)
assert.NoError(t, err)
}

func TestAzureCliCredentials_NotInstalled(t *testing.T) {
env.CleanupEnvironment(t)
os.Setenv("PATH", "whatever")
Expand Down
3 changes: 0 additions & 3 deletions config/auth_azure_client_secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config
if cfg.AzureClientID == "" || cfg.AzureClientSecret == "" || cfg.AzureTenantID == "" {
return nil, nil
}
if !cfg.IsAzure() {
return nil, nil
}
err := cfg.loadAzureTenantId(ctx)
if err != nil {
return nil, fmt.Errorf("load tenant id: %w", err)
Expand Down
3 changes: 1 addition & 2 deletions config/auth_azure_github_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ func (c AzureGithubOIDCCredentials) Name() string {

// Configure implements [CredentialsStrategy.Configure].
func (c AzureGithubOIDCCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
// Sanity check that the config is configured for Azure Databricks.
if !cfg.IsAzure() || cfg.AzureClientID == "" || cfg.Host == "" || cfg.AzureTenantID == "" || cfg.ActionsIDTokenRequestURL == "" || cfg.ActionsIDTokenRequestToken == "" {
if cfg.AzureClientID == "" || cfg.Host == "" || cfg.AzureTenantID == "" || cfg.ActionsIDTokenRequestURL == "" || cfg.ActionsIDTokenRequestToken == "" {
return nil, nil
}
supplier := oidc.NewGithubIDTokenSource(
Expand Down
2 changes: 1 addition & 1 deletion config/auth_azure_msi.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (c AzureMsiCredentials) Name() string {
}

func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if !cfg.IsAzure() || !cfg.AzureUseMSI || (cfg.AzureResourceID == "" && cfg.ConfigType() == WorkspaceConfig) {
if !cfg.AzureUseMSI || (cfg.AzureResourceID == "" && cfg.ConfigType() == WorkspaceConfig) {
return nil, nil
}
env := cfg.Environment()
Expand Down
44 changes: 40 additions & 4 deletions config/auth_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"net/http"

"github.com/databricks/databricks-sdk-go/common/environment"
"github.com/databricks/databricks-sdk-go/config/credentials"
"github.com/databricks/databricks-sdk-go/logger"
)
Expand All @@ -22,13 +23,25 @@ var ErrCannotConfigureDefault = fmt.Errorf("cannot configure default credentials
// INTERNAL: This function is not part of the public API and is subject to
// change. Users are encouraged to use an explicit credentials strategy rather
// than relying on a custom credentials chain.
func NewCredentialsChain(strategies ...CredentialsStrategy) CredentialsStrategy {
func NewCredentialsChain(strategies ...CredentialsStrategy) *credentialsChain {
return &credentialsChain{strategies: strategies}
}

type credentialsChain struct {
strategies []CredentialsStrategy
name string
// cloudRequirements maps a strategy name to the cloud it requires. When
// set, the auto-detect loop skips strategies whose required cloud does not
// match the configured host. The map is not consulted when AuthType is
// explicitly set — in that case the named strategy is always attempted.
cloudRequirements map[string]environment.Cloud
name string
}

// WithCloudRequirements sets the cloud requirements for the chain and returns
// the chain for method chaining.
func (c *credentialsChain) WithCloudRequirements(m map[string]environment.Cloud) *credentialsChain {
c.cloudRequirements = m
return c
}

func (c *credentialsChain) Name() string {
Expand All @@ -42,7 +55,9 @@ func (c *credentialsChain) Configure(ctx context.Context, cfg *Config) (credenti
}

// If an auth type is specified, try to configure the credentials for that
// specific auth type. If an error is encountered, return it.
// specific auth type. Cloud filtering is bypassed entirely so that users
// can explicitly request any strategy regardless of detected cloud (e.g.
// "azure-cli" on a GCP host).
if cfg.AuthType != "" {
for _, s := range c.strategies {
if s.Name() == cfg.AuthType {
Expand All @@ -58,6 +73,16 @@ func (c *credentialsChain) Configure(ctx context.Context, cfg *Config) (credenti
// succeeds, returns the credentials provider. If a strategy fails, swallow
// the error and try the next strategy.
for _, s := range c.strategies {
// In auto-detect mode, skip cloud-specific strategies that don't match
// the detected cloud. This prevents Azure strategies from being
// attempted silently on GCP hosts and vice-versa.
if requiredCloud, ok := c.cloudRequirements[s.Name()]; ok {
if cfg.Environment().Cloud != requiredCloud {
logger.Debugf(ctx, "Skipping %q: not configured for %s", s.Name(), requiredCloud)
continue
}
}

logger.Tracef(ctx, "Attempting to configure auth: %q", s.Name())
cp, err := s.Configure(ctx, cfg)
if err != nil || cp == nil {
Expand Down Expand Up @@ -119,7 +144,18 @@ func (c *DefaultCredentials) Configure(ctx context.Context, cfg *Config) (creden
// Google strategies.
GoogleCredentials{},
GoogleDefaultCredentials{},
)
).WithCloudRequirements(map[string]environment.Cloud{
// cloudRequirements declares the cloud each strategy requires.
// DefaultCredentials uses this to skip cloud-specific strategies in
// auto-detect mode when the host cloud does not match. Cloud filtering
// is bypassed when AuthType is explicitly set.
"github-oidc-azure": environment.CloudAzure,
"azure-msi": environment.CloudAzure,
"azure-client-secret": environment.CloudAzure,
"azure-cli": environment.CloudAzure,
"google-credentials": environment.CloudGCP,
"google-id": environment.CloudGCP,
})
cp, err := chain.Configure(ctx, cfg)
if err != nil {
return nil, err
Expand Down
57 changes: 57 additions & 0 deletions config/auth_default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,65 @@ import (
"net/http"
"strings"
"testing"

"github.com/databricks/databricks-sdk-go/common/environment"
"github.com/databricks/databricks-sdk-go/config/credentials"
)

// recordingStrategy is a test helper that records whether Configure was called.
type recordingStrategy struct {
name string
called bool
}

func (r *recordingStrategy) Name() string { return r.name }
func (r *recordingStrategy) Configure(_ context.Context, _ *Config) (credentials.CredentialsProvider, error) {
r.called = true
return nil, nil
}

// TestCredentialsChain_CloudFiltering_SkipsOnCloudMismatch verifies that the
// chain skips a cloud-specific strategy in auto-detect mode when the detected
// cloud does not match the strategy's required cloud.
func TestCredentialsChain_CloudFiltering_SkipsOnCloudMismatch(t *testing.T) {
azureStrategy := &recordingStrategy{name: "azure-cli"}
chain := &credentialsChain{
strategies: []CredentialsStrategy{azureStrategy},
cloudRequirements: map[string]environment.Cloud{
"azure-cli": environment.CloudAzure,
},
}

// GCP host: azure-cli must be skipped in auto-detect mode.
cfg := &Config{Host: "https://xyz.gcp.databricks.com/", resolved: true}
chain.Configure(context.Background(), cfg) //nolint:errcheck

if azureStrategy.called {
t.Error("azure-cli strategy was called on GCP host, want it to be skipped in auto-detect mode")
}
}

// TestCredentialsChain_CloudFiltering_BypassesOnExplicitAuthType verifies that
// the cloud filter is bypassed when AuthType is explicitly set, so that a user
// can request "azure-cli" even on a GCP host.
func TestCredentialsChain_CloudFiltering_BypassesOnExplicitAuthType(t *testing.T) {
azureStrategy := &recordingStrategy{name: "azure-cli"}
chain := &credentialsChain{
strategies: []CredentialsStrategy{azureStrategy},
cloudRequirements: map[string]environment.Cloud{
"azure-cli": environment.CloudAzure,
},
}

// GCP host but auth_type is explicitly set: cloud filter must be bypassed.
cfg := &Config{Host: "https://xyz.gcp.databricks.com/", AuthType: "azure-cli", resolved: true}
chain.Configure(context.Background(), cfg) //nolint:errcheck

if !azureStrategy.called {
t.Error("azure-cli strategy was not called despite explicit auth_type on GCP host, want bypass of cloud filter")
}
}

func TestDefaultCredentialStrategy(t *testing.T) {
original := DefaultCredentialStrategyProvider
t.Cleanup(func() { DefaultCredentialStrategyProvider = original })
Expand Down
2 changes: 1 addition & 1 deletion config/auth_gcp_google_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func (c GoogleCredentials) Name() string {
}

func (c GoogleCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if cfg.GoogleCredentials == "" || !cfg.IsGcp() {
if cfg.GoogleCredentials == "" {
return nil, nil
}
json, err := readCredentials(cfg.GoogleCredentials)
Expand Down
2 changes: 1 addition & 1 deletion config/auth_gcp_google_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func (c GoogleDefaultCredentials) Name() string {
}

func (c GoogleDefaultCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if cfg.GoogleServiceAccount == "" || !cfg.IsGcp() {
if cfg.GoogleServiceAccount == "" {
return nil, nil
}
inner, err := c.idTokenSource(ctx, cfg.Host, cfg.GoogleServiceAccount, c.opts...)
Expand Down
Loading