diff --git a/wrappers/azurekeyvault/azurekeyvault.go b/wrappers/azurekeyvault/azurekeyvault.go index 68cf7f8a..ace6494e 100644 --- a/wrappers/azurekeyvault/azurekeyvault.go +++ b/wrappers/azurekeyvault/azurekeyvault.go @@ -14,6 +14,7 @@ import ( "net/http" "os" "strings" + "sync" "sync/atomic" "time" @@ -34,6 +35,7 @@ const ( EnvAzureKeyVaultWrapperKeyName = "AZUREKEYVAULT_WRAPPER_KEY_NAME" EnvVaultAzureKeyVaultKeyName = "VAULT_AZUREKEYVAULT_KEY_NAME" + EnvAzureClientId = "AZURE_CLIENT_ID" ) // Wrapper is an Wrapper that uses Azure Key Vault @@ -94,8 +96,8 @@ func (v *Wrapper) SetConfig(ctx context.Context, opt ...wrapping.Option) (*wrapp } switch { - case os.Getenv("AZURE_CLIENT_ID") != "" && !opts.withDisallowEnvVars: - v.clientID = os.Getenv("AZURE_CLIENT_ID") + case os.Getenv(EnvAzureClientId) != "" && !opts.withDisallowEnvVars: + v.clientID = os.Getenv(EnvAzureClientId) case opts.withClientId != "": v.clientID = opts.withClientId } @@ -287,6 +289,8 @@ func (v *Wrapper) buildBaseURL() string { return fmt.Sprintf("https://%s.%s/", v.vaultName, v.environment.KeyVaultDNSSuffix) } +var managedClientIdLock sync.Mutex + func (v *Wrapper) getKeyVaultClient(withCertPool *x509.CertPool) (*azkeys.Client, error) { var err error var cred azcore.TokenCredential @@ -299,14 +303,28 @@ func (v *Wrapper) getKeyVaultClient(withCertPool *x509.CertPool) (*azkeys.Client return nil, fmt.Errorf("failed to get client secret credentials %w", err) } case v.clientID != "": - // Try a managed service credential with a specified client id - clientID := azidentity.ClientID(v.clientID) - cred, err = azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{ - ID: clientID, - }) + // Some hoops to jump through to make sure two wrappers being setup at the same time don't step on the + // env var + managedClientIdLock.Lock() + oldVal, found := os.LookupEnv(EnvAzureClientId) + unlock := func() { + if found { + os.Setenv(EnvAzureClientId, oldVal) + } else { + os.Unsetenv(EnvAzureClientId) + } + managedClientIdLock.Unlock() + } + + // This could be a managed identity auth, so supply the default credential provider with clientId and let it + // figure it out. Sort of a hack, but Azure's library doesn't allow us to specify clientID as an option. + os.Setenv(EnvAzureClientId, v.clientID) + cred, err = azidentity.NewDefaultAzureCredential(nil) if err != nil { - return nil, fmt.Errorf("failed to get managed identity credentials: %w", err) + unlock() + return nil, fmt.Errorf("failed to acquire managed identity credentials %w", err) } + unlock() // By default let Azure select existing credentials default: cred, err = azidentity.NewDefaultAzureCredential(nil)