Skip to content
4 changes: 4 additions & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
### Bugs Fixed

### Other Changes
* `ManagedIdentityCredential` no longer probes IMDS before requesting a token
from it. Also, an error response from IMDS no longer disables a credential
instance. Following an error, a credential instance will continue to send
requests to IMDS as necessary.

## 0.12.0 (2021-11-02)
### Breaking Changes
Expand Down
2 changes: 2 additions & 0 deletions sdk/azidentity/default_azure_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package azidentity
import (
"context"
"errors"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
Expand Down Expand Up @@ -56,6 +57,7 @@ func NewDefaultAzureCredential(options *DefaultAzureCredentialOptions) (*Default
msiCred, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ClientOptions: options.ClientOptions})
if err == nil {
creds = append(creds, msiCred)
msiCred.client.imdsTimeout = time.Second
} else {
errMsg += err.Error()
}
Expand Down
43 changes: 0 additions & 43 deletions sdk/azidentity/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package azidentity

import (
"fmt"
"os"
"strings"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
Expand All @@ -19,30 +18,6 @@ import (
// used when obtaining credentials and the type of credential used.
const EventAuthentication log.Event = "Authentication"

// log environment variables that can be used for credential types
func logEnvVars() {
if !log.Should(EventAuthentication) {
return
}
// Log available environment variables
envVars := []string{}
if envCheck := os.Getenv("AZURE_TENANT_ID"); len(envCheck) > 0 {
envVars = append(envVars, "AZURE_TENANT_ID")
}
if envCheck := os.Getenv("AZURE_CLIENT_ID"); len(envCheck) > 0 {
envVars = append(envVars, "AZURE_CLIENT_ID")
}
if envCheck := os.Getenv("AZURE_CLIENT_SECRET"); len(envCheck) > 0 {
envVars = append(envVars, "AZURE_CLIENT_SECRET")
}
if envCheck := os.Getenv(azureAuthorityHost); len(envCheck) > 0 {
envVars = append(envVars, azureAuthorityHost)
}
if len(envVars) > 0 {
log.Writef(EventAuthentication, "Azure Identity => Found the following environment variables:\n\t%s", strings.Join(envVars, ", "))
}
}

func logGetTokenSuccess(cred azcore.TokenCredential, opts policy.TokenRequestOptions) {
if !log.Should(EventAuthentication) {
return
Expand All @@ -56,24 +31,6 @@ func logCredentialError(credName string, err error) {
log.Writef(EventAuthentication, "Azure Identity => ERROR in %s: %s", credName, err.Error())
}

func logMSIEnv(msi msiType) {
if !log.Should(EventAuthentication) {
return
}
var msg string
switch msi {
case msiTypeIMDS:
msg = "Azure Identity => Managed Identity environment: IMDS"
case msiTypeAppServiceV20170901, msiTypeCloudShell, msiTypeAppServiceV20190801:
msg = "Azure Identity => Managed Identity environment: MSI_ENDPOINT"
case msiTypeUnavailable:
msg = "Azure Identity => Managed Identity environment: Unavailable"
default:
msg = "Azure Identity => Managed Identity environment: Unknown"
}
log.Write(EventAuthentication, msg)
}

func addGetTokenFailureLogs(credName string, err error, includeStack bool) {
if !log.Should(EventAuthentication) {
return
Expand Down
170 changes: 68 additions & 102 deletions sdk/azidentity/managed_identity_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,16 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
)

const (
headerMetadata = "Metadata"
imdsEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
)

const (
arcIMDSEndpoint = "IMDS_ENDPOINT"
identityEndpoint = "IDENTITY_ENDPOINT"
identityHeader = "IDENTITY_HEADER"
identityServerThumbprint = "IDENTITY_SERVER_THUMBPRINT"
headerMetadata = "Metadata"
imdsEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
msiEndpoint = "MSI_ENDPOINT"
msiSecret = "MSI_SECRET"
imdsAPIVersion = "2018-02-01"
Expand All @@ -45,26 +43,22 @@ const (
type msiType int

const (
msiTypeUnknown msiType = 0
msiTypeIMDS msiType = 1
msiTypeAppServiceV20170901 msiType = 2
msiTypeCloudShell msiType = 3
msiTypeUnavailable msiType = 4
msiTypeAppServiceV20190801 msiType = 5
msiTypeAzureArc msiType = 6
msiTypeServiceFabric msiType = 7
msiTypeAppServiceV20170901 msiType = iota
msiTypeAppServiceV20190801
msiTypeAzureArc
msiTypeCloudShell
msiTypeIMDS
msiTypeServiceFabric
)

// managedIdentityClient provides the base for authenticating in managed identity environments
// This type includes an runtime.Pipeline and TokenCredentialOptions.
type managedIdentityClient struct {
pipeline runtime.Pipeline
imdsAPIVersion string
imdsAvailableTimeout time.Duration
msiType msiType
endpoint string
id ManagedIDKind
unavailableMessage string
pipeline runtime.Pipeline
msiType msiType
endpoint string
id ManagedIDKind
imdsTimeout time.Duration
}

type wrappedNumber json.Number
Expand All @@ -77,8 +71,8 @@ func (n *wrappedNumber) UnmarshalJSON(b []byte) error {
return json.Unmarshal(b, (*json.Number)(n))
}

// setRetryOptionDefaults sets zero-valued fields to default values appropriate for IMDS
func setRetryOptionDefaults(o *policy.RetryOptions) {
// setIMDSRetryOptionDefaults sets zero-valued fields to default values appropriate for IMDS
func setIMDSRetryOptionDefaults(o *policy.RetryOptions) {
if o.MaxRetries == 0 {
o.MaxRetries = 5
}
Expand Down Expand Up @@ -111,36 +105,60 @@ func setRetryOptionDefaults(o *policy.RetryOptions) {
}
}

// newDefaultMSIPipeline creates a pipeline using the specified pipeline options needed
// for a Managed Identity, such as a MSI specific retry policy.
func newDefaultMSIPipeline(o ManagedIdentityCredentialOptions) runtime.Pipeline {
cp := o.ClientOptions
setRetryOptionDefaults(&cp.Retry)
return runtime.NewPipeline(component, version, runtime.PipelineOptions{}, &cp)
}

// newManagedIdentityClient creates a new instance of the ManagedIdentityClient with the ManagedIdentityCredentialOptions
// that are passed into it along with a default pipeline.
// options: ManagedIdentityCredentialOptions configure policies for the pipeline and the authority host that
// will be used to retrieve tokens and authenticate
func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) *managedIdentityClient {
logEnvVars()
return &managedIdentityClient{
id: options.ID,
pipeline: newDefaultMSIPipeline(*options), // a pipeline that includes the specific requirements for MSI authentication, such as custom retry policy options
imdsAPIVersion: imdsAPIVersion, // this field will be set to whatever value exists in the constant and is used when creating requests to IMDS
imdsAvailableTimeout: 500 * time.Millisecond, // we allow a timeout of 500 ms since the endpoint might be slow to respond
msiType: msiTypeUnknown, // when creating a new managedIdentityClient, the current MSI type is unknown and will be tested for and replaced once authenticate() is called from GetToken on the credential side
func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) (*managedIdentityClient, error) {
if options == nil {
options = &ManagedIdentityCredentialOptions{}
}
cp := options.ClientOptions
c := managedIdentityClient{id: options.ID, endpoint: imdsEndpoint, msiType: msiTypeIMDS}
env := "IMDS"
if endpoint, ok := os.LookupEnv(msiEndpoint); ok {
if _, ok := os.LookupEnv(msiSecret); ok {
env = "App Service"
c.endpoint = endpoint
c.msiType = msiTypeAppServiceV20170901
} else {
env = "Cloud Shell"
c.endpoint = endpoint
c.msiType = msiTypeCloudShell
}
} else if endpoint, ok := os.LookupEnv(identityEndpoint); ok {
if _, ok := os.LookupEnv(identityHeader); ok {
if _, ok := os.LookupEnv(identityServerThumbprint); ok {
env = "Service Fabric"
c.endpoint = endpoint
c.msiType = msiTypeServiceFabric
}
} else if _, ok := os.LookupEnv(arcIMDSEndpoint); ok {
env = "Azure Arc"
c.endpoint = endpoint
c.msiType = msiTypeAzureArc
}
} else {
setIMDSRetryOptionDefaults(&cp.Retry)
}
c.pipeline = runtime.NewPipeline(component, version, runtime.PipelineOptions{}, &cp)

if log.Should(EventAuthentication) {
log.Writef(EventAuthentication, "Azure Identity => Managed Identity Credential will use %s managed identity", env)
}

return &c, nil
}

// authenticate creates an authentication request for a Managed Identity and returns the resulting Access Token if successful.
// ctx: The current context for controlling the request lifetime.
// clientID: The client (application) ID of the service principal.
// scopes: The scopes required for the token.
func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKind, scopes []string) (*azcore.AccessToken, error) {
if len(c.unavailableMessage) > 0 {
return nil, newCredentialUnavailableError("Managed Identity Credential", c.unavailableMessage)
var cancel context.CancelFunc
if c.imdsTimeout > 0 && c.msiType == msiTypeIMDS {
ctx, cancel = context.WithTimeout(ctx, c.imdsTimeout)
defer cancel()
}

msg, err := c.createAuthRequest(ctx, id, scopes)
Expand All @@ -150,9 +168,15 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi

resp, err := c.pipeline.Do(msg)
if err != nil {
return nil, err
if cancel != nil && errors.Is(err, context.DeadlineExceeded) {
return nil, newCredentialUnavailableError("Managed Identity Credential", "IMDS token request timed out")
}
return nil, newAuthenticationFailedError(err, nil)
}

// got a response, remove the IMDS timeout so future requests use the transport's configuration
c.imdsTimeout = 0

if runtime.HasStatusCode(resp, http.StatusOK, http.StatusCreated) {
return c.createAccessToken(resp)
}
Expand All @@ -161,8 +185,7 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi
if id != nil {
return nil, newAuthenticationFailedError(errors.New("the requested identity isn't assigned to this resource"), resp)
}
c.unavailableMessage = "No default identity is assigned to this resource."
return nil, newCredentialUnavailableError("Managed Identity Credential", c.unavailableMessage)
return nil, newCredentialUnavailableError("Managed Identity Credential", "no default identity is assigned to this resource")
}

return nil, newAuthenticationFailedError(errors.New("authentication failed"), resp)
Expand Down Expand Up @@ -229,15 +252,7 @@ func (c *managedIdentityClient) createAuthRequest(ctx context.Context, id Manage
case msiTypeCloudShell:
return c.createCloudShellAuthRequest(ctx, id, scopes)
default:
errorMsg := ""
switch c.msiType {
case msiTypeUnavailable:
errorMsg = "unavailable"
default:
errorMsg = "unknown"
}
c.unavailableMessage = "managed identity support is " + errorMsg
return nil, newCredentialUnavailableError("Managed Identity Credential", c.unavailableMessage)
return nil, newCredentialUnavailableError("Managed Identity Credential", "managed identity isn't supported in this environment")
}
}

Expand All @@ -248,7 +263,7 @@ func (c *managedIdentityClient) createIMDSAuthRequest(ctx context.Context, id Ma
}
request.Raw().Header.Set(headerMetadata, "true")
q := request.Raw().URL.Query()
q.Add("api-version", c.imdsAPIVersion)
q.Add("api-version", imdsAPIVersion)
q.Add("resource", strings.Join(scopes, " "))
if id != nil {
if id.idKind() == miResourceID {
Expand Down Expand Up @@ -383,52 +398,3 @@ func (c *managedIdentityClient) createCloudShellAuthRequest(ctx context.Context,
}
return request, nil
}

func (c *managedIdentityClient) getMSIType() (msiType, error) {
if c.msiType == msiTypeUnknown { // if we haven't already determined the msiType
if endpointEnvVar := os.Getenv(msiEndpoint); endpointEnvVar != "" { // if the env var MSI_ENDPOINT is set
c.endpoint = endpointEnvVar
if secretEnvVar := os.Getenv(msiSecret); secretEnvVar != "" { // if BOTH the env vars MSI_ENDPOINT and MSI_SECRET are set the msiType is AppService
c.msiType = msiTypeAppServiceV20170901
} else { // if ONLY the env var MSI_ENDPOINT is set the msiType is CloudShell
c.msiType = msiTypeCloudShell
}
} else if endpointEnvVar := os.Getenv(identityEndpoint); endpointEnvVar != "" { // check for IDENTITY_ENDPOINT
c.endpoint = endpointEnvVar
if header := os.Getenv(identityHeader); header != "" { // if BOTH the env vars IDENTITY_ENDPOINT and IDENTITY_HEADER are set the msiType is AppService
c.msiType = msiTypeAppServiceV20190801
if thumbprint := os.Getenv(identityServerThumbprint); thumbprint != "" { // if IDENTITY_SERVER_THUMBPRINT is set the environment is Service Fabric
c.msiType = msiTypeServiceFabric
}
} else if arcIMDS := os.Getenv(arcIMDSEndpoint); arcIMDS != "" {
c.msiType = msiTypeAzureArc
} else {
c.msiType = msiTypeUnavailable
return c.msiType, newCredentialUnavailableError("Managed Identity Credential", "this environment is not supported")
}
} else if c.imdsAvailable() { // if MSI_ENDPOINT is NOT set AND the IMDS endpoint is available the msiType is IMDS. This will timeout after 500 milliseconds
c.endpoint = imdsEndpoint
c.msiType = msiTypeIMDS
} else { // if MSI_ENDPOINT is NOT set and IMDS endpoint is not available Managed Identity is not available
c.msiType = msiTypeUnavailable
return c.msiType, newCredentialUnavailableError("Managed Identity Credential", "no managed identity endpoint is available")
}
}
return c.msiType, nil
}

// performs an I/O request that has a timeout of 500 milliseconds
func (c *managedIdentityClient) imdsAvailable() bool {
tempCtx, cancel := context.WithTimeout(context.Background(), c.imdsAvailableTimeout)
defer cancel()
// this should never fail
request, _ := runtime.NewRequest(tempCtx, http.MethodGet, imdsEndpoint)
q := request.Raw().URL.Query()
q.Add("api-version", c.imdsAPIVersion)
request.Raw().URL.RawQuery = q.Encode()
resp, err := c.pipeline.Do(request)
if err == nil {
runtime.Drain(resp)
}
return err == nil
}
Loading