Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions sdk/identity/azure-identity/HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Release History

## 1.0.0b4
### Fixes and improvements:
- `UsernamePasswordCredential` correctly handles environment configuration with
no tenant information (#7260)
- MSAL's user realm discovery requests are sent through credential
pipelines (#7260)

## 1.0.0b3 (2019-09-10)
### New features:
- `SharedTokenCacheCredential` authenticates with tokens stored in a local
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@
except AttributeError: # Python 2.7, abc exists, but not ABC
ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()}) # type: ignore

try:
from unittest import mock
except ImportError: # python < 3.3
import mock # type: ignore

try:
from typing import TYPE_CHECKING
except ImportError:
Expand Down Expand Up @@ -64,10 +59,11 @@ def _create_app(self, cls):
"""Creates an MSAL application, patching msal.authority to use an azure-core pipeline during tenant discovery"""

# MSAL application initializers use msal.authority to send AAD tenant discovery requests
with mock.patch("msal.authority.requests", self._adapter):
with self._adapter:
app = cls(client_id=self._client_id, client_credential=self._client_credential, authority=self._authority)

# monkeypatch the app to replace requests.Session with MsalTransportAdapter
app.client.session.close()
app.client.session = self._adapter

return app
Expand Down Expand Up @@ -106,8 +102,9 @@ class PublicClientCredential(MsalCredential):

def __init__(self, **kwargs):
# type: (Any) -> None
tenant = kwargs.pop("tenant", None) or "organizations"
super(PublicClientCredential, self).__init__(
authority="https://login.microsoftonline.com/" + kwargs.pop("tenant", "organizations"), **kwargs
authority="https://login.microsoftonline.com/" + tenant, **kwargs
)

@abc.abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,18 @@
from azure.core.pipeline.policies import ContentDecodePolicy, NetworkTraceLoggingPolicy, ProxyPolicy, RetryPolicy
from azure.core.pipeline.transport import HttpRequest, RequestsTransport

try:
from unittest import mock
except ImportError: # python < 3.3
import mock # type: ignore

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

if TYPE_CHECKING:
# pylint:disable=unused-import
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Dict, Mapping, Optional
from azure.core.pipeline import PipelineResponse

Expand All @@ -38,17 +43,30 @@ def json(self, **kwargs):

def raise_for_status(self):
# type: () -> None
raise ClientAuthenticationError("authentication failed", self._response)
if self.status_code >= 400:
raise ClientAuthenticationError("authentication failed", self._response)


class MsalTransportAdapter(object):
"""Wraps an azure-core pipeline with the shape of requests.Session"""
"""
Wraps an azure-core pipeline with the shape of requests.Session.

Used as a context manager, patches msal.authority to intercept calls to requests.
"""

def __init__(self, **kwargs):
# type: (Any) -> None
super(MsalTransportAdapter, self).__init__()
self._patch = mock.patch("msal.authority.requests", self)
self._pipeline = self._build_pipeline(**kwargs)

def __enter__(self):
self._patch.__enter__()
return self

def __exit__(self, *args):
self._patch.__exit__(*args)

@staticmethod
def _create_config(**kwargs):
# type: (Any) -> Configuration
Expand Down
7 changes: 4 additions & 3 deletions sdk/identity/azure-identity/azure/identity/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,9 +433,10 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument

if not result:
# cache miss -> request a new token
result = app.acquire_token_by_username_password(
username=self._username, password=self._password, scopes=scopes
)
with self._adapter:
result = app.acquire_token_by_username_password(
username=self._username, password=self._password, scopes=scopes
)

if "access_token" not in result:
raise ClientAuthenticationError(message="authentication failed: {}".format(result.get("error_description")))
Expand Down
50 changes: 48 additions & 2 deletions sdk/identity/azure-identity/tests/test_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,13 @@ def test_interactive_credential_timeout():
def test_username_password_credential():
expected_token = "access-token"
transport = validating_transport(
requests=[Request()] * 2, # not validating requests because they're formed by MSAL
requests=[Request()] * 3, # not validating requests because they're formed by MSAL
responses=[
# expecting tenant discovery then a token request
# tenant discovery
mock_response(json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"}),
# user realm discovery, interests MSAL only when the response body contains account_type == "Federated"
mock_response(json_payload={}),
# token request
mock_response(
json_payload={
"access_token": expected_token,
Expand All @@ -436,3 +439,46 @@ def test_username_password_credential():

token = credential.get_token("scope")
assert token.token == expected_token


def test_username_password_environment_credential(monkeypatch):
client_id = "fake-client-id"
username = "foo@bar.com"
password = "password"
expected_token = "***"

create_transport = functools.partial(
validating_transport,
requests=[Request()] * 3, # not validating requests because they're formed by MSAL
responses=[
# tenant discovery
mock_response(json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"}),
# user realm discovery, interests MSAL only when the response body contains account_type == "Federated"
mock_response(json_payload={}),
# token request
mock_response(
json_payload={
"access_token": expected_token,
"expires_in": 42,
"token_type": "Bearer",
"ext_expires_in": 42,
}
),
],
)

monkeypatch.setenv(EnvironmentVariables.AZURE_CLIENT_ID, client_id)
monkeypatch.setenv(EnvironmentVariables.AZURE_USERNAME, username)
monkeypatch.setenv(EnvironmentVariables.AZURE_PASSWORD, password)

token = EnvironmentCredential(transport=create_transport()).get_token("scope")

# not validating expires_on because doing so requires monkeypatching time, and this is tested elsewhere
assert token.token == expected_token

# now with a tenant id
monkeypatch.setenv(EnvironmentVariables.AZURE_TENANT_ID, "tenant_id")

token = EnvironmentCredential(transport=create_transport()).get_token("scope")

assert token.token == expected_token