diff --git a/common/environment/environments.go b/common/environment/environments.go index 57acd99d7..9ba9b7a35 100644 --- a/common/environment/environments.go +++ b/common/environment/environments.go @@ -1,6 +1,7 @@ package environment import ( + "encoding/json" "fmt" "strings" ) @@ -8,11 +9,40 @@ import ( type Cloud string const ( - CloudAWS Cloud = "AWS" - CloudAzure Cloud = "Azure" - CloudGCP Cloud = "GCP" + CloudAWS Cloud = "AWS" + CloudAzure Cloud = "Azure" + CloudGCP Cloud = "GCP" + CloudUnknown Cloud = "" ) +// UnmarshalJSON automatically normalizes the cloud string when parsing JSON +func (c *Cloud) UnmarshalJSON(data []byte) error { + var rawString string + + if err := json.Unmarshal(data, &rawString); err != nil { + return err + } + + *c = normalizeCloud(rawString) + return nil +} + +func normalizeCloud(cloud string) Cloud { + switch strings.ToUpper(cloud) { + case "AWS": + return CloudAWS + case "AZURE": + return CloudAzure + case "GCP": + return CloudGCP + case "": + return CloudUnknown + // For forward compatibility with new cloud providers that are not yet supported by the SDK. + default: + return Cloud(cloud) + } +} + type DatabricksEnvironment struct { Cloud Cloud DnsZone string @@ -51,7 +81,6 @@ func DefaultEnvironment() DatabricksEnvironment { Cloud: CloudAWS, DnsZone: ".cloud.databricks.com", } - } var envs = []DatabricksEnvironment{ diff --git a/config/config.go b/config/config.go index e8be64535..9171d6cd7 100644 --- a/config/config.go +++ b/config/config.go @@ -210,7 +210,11 @@ type Config struct { // HTTPTransport can be overriden for unit testing and together with tooling like https://github.com/google/go-replayers HTTPTransport http.RoundTripper - // Environment override to return when resolving the current environment. + // Cloud is the cloud provider for this Databricks deployment (AWS, Azure, GCP, or CloudUnknown). + // + // Experimental: subject to change. + Cloud environment.Cloud `name:"cloud" env:"DATABRICKS_CLOUD" auth:"-"` + DatabricksEnvironment *environment.DatabricksEnvironment // When using Workload Identity Federation, the audience to specify when fetching an ID token from the ID token supplier. @@ -352,16 +356,25 @@ func (c *Config) IsAzure() bool { if c.AzureResourceID != "" { return true } + if c.Cloud != environment.CloudUnknown { + return c.Cloud == environment.CloudAzure + } return c.Environment().Cloud == environment.CloudAzure } // IsGcp returns if the client is configured for Databricks on Google Cloud. func (c *Config) IsGcp() bool { + if c.Cloud != environment.CloudUnknown { + return c.Cloud == environment.CloudGCP + } return c.Environment().Cloud == environment.CloudGCP } // IsAws returns if the client is configured for Databricks on AWS. func (c *Config) IsAws() bool { + if c.Cloud != environment.CloudUnknown { + return c.Cloud == environment.CloudAWS + } return c.Host != "" && !c.IsAzure() && !c.IsGcp() } @@ -643,7 +656,7 @@ func (c *Config) getOidcEndpoints(ctx context.Context) (*u2m.OAuthAuthorizationS // resolveHostMetadata populates missing config fields from the host's // /.well-known/databricks-config discovery endpoint. It back-fills AccountID, -// WorkspaceID, and DiscoveryURL (with any {account_id} placeholder substituted) +// WorkspaceID, Cloud, and DiscoveryURL (with any {account_id} placeholder substituted) // if those fields are not already set. Returns an error if AccountID cannot be // resolved or no oidc_endpoint is present in the metadata. // @@ -657,6 +670,7 @@ func (c *Config) resolveHostMetadata(ctx context.Context) error { } meta, err := getHostMetadata(ctx, c.CanonicalHostName(), c.refreshClient) if err != nil { + logger.Debugf(ctx, "Failed to fetch host metadata: %v", err) return err } if c.AccountID == "" && meta.AccountID != "" { @@ -670,6 +684,14 @@ func (c *Config) resolveHostMetadata(ctx context.Context) error { logger.Debugf(ctx, "Resolved workspace_id from host metadata: %q", meta.WorkspaceID) c.WorkspaceID = meta.WorkspaceID } + if c.Cloud == "" && meta.Cloud != environment.CloudUnknown { + logger.Debugf(ctx, "Resolved cloud from host metadata: %q", meta.Cloud) + c.Cloud = meta.Cloud + } + if c.Cloud == "" { + c.Cloud = c.Environment().Cloud + logger.Debugf(ctx, "Resolved cloud from hostname: %q", c.Cloud) + } if c.DiscoveryURL == "" { if meta.OIDCEndpoint == "" { return fmt.Errorf("discovery_url is not configured and could not be resolved from host metadata") diff --git a/config/config_test.go b/config/config_test.go index 8b2f4ccac..ce0e5ada0 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -666,3 +666,158 @@ func TestConfig_ResolveHostMetadata_NoHost(t *testing.T) { t.Fatal(err) } } + +func TestConfig_ResolveHostMetadata_PopulatesCloudFromAPI(t *testing.T) { + cfg := &Config{ + Host: testHMHost, + Token: "t", + HTTPTransport: fixtures.SliceTransport{ + { + Method: "GET", + Resource: "/.well-known/databricks-config", + Status: 200, + Response: `{"oidc_endpoint": "` + testHMHost + `/oidc", "account_id": "` + testHMAccountID + `", "cloud": "Azure"}`, + }, + }, + } + if err := cfg.resolveHostMetadata(context.Background()); err != nil { + t.Fatal(err) + } + if cfg.Cloud != "Azure" { + t.Errorf("unexpected Cloud: %q", cfg.Cloud) + } +} + +func TestConfig_ResolveHostMetadata_CloudFallbackToDNS(t *testing.T) { + cfg := &Config{ + Host: "https://my-workspace.azuredatabricks.net", + Token: "t", + HTTPTransport: fixtures.SliceTransport{ + { + Method: "GET", + Resource: "/.well-known/databricks-config", + Status: 200, + Response: `{"oidc_endpoint": "https://my-workspace.azuredatabricks.net/oidc", "account_id": "` + testHMAccountID + `"}`, + }, + }, + } + if err := cfg.resolveHostMetadata(context.Background()); err != nil { + t.Fatal(err) + } + if cfg.Cloud != "Azure" { + t.Errorf("unexpected Cloud from DNS fallback: %q", cfg.Cloud) + } +} + +func TestConfig_ResolveHostMetadata_DoesNotOverwriteExistingCloud(t *testing.T) { + cfg := &Config{ + Host: testHMHost, + Token: "t", + Cloud: "GCP", + HTTPTransport: fixtures.SliceTransport{ + { + Method: "GET", + Resource: "/.well-known/databricks-config", + Status: 200, + Response: `{"oidc_endpoint": "` + testHMHost + `/oidc", "account_id": "` + testHMAccountID + `", "cloud": "AWS"}`, + }, + }, + } + if err := cfg.resolveHostMetadata(context.Background()); err != nil { + t.Fatal(err) + } + if cfg.Cloud != "GCP" { + t.Errorf("Cloud was overwritten: got %q, want %q", cfg.Cloud, "GCP") + } +} + +func TestConfig_ResolveHostMetadata_Clouds(t *testing.T) { + tests := []struct { + name string + cloudJSON string + wantCloud string + }{ + { + name: "AWS", + cloudJSON: "AWS", + wantCloud: "AWS", + }, + { + name: "Azure", + cloudJSON: "Azure", + wantCloud: "Azure", + }, + { + name: "GCP", + cloudJSON: "GCP", + wantCloud: "GCP", + }, + { + name: "aws lowercase", + cloudJSON: "aws", + wantCloud: "AWS", + }, + { + name: "AWS uppercase", + cloudJSON: "AWS", + wantCloud: "AWS", + }, + { + name: "azure lowercase", + cloudJSON: "azure", + wantCloud: "Azure", + }, + { + name: "AZURE uppercase", + cloudJSON: "AZURE", + wantCloud: "Azure", + }, + { + name: "Azure title case", + cloudJSON: "Azure", + wantCloud: "Azure", + }, + { + name: "gcp lowercase", + cloudJSON: "gcp", + wantCloud: "GCP", + }, + { + name: "GCP uppercase", + cloudJSON: "GCP", + wantCloud: "GCP", + }, + { + name: "Another cloud is supported", + cloudJSON: "Another", + wantCloud: "Another", + }, + { + name: "Unknown cloud falls back to DNS", + cloudJSON: "", + wantCloud: "AWS", // Falls back to DNS-based detection for testHMHost + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + cfg := &Config{ + Host: testHMHost, + Token: "t", + HTTPTransport: fixtures.SliceTransport{ + { + Method: "GET", + Resource: "/.well-known/databricks-config", + Status: 200, + Response: `{"oidc_endpoint": "` + testHMHost + `/oidc", "account_id": "` + testHMAccountID + `", "cloud": "` + tc.cloudJSON + `"}`, + }, + }, + } + if err := cfg.resolveHostMetadata(context.Background()); err != nil { + t.Fatal(err) + } + if string(cfg.Cloud) != tc.wantCloud { + t.Errorf("unexpected Cloud: got %q, want %q", cfg.Cloud, tc.wantCloud) + } + }) + } +} diff --git a/config/environments.go b/config/environments.go index 597fdf7b6..5dfa76312 100644 --- a/config/environments.go +++ b/config/environments.go @@ -6,6 +6,9 @@ import ( "github.com/databricks/databricks-sdk-go/common/environment" ) +// Deprecated: Use the Cloud field and cloud-specific helper methods (IsAws, IsAzure, IsGcp) +// instead. Environment() returns environment metadata including cloud type and Azure-specific +// endpoints. func (c *Config) Environment() environment.DatabricksEnvironment { // Use the provided environment if specified. Tests may configure the client with different hostnames, // like localhost, which are not resolvable to a known environment, while needing to mock a specific environment. diff --git a/config/environments_test.go b/config/environments_test.go index 901b882b7..a134c393a 100644 --- a/config/environments_test.go +++ b/config/environments_test.go @@ -22,3 +22,59 @@ func TestOverriddenEnvironmentIsReturned(t *testing.T) { } assert.Equal(t, "holla", c.Environment().DnsZone) } + +func TestCloudField_AWS(t *testing.T) { + c := &Config{ + Host: "https://my-workspace.cloud.databricks.com", + Cloud: environment.CloudAWS, + } + assert.True(t, c.IsAws()) + assert.False(t, c.IsAzure()) + assert.False(t, c.IsGcp()) +} + +func TestCloudField_Azure(t *testing.T) { + c := &Config{ + Host: "https://my-workspace.azuredatabricks.net", + Cloud: environment.CloudAzure, + } + assert.True(t, c.IsAzure()) + assert.False(t, c.IsAws()) + assert.False(t, c.IsGcp()) +} + +func TestCloudField_GCP(t *testing.T) { + c := &Config{ + Host: "https://my-workspace.gcp.databricks.com", + Cloud: environment.CloudGCP, + } + assert.True(t, c.IsGcp()) + assert.False(t, c.IsAws()) + assert.False(t, c.IsAzure()) +} + +func TestCloudField_FallsBackToEnvironment(t *testing.T) { + c := &Config{ + Host: "https://my-workspace.azuredatabricks.net", + } + assert.True(t, c.IsAzure()) + assert.False(t, c.IsAws()) + assert.False(t, c.IsGcp()) +} + +func TestCloudField_PrefersCloudOverEnvironment(t *testing.T) { + c := &Config{ + Host: "https://my-workspace.cloud.databricks.com", + Cloud: environment.CloudAzure, + } + assert.True(t, c.IsAzure()) + assert.False(t, c.IsAws()) +} + +func TestCloudField_EmptyStringIsCloudUnknown(t *testing.T) { + c := &Config{ + Host: "https://localhost:8080", + Cloud: "", + } + assert.Equal(t, environment.CloudUnknown, c.Cloud) +} diff --git a/config/host_metadata.go b/config/host_metadata.go index 5c3e6e178..246e9979d 100644 --- a/config/host_metadata.go +++ b/config/host_metadata.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "github.com/databricks/databricks-sdk-go/common/environment" "github.com/databricks/databricks-sdk-go/httpclient" ) @@ -18,6 +19,9 @@ type hostMetadata struct { // WorkspaceID is the Databricks workspace ID associated with this host, if available. WorkspaceID string `json:"workspace_id"` + + // Cloud is the cloud provider for this Databricks deployment (AWS, Azure, or GCP). + Cloud environment.Cloud `json:"cloud"` } // getHostMetadata fetches the raw Databricks well-known configuration from diff --git a/config/host_metadata_test.go b/config/host_metadata_test.go index 9a92ecea4..1285a354c 100644 --- a/config/host_metadata_test.go +++ b/config/host_metadata_test.go @@ -7,6 +7,7 @@ import ( "github.com/google/go-cmp/cmp" + "github.com/databricks/databricks-sdk-go/common/environment" "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" ) @@ -30,6 +31,7 @@ func TestGetHostMetadata_WorkspaceStaticOIDCEndpoint(t *testing.T) { "oidc_endpoint": testHMHost + "/oidc", "account_id": testHMAccountID, "workspace_id": testHMWorkspaceID, + "cloud": "AWS", }, }, }) @@ -41,6 +43,7 @@ func TestGetHostMetadata_WorkspaceStaticOIDCEndpoint(t *testing.T) { OIDCEndpoint: testHMHost + "/oidc", AccountID: testHMAccountID, WorkspaceID: testHMWorkspaceID, + Cloud: "AWS", } if diff := cmp.Diff(want, meta); diff != "" { t.Errorf("mismatch (-want +got):\n%s", diff) @@ -82,3 +85,56 @@ func TestGetHostMetadata_HTTPError(t *testing.T) { t.Errorf("expected error containing %q, got %q", "fetching host metadata from", err.Error()) } } + +func TestGetHostMetadata_WithCloudField(t *testing.T) { + tests := []struct { + name string + cloud string + wantCloud environment.Cloud + }{ + { + name: "AWS", + cloud: "AWS", + wantCloud: environment.CloudAWS, + }, + { + name: "Azure", + cloud: "Azure", + wantCloud: environment.CloudAzure, + }, + { + name: "GCP", + cloud: "GCP", + wantCloud: environment.CloudGCP, + }, + { + name: "missing cloud field", + cloud: "", + wantCloud: environment.CloudUnknown, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + response := map[string]string{ + "oidc_endpoint": testHMHost + "/oidc", + "account_id": testHMAccountID, + } + if tc.cloud != "" { + response["cloud"] = tc.cloud + } + client := newTestAPIClient(fixtures.MappingTransport{ + "GET /.well-known/databricks-config": { + Status: 200, + Response: response, + }, + }) + meta, err := getHostMetadata(context.Background(), testHMHost, client) + if err != nil { + t.Fatal(err) + } + if meta.Cloud != tc.wantCloud { + t.Errorf("Cloud field mismatch: got %q, want %q", meta.Cloud, tc.wantCloud) + } + }) + } +}