Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions common/environment/environments.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,48 @@
package environment

import (
"encoding/json"
"fmt"
"strings"
)

type Cloud string

const (
CloudAWS Cloud = "AWS"
CloudAzure Cloud = "Azure"
CloudGCP Cloud = "GCP"
CloudAWS Cloud = "AWS"
CloudAzure Cloud = "Azure"
CloudGCP Cloud = "GCP"
CloudUnknown Cloud = ""
)

// UnmarshalJSON automatically normalizes the cloud string when parsing JSON
func (c *Cloud) UnmarshalJSON(data []byte) error {
var rawString string

if err := json.Unmarshal(data, &rawString); err != nil {
return err
}

*c = normalizeCloud(rawString)
return nil
}

func normalizeCloud(cloud string) Cloud {
switch strings.ToUpper(cloud) {
case "AWS":
return CloudAWS
case "AZURE":
return CloudAzure
case "GCP":
return CloudGCP
case "":
return CloudUnknown
// For forward compatibility with new cloud providers that are not yet supported by the SDK.
default:
return Cloud(cloud)
}
}

type DatabricksEnvironment struct {
Cloud Cloud
DnsZone string
Expand Down Expand Up @@ -51,7 +81,6 @@ func DefaultEnvironment() DatabricksEnvironment {
Cloud: CloudAWS,
DnsZone: ".cloud.databricks.com",
}

}

var envs = []DatabricksEnvironment{
Expand Down
26 changes: 24 additions & 2 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,11 @@ type Config struct {
// HTTPTransport can be overriden for unit testing and together with tooling like https://github.com/google/go-replayers
HTTPTransport http.RoundTripper

// Environment override to return when resolving the current environment.
// Cloud is the cloud provider for this Databricks deployment (AWS, Azure, GCP, or CloudUnknown).
//
// Experimental: subject to change.
Cloud environment.Cloud `name:"cloud" env:"DATABRICKS_CLOUD" auth:"-"`

DatabricksEnvironment *environment.DatabricksEnvironment

// When using Workload Identity Federation, the audience to specify when fetching an ID token from the ID token supplier.
Expand Down Expand Up @@ -352,16 +356,25 @@ func (c *Config) IsAzure() bool {
if c.AzureResourceID != "" {
return true
}
if c.Cloud != environment.CloudUnknown {
return c.Cloud == environment.CloudAzure
}
return c.Environment().Cloud == environment.CloudAzure
}

// IsGcp returns if the client is configured for Databricks on Google Cloud.
func (c *Config) IsGcp() bool {
if c.Cloud != environment.CloudUnknown {
return c.Cloud == environment.CloudGCP
}
return c.Environment().Cloud == environment.CloudGCP
}

// IsAws returns if the client is configured for Databricks on AWS.
func (c *Config) IsAws() bool {
if c.Cloud != environment.CloudUnknown {
return c.Cloud == environment.CloudAWS
}
return c.Host != "" && !c.IsAzure() && !c.IsGcp()
}

Expand Down Expand Up @@ -643,7 +656,7 @@ func (c *Config) getOidcEndpoints(ctx context.Context) (*u2m.OAuthAuthorizationS

// resolveHostMetadata populates missing config fields from the host's
// /.well-known/databricks-config discovery endpoint. It back-fills AccountID,
// WorkspaceID, and DiscoveryURL (with any {account_id} placeholder substituted)
// WorkspaceID, Cloud, and DiscoveryURL (with any {account_id} placeholder substituted)
// if those fields are not already set. Returns an error if AccountID cannot be
// resolved or no oidc_endpoint is present in the metadata.
//
Expand All @@ -657,6 +670,7 @@ func (c *Config) resolveHostMetadata(ctx context.Context) error {
}
meta, err := getHostMetadata(ctx, c.CanonicalHostName(), c.refreshClient)
if err != nil {
logger.Debugf(ctx, "Failed to fetch host metadata: %v", err)
return err
}
if c.AccountID == "" && meta.AccountID != "" {
Expand All @@ -670,6 +684,14 @@ func (c *Config) resolveHostMetadata(ctx context.Context) error {
logger.Debugf(ctx, "Resolved workspace_id from host metadata: %q", meta.WorkspaceID)
c.WorkspaceID = meta.WorkspaceID
}
if c.Cloud == "" && meta.Cloud != environment.CloudUnknown {
logger.Debugf(ctx, "Resolved cloud from host metadata: %q", meta.Cloud)
c.Cloud = meta.Cloud
}
if c.Cloud == "" {
c.Cloud = c.Environment().Cloud
logger.Debugf(ctx, "Resolved cloud from hostname: %q", c.Cloud)
}
if c.DiscoveryURL == "" {
if meta.OIDCEndpoint == "" {
return fmt.Errorf("discovery_url is not configured and could not be resolved from host metadata")
Expand Down
155 changes: 155 additions & 0 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -666,3 +666,158 @@ func TestConfig_ResolveHostMetadata_NoHost(t *testing.T) {
t.Fatal(err)
}
}

func TestConfig_ResolveHostMetadata_PopulatesCloudFromAPI(t *testing.T) {
cfg := &Config{
Host: testHMHost,
Token: "t",
HTTPTransport: fixtures.SliceTransport{
{
Method: "GET",
Resource: "/.well-known/databricks-config",
Status: 200,
Response: `{"oidc_endpoint": "` + testHMHost + `/oidc", "account_id": "` + testHMAccountID + `", "cloud": "Azure"}`,
},
},
}
if err := cfg.resolveHostMetadata(context.Background()); err != nil {
t.Fatal(err)
}
if cfg.Cloud != "Azure" {
t.Errorf("unexpected Cloud: %q", cfg.Cloud)
}
}

func TestConfig_ResolveHostMetadata_CloudFallbackToDNS(t *testing.T) {
cfg := &Config{
Host: "https://my-workspace.azuredatabricks.net",
Token: "t",
HTTPTransport: fixtures.SliceTransport{
{
Method: "GET",
Resource: "/.well-known/databricks-config",
Status: 200,
Response: `{"oidc_endpoint": "https://my-workspace.azuredatabricks.net/oidc", "account_id": "` + testHMAccountID + `"}`,
},
},
}
if err := cfg.resolveHostMetadata(context.Background()); err != nil {
t.Fatal(err)
}
if cfg.Cloud != "Azure" {
t.Errorf("unexpected Cloud from DNS fallback: %q", cfg.Cloud)
}
}

func TestConfig_ResolveHostMetadata_DoesNotOverwriteExistingCloud(t *testing.T) {
cfg := &Config{
Host: testHMHost,
Token: "t",
Cloud: "GCP",
HTTPTransport: fixtures.SliceTransport{
{
Method: "GET",
Resource: "/.well-known/databricks-config",
Status: 200,
Response: `{"oidc_endpoint": "` + testHMHost + `/oidc", "account_id": "` + testHMAccountID + `", "cloud": "AWS"}`,
},
},
}
if err := cfg.resolveHostMetadata(context.Background()); err != nil {
t.Fatal(err)
}
if cfg.Cloud != "GCP" {
t.Errorf("Cloud was overwritten: got %q, want %q", cfg.Cloud, "GCP")
}
}

func TestConfig_ResolveHostMetadata_Clouds(t *testing.T) {
tests := []struct {
name string
cloudJSON string
wantCloud string
}{
{
name: "AWS",
cloudJSON: "AWS",
wantCloud: "AWS",
},
{
name: "Azure",
cloudJSON: "Azure",
wantCloud: "Azure",
},
{
name: "GCP",
cloudJSON: "GCP",
wantCloud: "GCP",
},
{
name: "aws lowercase",
cloudJSON: "aws",
wantCloud: "AWS",
},
{
name: "AWS uppercase",
cloudJSON: "AWS",
wantCloud: "AWS",
},
{
name: "azure lowercase",
cloudJSON: "azure",
wantCloud: "Azure",
},
{
name: "AZURE uppercase",
cloudJSON: "AZURE",
wantCloud: "Azure",
},
{
name: "Azure title case",
cloudJSON: "Azure",
wantCloud: "Azure",
},
{
name: "gcp lowercase",
cloudJSON: "gcp",
wantCloud: "GCP",
},
{
name: "GCP uppercase",
cloudJSON: "GCP",
wantCloud: "GCP",
},
{
name: "Another cloud is supported",
cloudJSON: "Another",
wantCloud: "Another",
},
{
name: "Unknown cloud falls back to DNS",
cloudJSON: "",
wantCloud: "AWS", // Falls back to DNS-based detection for testHMHost
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
cfg := &Config{
Host: testHMHost,
Token: "t",
HTTPTransport: fixtures.SliceTransport{
{
Method: "GET",
Resource: "/.well-known/databricks-config",
Status: 200,
Response: `{"oidc_endpoint": "` + testHMHost + `/oidc", "account_id": "` + testHMAccountID + `", "cloud": "` + tc.cloudJSON + `"}`,
},
},
}
if err := cfg.resolveHostMetadata(context.Background()); err != nil {
t.Fatal(err)
}
if string(cfg.Cloud) != tc.wantCloud {
t.Errorf("unexpected Cloud: got %q, want %q", cfg.Cloud, tc.wantCloud)
}
})
}
}
3 changes: 3 additions & 0 deletions config/environments.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ import (
"github.com/databricks/databricks-sdk-go/common/environment"
)

// Deprecated: Use the Cloud field and cloud-specific helper methods (IsAws, IsAzure, IsGcp)
// instead. Environment() returns environment metadata including cloud type and Azure-specific
// endpoints.
func (c *Config) Environment() environment.DatabricksEnvironment {
// Use the provided environment if specified. Tests may configure the client with different hostnames,
// like localhost, which are not resolvable to a known environment, while needing to mock a specific environment.
Expand Down
56 changes: 56 additions & 0 deletions config/environments_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,59 @@ func TestOverriddenEnvironmentIsReturned(t *testing.T) {
}
assert.Equal(t, "holla", c.Environment().DnsZone)
}

func TestCloudField_AWS(t *testing.T) {
c := &Config{
Host: "https://my-workspace.cloud.databricks.com",
Cloud: environment.CloudAWS,
}
assert.True(t, c.IsAws())
assert.False(t, c.IsAzure())
assert.False(t, c.IsGcp())
}

func TestCloudField_Azure(t *testing.T) {
c := &Config{
Host: "https://my-workspace.azuredatabricks.net",
Cloud: environment.CloudAzure,
}
assert.True(t, c.IsAzure())
assert.False(t, c.IsAws())
assert.False(t, c.IsGcp())
}

func TestCloudField_GCP(t *testing.T) {
c := &Config{
Host: "https://my-workspace.gcp.databricks.com",
Cloud: environment.CloudGCP,
}
assert.True(t, c.IsGcp())
assert.False(t, c.IsAws())
assert.False(t, c.IsAzure())
}

func TestCloudField_FallsBackToEnvironment(t *testing.T) {
c := &Config{
Host: "https://my-workspace.azuredatabricks.net",
}
assert.True(t, c.IsAzure())
assert.False(t, c.IsAws())
assert.False(t, c.IsGcp())
}

func TestCloudField_PrefersCloudOverEnvironment(t *testing.T) {
c := &Config{
Host: "https://my-workspace.cloud.databricks.com",
Cloud: environment.CloudAzure,
}
assert.True(t, c.IsAzure())
assert.False(t, c.IsAws())
}

func TestCloudField_EmptyStringIsCloudUnknown(t *testing.T) {
c := &Config{
Host: "https://localhost:8080",
Cloud: "",
}
assert.Equal(t, environment.CloudUnknown, c.Cloud)
}
4 changes: 4 additions & 0 deletions config/host_metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"

"github.com/databricks/databricks-sdk-go/common/environment"
"github.com/databricks/databricks-sdk-go/httpclient"
)

Expand All @@ -18,6 +19,9 @@ type hostMetadata struct {

// WorkspaceID is the Databricks workspace ID associated with this host, if available.
WorkspaceID string `json:"workspace_id"`

// Cloud is the cloud provider for this Databricks deployment (AWS, Azure, or GCP).
Cloud environment.Cloud `json:"cloud"`
}

// getHostMetadata fetches the raw Databricks well-known configuration from
Expand Down
Loading
Loading