Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
### Documentation

### Internal Changes
* Move cloud-based credential filtering from individual strategies into `DefaultCredentials`. Azure strategies are skipped on GCP/AWS hosts in auto-detect mode; GCP strategies are skipped on Azure/AWS hosts. When `auth_type` is explicitly set (e.g. `azure-cli`), cloud filtering is bypassed so the named strategy is always attempted regardless of host cloud.

### API Changes
24 changes: 24 additions & 0 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from . import azure, oauth, oidc, oidc_token_supplier
from .client_types import ClientType
from .environments import Cloud

CredentialsProvider = Callable[[], Dict[str, str]]

Expand Down Expand Up @@ -1145,6 +1146,19 @@ def model_serving_auth(cfg: "Config") -> Optional[CredentialsProvider]:
return model_serving_auth_visitor(cfg)


# _CLOUD_REQUIREMENTS maps auth type names to the cloud they require.
# DefaultCredentials uses this to skip cloud-specific strategies in auto-detect
# mode when the host cloud does not match. Cloud filtering is bypassed when
# auth_type is explicitly set.
_CLOUD_REQUIREMENTS: Dict[str, Cloud] = {
"github-oidc-azure": Cloud.AZURE,
"azure-client-secret": Cloud.AZURE,
"azure-cli": Cloud.AZURE,
"google-credentials": Cloud.GCP,
"google-id": Cloud.GCP,
}


class DefaultCredentials:
"""Select the first applicable credential provider from the chain"""

Expand Down Expand Up @@ -1189,6 +1203,16 @@ def __call__(self, cfg: "Config") -> CredentialsProvider:
# ignore other auth types if one is explicitly enforced
logger.debug(f"Ignoring {auth_type} auth, because {cfg.auth_type} is preferred")
continue
# In auto-detect mode, skip cloud-specific strategies that don't
# match the detected cloud. This prevents Azure strategies from
# being attempted on GCP hosts and vice-versa. When auth_type is
# explicitly set, cloud filtering is bypassed so the named strategy
# is always attempted regardless of detected host cloud.
if not cfg.auth_type and auth_type in _CLOUD_REQUIREMENTS:
required_cloud = _CLOUD_REQUIREMENTS[auth_type]
if cfg.environment.cloud != required_cloud:
logger.debug(f'Skipping "{auth_type}": not configured for {required_cloud.value}')
continue
logger.debug(f"Attempting to configure auth: {auth_type}")
try:
# The header factory might be None if the provider cannot be
Expand Down
55 changes: 55 additions & 0 deletions tests/test_credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from databricks.sdk import credentials_provider, oauth, oidc
from databricks.sdk.client_types import ClientType
from databricks.sdk.config import Config
from databricks.sdk.environments import Cloud


# Tests for external_browser function
Expand Down Expand Up @@ -737,3 +738,57 @@ def test_azure_cli_returns_none_without_effective_azure_login_app_id(self):
# Should return None due to missing requirement
provider = credentials_provider.azure_cli(mock_cfg)
assert provider is None


class TestDefaultCredentialsCloudFiltering:
"""Tests for cloud-based credential filtering in DefaultCredentials."""

def _recording_strategy(self, name: str, cloud: Cloud = None):
"""Returns a mock CredentialsStrategy that records whether it was called."""
called = []

class _Strategy:
def auth_type(self):
return name

def __call__(self, cfg):
called.append(True)
return None

strategy = _Strategy()
strategy.called = called
return strategy

def test_skips_azure_strategy_on_gcp_host_in_auto_detect_mode(self):
"""In auto-detect mode, Azure strategies must be skipped on a GCP host."""
azure_strategy = self._recording_strategy("azure-cli")

dc = credentials_provider.DefaultCredentials()
dc._auth_providers = [azure_strategy]

cfg = Mock()
cfg.auth_type = None # auto-detect
cfg.environment = Mock()
cfg.environment.cloud = Cloud.GCP

with pytest.raises(ValueError):
dc(cfg)

assert not azure_strategy.called, "azure-cli must not be called on a GCP host in auto-detect mode"

def test_bypasses_cloud_filter_when_auth_type_explicitly_set(self):
"""When auth_type is explicitly set, cloud filtering must be bypassed."""
azure_strategy = self._recording_strategy("azure-cli")

dc = credentials_provider.DefaultCredentials()
dc._auth_providers = [azure_strategy]

cfg = Mock()
cfg.auth_type = "azure-cli" # explicitly set
cfg.environment = Mock()
cfg.environment.cloud = Cloud.GCP

with pytest.raises(ValueError):
dc(cfg)

assert azure_strategy.called, "azure-cli must be called despite GCP host when auth_type is explicitly set"
Loading