diff --git a/sdk/azidentity/managed_identity_client.go b/sdk/azidentity/managed_identity_client.go index 933ab1740ab6..9947dd48d849 100644 --- a/sdk/azidentity/managed_identity_client.go +++ b/sdk/azidentity/managed_identity_client.go @@ -24,13 +24,15 @@ const ( ) const ( - arcIMDSEndpoint = "IMDS_ENDPOINT" - identityEndpoint = "IDENTITY_ENDPOINT" - identityHeader = "IDENTITY_HEADER" - msiEndpoint = "MSI_ENDPOINT" - msiSecret = "MSI_SECRET" - imdsAPIVersion = "2018-02-01" - azureArcAPIVersion = "2019-08-15" + arcIMDSEndpoint = "IMDS_ENDPOINT" + identityEndpoint = "IDENTITY_ENDPOINT" + identityHeader = "IDENTITY_HEADER" + identityServerThumbprint = "IDENTITY_SERVER_THUMBPRINT" + msiEndpoint = "MSI_ENDPOINT" + msiSecret = "MSI_SECRET" + imdsAPIVersion = "2018-02-01" + azureArcAPIVersion = "2019-08-15" + serviceFabricAPIVersion = "2019-07-01-preview" ) type msiType int @@ -43,6 +45,7 @@ const ( msiTypeUnavailable msiType = 4 msiTypeAppServiceV20190801 msiType = 5 msiTypeAzureArc msiType = 6 + msiTypeServiceFabric msiType = 7 ) // managedIdentityClient provides the base for authenticating in managed identity environments @@ -109,7 +112,7 @@ func (c *managedIdentityClient) createAccessToken(res *azcore.Response) (*azcore Token string `json:"access_token,omitempty"` RefreshToken string `json:"refresh_token,omitempty"` ExpiresIn wrappedNumber `json:"expires_in,omitempty"` // this field should always return the number of seconds for which a token is valid - ExpiresOn string `json:"expires_on,omitempty"` // the value returned in this field varies between a number and a date string + ExpiresOn interface{} `json:"expires_on,omitempty"` // the value returned in this field varies between a number and a date string }{} if err := res.UnmarshalAsJSON(&value); err != nil { return nil, fmt.Errorf("internal AccessToken: %w", err) @@ -121,19 +124,26 @@ func (c *managedIdentityClient) createAccessToken(res *azcore.Response) (*azcore } return &azcore.AccessToken{Token: value.Token, ExpiresOn: time.Now().Add(time.Second * time.Duration(expiresIn)).UTC()}, nil } - if expiresOn, err := strconv.Atoi(value.ExpiresOn); err == nil { - return &azcore.AccessToken{Token: value.Token, ExpiresOn: time.Now().Add(time.Second * time.Duration(expiresOn)).UTC()}, nil - } - // this is the case when expires_on is a time string - // this is the format of the string coming from the service - if expiresOn, err := time.Parse("1/2/2006 15:04:05 PM +00:00", value.ExpiresOn); err == nil { // the date string specified is for Windows OS - eo := expiresOn.UTC() - return &azcore.AccessToken{Token: value.Token, ExpiresOn: eo}, nil - } else if expiresOn, err := time.Parse("1/2/2006 15:04:05 +00:00", value.ExpiresOn); err == nil { // the date string specified is for Linux OS - eo := expiresOn.UTC() - return &azcore.AccessToken{Token: value.Token, ExpiresOn: eo}, nil - } else { - return nil, err + switch v := value.ExpiresOn.(type) { + case float64: + return &azcore.AccessToken{Token: value.Token, ExpiresOn: time.Unix(int64(v), 0).UTC()}, nil + case string: + if expiresOn, err := strconv.Atoi(v); err == nil { + return &azcore.AccessToken{Token: value.Token, ExpiresOn: time.Unix(int64(expiresOn), 0).UTC()}, nil + } + // this is the case when expires_on is a time string + // this is the format of the string coming from the service + if expiresOn, err := time.Parse("1/2/2006 15:04:05 PM +00:00", v); err == nil { // the date string specified is for Windows OS + eo := expiresOn.UTC() + return &azcore.AccessToken{Token: value.Token, ExpiresOn: eo}, nil + } else if expiresOn, err := time.Parse("1/2/2006 15:04:05 +00:00", v); err == nil { // the date string specified is for Linux OS + eo := expiresOn.UTC() + return &azcore.AccessToken{Token: value.Token, ExpiresOn: eo}, nil + } else { + return nil, err + } + default: + return nil, &AuthenticationFailedError{msg: fmt.Sprintf("unsupported type received in expires_on: %T, %v", v, v)} } } @@ -150,6 +160,8 @@ func (c *managedIdentityClient) createAuthRequest(ctx context.Context, clientID return nil, &AuthenticationFailedError{inner: err, msg: "Failed to retreive secret key from the identity endpoint."} } return c.createAzureArcAuthRequest(ctx, key, scopes) + case msiTypeServiceFabric: + return c.createServiceFabricAuthRequest(ctx, clientID, scopes) case msiTypeCloudShell: return c.createCloudShellAuthRequest(ctx, clientID, scopes) default: @@ -213,6 +225,23 @@ func (c *managedIdentityClient) createAppServiceAuthRequest(ctx context.Context, return request, nil } +func (c *managedIdentityClient) createServiceFabricAuthRequest(ctx context.Context, id string, scopes []string) (*azcore.Request, error) { + request, err := azcore.NewRequest(ctx, http.MethodGet, c.endpoint) + if err != nil { + return nil, err + } + q := request.URL.Query() + request.Header.Set("Accept", "application/json") + request.Header.Set("Secret", os.Getenv(identityHeader)) + q.Add("api-version", serviceFabricAPIVersion) + q.Add("resource", strings.Join(scopes, " ")) + if id != "" { + q.Add(qpClientID, id) + } + request.URL.RawQuery = q.Encode() + return request, nil +} + func (c *managedIdentityClient) getAzureArcSecretKey(ctx context.Context, resources []string) (string, error) { // create the request to retreive the secret key challenge provided by the HIMDS service request, err := azcore.NewRequest(ctx, http.MethodGet, c.endpoint) @@ -296,6 +325,9 @@ func (c *managedIdentityClient) getMSIType() (msiType, error) { c.endpoint = endpointEnvVar if header := os.Getenv(identityHeader); header != "" { // if BOTH the env vars IDENTITY_ENDPOINT and IDENTITY_HEADER are set the msiType is AppService c.msiType = msiTypeAppServiceV20190801 + if thumbprint := os.Getenv(identityServerThumbprint); thumbprint != "" { // if IDENTITY_SERVER_THUMBPRINT is set the environment is Service Fabric + c.msiType = msiTypeServiceFabric + } } else if arcIMDS := os.Getenv(arcIMDSEndpoint); arcIMDS != "" { c.msiType = msiTypeAzureArc } else { diff --git a/sdk/azidentity/managed_identity_credential_test.go b/sdk/azidentity/managed_identity_credential_test.go index 1cafdfe0ac2a..7cd3cf659d81 100644 --- a/sdk/azidentity/managed_identity_credential_test.go +++ b/sdk/azidentity/managed_identity_credential_test.go @@ -20,6 +20,7 @@ const ( appServiceWindowsSuccessResp = `{"access_token": "new_token", "expires_on": "9/14/2017 00:00:00 PM +00:00", "resource": "https://vault.azure.net", "token_type": "Bearer"}` appServiceLinuxSuccessResp = `{"access_token": "new_token", "expires_on": "09/14/2017 00:00:00 +00:00", "resource": "https://vault.azure.net", "token_type": "Bearer"}` expiresOnIntResp = `{"access_token": "new_token", "refresh_token": "", "expires_in": "", "expires_on": "1560974028", "not_before": "1560970130", "resource": "https://vault.azure.net", "token_type": "Bearer"}` + expiresOnNonStringIntResp = `{"access_token": "new_token", "refresh_token": "", "expires_in": "", "expires_on": 1560974028, "not_before": "1560970130", "resource": "https://vault.azure.net", "token_type": "Bearer"}` ) func clearEnvVars(envVars ...string) { @@ -294,7 +295,7 @@ func TestManagedIdentityCredential_CreateAppServiceAuthRequestV20170901(t *testi } } -func TestManagedIdentityCredential_CreateAccessTokenExpiresOnInt(t *testing.T) { +func TestManagedIdentityCredential_CreateAccessTokenExpiresOnStringInt(t *testing.T) { resetEnvironmentVarsForTest() srv, close := mock.NewServer() defer close() @@ -620,3 +621,44 @@ func TestManagedIdentityCredential_ResourceID_IMDS(t *testing.T) { t.Fatalf("Unexpected resource ID in resource query param") } } + +func TestManagedIdentityCredential_CreateAccessTokenExpiresOnInt(t *testing.T) { + resetEnvironmentVarsForTest() + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(expiresOnNonStringIntResp))) + _ = os.Setenv("MSI_ENDPOINT", srv.URL()) + _ = os.Setenv("MSI_SECRET", "secret") + defer clearEnvVars("MSI_ENDPOINT", "MSI_SECRET") + options := ManagedIdentityCredentialOptions{} + options.HTTPClient = srv + msiCred, err := NewManagedIdentityCredential(clientID, &options) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + _, err = msiCred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{msiScope}}) + if err != nil { + t.Fatalf("Received an error when attempting to retrieve a token") + } +} + +// adding an incorrect string value in expires_on +func TestManagedIdentityCredential_CreateAccessTokenExpiresOnFail(t *testing.T) { + resetEnvironmentVarsForTest() + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(`{"access_token": "new_token", "refresh_token": "", "expires_in": "", "expires_on": "15609740s28", "not_before": "1560970130", "resource": "https://vault.azure.net", "token_type": "Bearer"}`))) + _ = os.Setenv("MSI_ENDPOINT", srv.URL()) + _ = os.Setenv("MSI_SECRET", "secret") + defer clearEnvVars("MSI_ENDPOINT", "MSI_SECRET") + options := ManagedIdentityCredentialOptions{} + options.HTTPClient = srv + msiCred, err := NewManagedIdentityCredential(clientID, &options) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + _, err = msiCred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{msiScope}}) + if err == nil { + t.Fatalf("expected to receive an error but received none") + } +}