diff --git a/sdk/identity/azure-identity/HISTORY.md b/sdk/identity/azure-identity/HISTORY.md index df64badcb396..80e0ff93b793 100644 --- a/sdk/identity/azure-identity/HISTORY.md +++ b/sdk/identity/azure-identity/HISTORY.md @@ -1,5 +1,12 @@ # Release History +## 1.0.0b4 +### Fixes and improvements: +- `UsernamePasswordCredential` correctly handles environment configuration with +no tenant information (#7260) +- MSAL's user realm discovery requests are sent through credential +pipelines (#7260) + ## 1.0.0b3 (2019-09-10) ### New features: - `SharedTokenCacheCredential` authenticates with tokens stored in a local diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py index bae20de6d3e8..29e3a8f1cd62 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py @@ -20,11 +20,6 @@ except AttributeError: # Python 2.7, abc exists, but not ABC ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()}) # type: ignore -try: - from unittest import mock -except ImportError: # python < 3.3 - import mock # type: ignore - try: from typing import TYPE_CHECKING except ImportError: @@ -64,10 +59,11 @@ def _create_app(self, cls): """Creates an MSAL application, patching msal.authority to use an azure-core pipeline during tenant discovery""" # MSAL application initializers use msal.authority to send AAD tenant discovery requests - with mock.patch("msal.authority.requests", self._adapter): + with self._adapter: app = cls(client_id=self._client_id, client_credential=self._client_credential, authority=self._authority) # monkeypatch the app to replace requests.Session with MsalTransportAdapter + app.client.session.close() app.client.session = self._adapter return app @@ -106,8 +102,9 @@ class PublicClientCredential(MsalCredential): def __init__(self, **kwargs): # type: (Any) -> None + tenant = kwargs.pop("tenant", None) or "organizations" super(PublicClientCredential, self).__init__( - authority="https://login.microsoftonline.com/" + kwargs.pop("tenant", "organizations"), **kwargs + authority="https://login.microsoftonline.com/" + tenant, **kwargs ) @abc.abstractmethod diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_transport_adapter.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_transport_adapter.py index b6b8e96f46af..9cf777424ede 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/msal_transport_adapter.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_transport_adapter.py @@ -12,13 +12,18 @@ from azure.core.pipeline.policies import ContentDecodePolicy, NetworkTraceLoggingPolicy, ProxyPolicy, RetryPolicy from azure.core.pipeline.transport import HttpRequest, RequestsTransport +try: + from unittest import mock +except ImportError: # python < 3.3 + import mock # type: ignore + try: from typing import TYPE_CHECKING except ImportError: TYPE_CHECKING = False if TYPE_CHECKING: - # pylint:disable=unused-import + # pylint:disable=unused-import,ungrouped-imports from typing import Any, Dict, Mapping, Optional from azure.core.pipeline import PipelineResponse @@ -38,17 +43,30 @@ def json(self, **kwargs): def raise_for_status(self): # type: () -> None - raise ClientAuthenticationError("authentication failed", self._response) + if self.status_code >= 400: + raise ClientAuthenticationError("authentication failed", self._response) class MsalTransportAdapter(object): - """Wraps an azure-core pipeline with the shape of requests.Session""" + """ + Wraps an azure-core pipeline with the shape of requests.Session. + + Used as a context manager, patches msal.authority to intercept calls to requests. + """ def __init__(self, **kwargs): # type: (Any) -> None super(MsalTransportAdapter, self).__init__() + self._patch = mock.patch("msal.authority.requests", self) self._pipeline = self._build_pipeline(**kwargs) + def __enter__(self): + self._patch.__enter__() + return self + + def __exit__(self, *args): + self._patch.__exit__(*args) + @staticmethod def _create_config(**kwargs): # type: (Any) -> Configuration diff --git a/sdk/identity/azure-identity/azure/identity/credentials.py b/sdk/identity/azure-identity/azure/identity/credentials.py index 091b6472c676..2bff4e17102f 100644 --- a/sdk/identity/azure-identity/azure/identity/credentials.py +++ b/sdk/identity/azure-identity/azure/identity/credentials.py @@ -433,9 +433,10 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument if not result: # cache miss -> request a new token - result = app.acquire_token_by_username_password( - username=self._username, password=self._password, scopes=scopes - ) + with self._adapter: + result = app.acquire_token_by_username_password( + username=self._username, password=self._password, scopes=scopes + ) if "access_token" not in result: raise ClientAuthenticationError(message="authentication failed: {}".format(result.get("error_description"))) diff --git a/sdk/identity/azure-identity/tests/test_identity.py b/sdk/identity/azure-identity/tests/test_identity.py index 799d50476523..2d7c1208a3ba 100644 --- a/sdk/identity/azure-identity/tests/test_identity.py +++ b/sdk/identity/azure-identity/tests/test_identity.py @@ -411,10 +411,13 @@ def test_interactive_credential_timeout(): def test_username_password_credential(): expected_token = "access-token" transport = validating_transport( - requests=[Request()] * 2, # not validating requests because they're formed by MSAL + requests=[Request()] * 3, # not validating requests because they're formed by MSAL responses=[ - # expecting tenant discovery then a token request + # tenant discovery mock_response(json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"}), + # user realm discovery, interests MSAL only when the response body contains account_type == "Federated" + mock_response(json_payload={}), + # token request mock_response( json_payload={ "access_token": expected_token, @@ -436,3 +439,46 @@ def test_username_password_credential(): token = credential.get_token("scope") assert token.token == expected_token + + +def test_username_password_environment_credential(monkeypatch): + client_id = "fake-client-id" + username = "foo@bar.com" + password = "password" + expected_token = "***" + + create_transport = functools.partial( + validating_transport, + requests=[Request()] * 3, # not validating requests because they're formed by MSAL + responses=[ + # tenant discovery + mock_response(json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"}), + # user realm discovery, interests MSAL only when the response body contains account_type == "Federated" + mock_response(json_payload={}), + # token request + mock_response( + json_payload={ + "access_token": expected_token, + "expires_in": 42, + "token_type": "Bearer", + "ext_expires_in": 42, + } + ), + ], + ) + + monkeypatch.setenv(EnvironmentVariables.AZURE_CLIENT_ID, client_id) + monkeypatch.setenv(EnvironmentVariables.AZURE_USERNAME, username) + monkeypatch.setenv(EnvironmentVariables.AZURE_PASSWORD, password) + + token = EnvironmentCredential(transport=create_transport()).get_token("scope") + + # not validating expires_on because doing so requires monkeypatching time, and this is tested elsewhere + assert token.token == expected_token + + # now with a tenant id + monkeypatch.setenv(EnvironmentVariables.AZURE_TENANT_ID, "tenant_id") + + token = EnvironmentCredential(transport=create_transport()).get_token("scope") + + assert token.token == expected_token