Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions provider/pkg/provider/auth_azidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func NewAzCoreIdentity(ctx context.Context, authConf *authConfiguration, baseCli
// Create the azcore.TokenCredential implementation based on the auth configuration.
// This routine evaluates the auth configuration and other environment variables,
// and ultimately resolves the Azure cloud and subscription ID.
cred, err := newSingleMethodAuthCredential(authConf, baseClientOpts)
cred, err := newSingleMethodAuthCredential(ctx, authConf, baseClientOpts)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -120,7 +120,7 @@ func NewAzCoreIdentity(ctx context.Context, authConf *authConfiguration, baseCli
// - When a method is configured but instantiating the credential fails, we return an error and do not fall through to
// the next method.
// - Auxiliary or additional tenants are supported for SP with client secret and CLI authentication, not for others.
func newSingleMethodAuthCredential(authConf *authConfiguration, baseClientOpts azcore.ClientOptions) (azcore.TokenCredential, error) {
func newSingleMethodAuthCredential(ctx context.Context, authConf *authConfiguration, baseClientOpts azcore.ClientOptions) (azcore.TokenCredential, error) {
if authConf.useDefault {
logging.V(9).Infof("[auth] Using default Azure credential")
fmtErrorMessage := "A %s must be configured when authenticating using the Default Azure Credential."
Expand Down Expand Up @@ -215,7 +215,18 @@ func newSingleMethodAuthCredential(authConf *authConfiguration, baseClientOpts a
}
// note that the subscription ID is discoverable when using the Azure CLI credential and hence optional.
if authConf.subscriptionId != "" {
options.Subscription = authConf.subscriptionId
// Query the subscription to check if it's the default.
// This avoids triggering a shell quoting bug in the Azure SDK when using the default subscription.
activeSubscription, err := authConf.showSubscription(ctx, authConf.subscriptionId)
if err != nil {
return nil, err
}

// Only pass subscription to SDK if it's not the default subscription.
// When using the default, the SDK will auto-detect it without needing the parameter.
if !activeSubscription.IsDefault {
options.Subscription = authConf.subscriptionId
}
}
return azidentity.NewAzureCLICredential(options)
}
Expand Down
62 changes: 44 additions & 18 deletions provider/pkg/provider/auth_azidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func TestNewCredential(t *testing.T) {
tenantId: "tenant-id",
subscriptionId: "subscription-id",
}
cred, err := newSingleMethodAuthCredential(conf, azcore.ClientOptions{})
cred, err := newSingleMethodAuthCredential(context.Background(), conf, azcore.ClientOptions{})
require.NoError(t, err)
require.IsType(t, &azidentity.ClientSecretCredential{}, cred)
clientVal := reflect.ValueOf(cred).Elem().FieldByName("client")
Expand All @@ -171,7 +171,7 @@ func TestNewCredential(t *testing.T) {
clientSecret: "client-secret",
tenantId: "tenant-id",
}
_, err := newSingleMethodAuthCredential(conf, azcore.ClientOptions{})
_, err := newSingleMethodAuthCredential(context.Background(), conf, azcore.ClientOptions{})
require.Error(t, err)
require.Contains(t, err.Error(), "Subscription")
})
Expand All @@ -182,7 +182,7 @@ func TestNewCredential(t *testing.T) {
clientSecret: "client-secret",
subscriptionId: "subscription-id",
}
_, err := newSingleMethodAuthCredential(conf, azcore.ClientOptions{})
_, err := newSingleMethodAuthCredential(context.Background(), conf, azcore.ClientOptions{})
require.Error(t, err)
require.Contains(t, err.Error(), "Tenant")
})
Expand All @@ -198,7 +198,7 @@ func TestNewCredential(t *testing.T) {
tenantId: "tenant-id",
subscriptionId: "subscription-id",
}
cred, err := newSingleMethodAuthCredential(conf, azcore.ClientOptions{})
cred, err := newSingleMethodAuthCredential(context.Background(), conf, azcore.ClientOptions{})
require.NoError(t, err)
require.IsType(t, &azidentity.ClientCertificateCredential{}, cred)
clientVal := reflect.ValueOf(cred).Elem().FieldByName("client")
Expand All @@ -216,7 +216,7 @@ func TestNewCredential(t *testing.T) {
tenantId: "tenant-id",
subscriptionId: "subscription-id",
}
_, err := newSingleMethodAuthCredential(conf, azcore.ClientOptions{})
_, err := newSingleMethodAuthCredential(context.Background(), conf, azcore.ClientOptions{})
require.Error(t, err)
require.Contains(t, err.Error(), "failed to parse certificate")
})
Expand All @@ -232,7 +232,7 @@ func TestNewCredential(t *testing.T) {
tenantId: "tenant-id",
subscriptionId: "subscription-id",
}
_, err := newSingleMethodAuthCredential(conf, azcore.ClientOptions{})
_, err := newSingleMethodAuthCredential(context.Background(), conf, azcore.ClientOptions{})
require.Error(t, err)
require.Contains(t, err.Error(), "failed to parse certificate")
require.Contains(t, err.Error(), "password incorrect")
Expand All @@ -246,7 +246,7 @@ func TestNewCredential(t *testing.T) {
tenantId: "tenant-id",
subscriptionId: "subscription-id",
}
cred, err := newSingleMethodAuthCredential(conf, azcore.ClientOptions{})
cred, err := newSingleMethodAuthCredential(context.Background(), conf, azcore.ClientOptions{})
require.NoError(t, err)
require.IsType(t, &azidentity.ClientAssertionCredential{}, cred)
clientVal := reflect.ValueOf(cred).Elem().FieldByName("client")
Expand All @@ -265,7 +265,7 @@ func TestNewCredential(t *testing.T) {
tenantId: "tenant-id",
subscriptionId: "subscription-id",
}
cred, err := newSingleMethodAuthCredential(conf, azcore.ClientOptions{})
cred, err := newSingleMethodAuthCredential(context.Background(), conf, azcore.ClientOptions{})
require.NoError(t, err)
require.IsType(t, &azidentity.ClientAssertionCredential{}, cred)
})
Expand All @@ -278,7 +278,7 @@ func TestNewCredential(t *testing.T) {
tenantId: "tenant-id",
subscriptionId: "subscription-id",
}
_, err := newSingleMethodAuthCredential(conf, azcore.ClientOptions{})
_, err := newSingleMethodAuthCredential(context.Background(), conf, azcore.ClientOptions{})
require.Error(t, err)
require.ErrorIs(t, err, os.ErrNotExist)
})
Expand All @@ -292,7 +292,7 @@ func TestNewCredential(t *testing.T) {
tenantId: "tenant-id",
subscriptionId: "subscription-id",
}
cred, err := newSingleMethodAuthCredential(conf, azcore.ClientOptions{})
cred, err := newSingleMethodAuthCredential(context.Background(), conf, azcore.ClientOptions{})
require.NoError(t, err)
require.IsType(t, &azidentity.ClientAssertionCredential{}, cred)
})
Expand All @@ -307,7 +307,7 @@ func TestNewCredential(t *testing.T) {
tenantId: "tenant-id",
subscriptionId: "subscription-id",
}
cred, err := newSingleMethodAuthCredential(conf, azcore.ClientOptions{})
cred, err := newSingleMethodAuthCredential(context.Background(), conf, azcore.ClientOptions{})
require.NoError(t, err)
require.IsType(t, &azidentity.ClientAssertionCredential{}, cred)
})
Expand All @@ -334,7 +334,7 @@ func TestNewCredential(t *testing.T) {
tenantId: "tenant-id",
},
} {
_, err := newSingleMethodAuthCredential(conf, azcore.ClientOptions{})
_, err := newSingleMethodAuthCredential(context.Background(), conf, azcore.ClientOptions{})
require.Error(t, err)
}
})
Expand All @@ -344,7 +344,7 @@ func TestNewCredential(t *testing.T) {
useMsi: true,
subscriptionId: "subscription-id",
}
cred, err := newSingleMethodAuthCredential(conf, azcore.ClientOptions{})
cred, err := newSingleMethodAuthCredential(context.Background(), conf, azcore.ClientOptions{})
require.NoError(t, err)
require.IsType(t, &azidentity.ManagedIdentityCredential{}, cred)
})
Expand All @@ -356,15 +356,15 @@ func TestNewCredential(t *testing.T) {
clientId: "123",
subscriptionId: "subscription-id",
}
cred, err := newSingleMethodAuthCredential(conf, azcore.ClientOptions{})
cred, err := newSingleMethodAuthCredential(context.Background(), conf, azcore.ClientOptions{})
require.NoError(t, err)
require.IsType(t, &azidentity.ManagedIdentityCredential{}, cred)
// FUTURE: assert cred.client.id = "123"
})

t.Run("CLI", func(t *testing.T) {
conf := &authConfiguration{}
cred, err := newSingleMethodAuthCredential(conf, azcore.ClientOptions{})
cred, err := newSingleMethodAuthCredential(context.Background(), conf, azcore.ClientOptions{})
require.NoError(t, err)
require.IsType(t, &azidentity.AzureCLICredential{}, cred)
})
Expand All @@ -373,7 +373,7 @@ func TestNewCredential(t *testing.T) {
conf := &authConfiguration{
auxTenants: []string{"123", "456"},
}
cred, err := newSingleMethodAuthCredential(conf, azcore.ClientOptions{})
cred, err := newSingleMethodAuthCredential(context.Background(), conf, azcore.ClientOptions{})
require.NoError(t, err)
require.IsType(t, &azidentity.AzureCLICredential{}, cred)
optsVal := reflect.ValueOf(cred).Elem().FieldByName("opts")
Expand All @@ -384,20 +384,46 @@ func TestNewCredential(t *testing.T) {
t.Run("CLI with subscription id", func(t *testing.T) {
conf := &authConfiguration{
subscriptionId: "subscription-id",
// Mock showSubscription to return a non-default subscription
showSubscription: func(ctx context.Context, subscriptionID string) (*Subscription, error) {
return &Subscription{
ID: "subscription-id",
IsDefault: false, // Non-default subscription should be passed to SDK
}, nil
},
}
cred, err := newSingleMethodAuthCredential(conf, azcore.ClientOptions{})
cred, err := newSingleMethodAuthCredential(context.Background(), conf, azcore.ClientOptions{})
require.NoError(t, err)
require.IsType(t, &azidentity.AzureCLICredential{}, cred)
optsVal := reflect.ValueOf(cred).Elem().FieldByName("opts")
require.Equal(t, "subscription-id", optsVal.FieldByName("Subscription").String())
})

t.Run("CLI with default subscription id", func(t *testing.T) {
conf := &authConfiguration{
subscriptionId: "subscription-id",
// Mock showSubscription to return the default subscription
showSubscription: func(ctx context.Context, subscriptionID string) (*Subscription, error) {
return &Subscription{
ID: "subscription-id",
IsDefault: true, // Default subscription should NOT be passed to SDK
}, nil
},
}
cred, err := newSingleMethodAuthCredential(context.Background(), conf, azcore.ClientOptions{})
require.NoError(t, err)
require.IsType(t, &azidentity.AzureCLICredential{}, cred)
optsVal := reflect.ValueOf(cred).Elem().FieldByName("opts")
// When it's the default subscription, we should NOT pass it to the SDK
require.Equal(t, "", optsVal.FieldByName("Subscription").String())
})

t.Run("Azure Default Credential", func(t *testing.T) {
conf := &authConfiguration{
useDefault: true,
subscriptionId: "subscription-id",
}
cred, err := newSingleMethodAuthCredential(conf, azcore.ClientOptions{})
cred, err := newSingleMethodAuthCredential(context.Background(), conf, azcore.ClientOptions{})
require.NoError(t, err)
require.IsType(t, &azidentity.DefaultAzureCredential{}, cred)
})
Expand Down
22 changes: 5 additions & 17 deletions provider/pkg/provider/azure_cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"encoding/json"
"os"
"os/exec"
"runtime"
"strings"
"time"

Expand Down Expand Up @@ -52,25 +51,14 @@ type azSubscriptionProvider func(ctx context.Context, subscriptionID string) (*S
// this code is derived from "CLI token provider" code in the Azure SDK for Go:
// https://github.com/Azure/azure-sdk-for-go/blob/519e8ab1a0e433b755c31ebaa6b177dfc83cb838/sdk/azidentity/azure_cli_credential.go#L117-L172
var defaultAzSubscriptionProvider = func(ctx context.Context, subscriptionID string) (*Subscription, error) {
commandLine := "az account show -o json "
// Build command arguments as array to avoid shell quoting issues
args := []string{"account", "show", "-o", "json"}
if subscriptionID != "" {
// subscription needs quotes because it may contain spaces
commandLine += ` --subscription "` + subscriptionID + `"`
args = append(args, "--subscription", subscriptionID)
}
logging.V(9).Infof("Running command: %s", commandLine)
logging.V(9).Infof("Running command: az %s", strings.Join(args, " "))

var cliCmd *exec.Cmd
if runtime.GOOS == "windows" {
dir := os.Getenv("SYSTEMROOT")
if dir == "" {
return nil, newSubscriptionUnavailableError("environment variable 'SYSTEMROOT' has no value")
}
cliCmd = exec.CommandContext(ctx, "cmd.exe", "/c", commandLine)
cliCmd.Dir = dir
} else {
cliCmd = exec.CommandContext(ctx, "/bin/sh", "-c", commandLine)
cliCmd.Dir = "/bin"
}
cliCmd := exec.CommandContext(ctx, "az", args...)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh much cleaner!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As described in overview, I cribbed the original code from Azure SDK. Why, I wonder, did they use a shell?

cliCmd.Env = os.Environ()
var stdout, stderr bytes.Buffer
cliCmd.Stderr = &stderr
Expand Down
Loading