Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from .msal_credentials import ConfidentialClientCredential
from .msal_transport_adapter import MsalTransportResponse, MsalTransportAdapter
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
"""Credentials wrapping MSAL applications and delegating token acquisition and caching to them.
This entails monkeypatching MSAL's OAuth client with an adapter substituting an azure-core pipeline for Requests.
"""

import time

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

try:
from unittest import mock
except ImportError: # python < 3.3
import mock # type: ignore

if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Any, Mapping, Optional, Union

from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
import msal

from .msal_transport_adapter import MsalTransportAdapter


class MsalCredential(object):
"""Base class for credentials wrapping MSAL applications"""

def __init__(self, client_id, authority, app_class, client_credential=None, **kwargs):
# type: (str, str, msal.ClientApplication, Optional[Union[str, Mapping[str, str]]], Any) -> None
self._authority = authority
self._client_credential = client_credential
self._client_id = client_id

self._adapter = kwargs.pop("msal_adapter", None) or MsalTransportAdapter(**kwargs)

# postpone creating the wrapped application because its initializer uses the network
self._app_class = app_class
self._msal_app = None # type: Optional[msal.ClientApplication]

@property
def _app(self):
# type: () -> msal.ClientApplication
"""The wrapped MSAL application"""

if not self._msal_app:
# MSAL application initializers use msal.authority to send AAD tenant discovery requests
with mock.patch("msal.authority.requests", self._adapter):
app = self._app_class(
client_id=self._client_id, client_credential=self._client_credential, authority=self._authority
)

# monkeypatch the app to replace requests.Session with MsalTransportAdapter
app.client.session = self._adapter
self._msal_app = app

return self._msal_app


class ConfidentialClientCredential(MsalCredential):
"""Wraps an MSAL ConfidentialClientApplication with the TokenCredential API"""

def __init__(self, **kwargs):
# type: (Any) -> None
super(ConfidentialClientCredential, self).__init__(app_class=msal.ConfidentialClientApplication, **kwargs)

def get_token(self, *scopes):
# type: (str) -> AccessToken

# MSAL requires scopes be a list
scopes = list(scopes) # type: ignore
now = int(time.time())

# First try to get a cached access token or if a refresh token is cached, redeem it for an access token.
# Failing that, acquire a new token.
app = self._app # type: msal.ConfidentialClientApplication
result = app.acquire_token_silent(scopes, account=None) or app.acquire_token_for_client(scopes)

if "access_token" not in result:
raise ClientAuthenticationError(message="authentication failed: {}".format(result.get("error_description")))

return AccessToken(result["access_token"], now + int(result["expires_in"]))
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
"""Adapter to substitute an azure-core pipeline for Requests in MSAL application token acquisition methods."""

import json

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Any, Dict, Mapping, Optional
from azure.core.pipeline import PipelineResponse

from azure.core.configuration import Configuration
from azure.core.exceptions import ClientAuthenticationError
from azure.core.pipeline import Pipeline
from azure.core.pipeline.policies import ContentDecodePolicy, NetworkTraceLoggingPolicy, RetryPolicy
from azure.core.pipeline.transport import HttpRequest, RequestsTransport


class MsalTransportResponse:
"""Wraps an azure-core PipelineResponse with the shape of requests.Response"""

def __init__(self, pipeline_response):
# type: (PipelineResponse) -> None
self._response = pipeline_response.http_response
self.status_code = self._response.status_code
self.text = self._response.text()

def json(self, **kwargs):
# type: (Any) -> Mapping[str, Any]
return json.loads(self.text, **kwargs)

def raise_for_status(self):
# type: () -> None
raise ClientAuthenticationError("authentication failed", self._response)


class MsalTransportAdapter(object):
"""Wraps an azure-core pipeline with the shape of requests.Session"""

def __init__(self, **kwargs):
# type: (Any) -> None
super(MsalTransportAdapter, self).__init__()
self._pipeline = self._build_pipeline(**kwargs)

@staticmethod
def create_config(**kwargs):
# type: (Any) -> Configuration
config = Configuration(**kwargs)
config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
config.retry_policy = RetryPolicy(**kwargs)
return config

def _build_pipeline(self, config=None, policies=None, transport=None, **kwargs):
config = config or self.create_config(**kwargs)
policies = policies or [ContentDecodePolicy(), config.retry_policy, config.logging_policy]
if not transport:
transport = RequestsTransport(configuration=config)
return Pipeline(transport=transport, policies=policies)

def get(self, url, headers=None, params=None, timeout=None, verify=None, **kwargs):
# type: (str, Optional[Mapping[str, str]], Optional[Dict[str, str]], float, bool, Any) -> MsalTransportResponse
request = HttpRequest("GET", url, headers=headers)
if params:
request.format_parameters(params)
response = self._pipeline.run(
request, stream=False, connection_timeout=timeout, connection_verify=verify, **kwargs
)
return MsalTransportResponse(response)

def post(self, url, data=None, headers=None, params=None, timeout=None, verify=None, **kwargs):
# type: (str, Optional[Mapping[str, str]], Optional[Mapping[str, str]], Optional[Dict[str, str]], float, bool, Any) -> MsalTransportResponse
request = HttpRequest("POST", url, headers=headers)
if params:
request.format_parameters(params)
if data:
request.headers["Content-Type"] = "application/x-www-form-urlencoded"
request.set_formdata_body(data)
response = self._pipeline.run(
request, stream=False, connection_timeout=timeout, connection_verify=verify, **kwargs
)
return MsalTransportResponse(response)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from ._authn_client import AsyncAuthnClient
from ..constants import Endpoints, EnvironmentVariables
from .._internal import _ManagedIdentityBase
from .._managed_identity import _ManagedIdentityBase


class _AsyncManagedIdentityBase(_ManagedIdentityBase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from azure.core.pipeline.policies import ContentDecodePolicy, HeadersPolicy, NetworkTraceLoggingPolicy, AsyncRetryPolicy

from ._authn_client import AsyncAuthnClient
from ._internal import ImdsCredential, MsiCredential
from ._managed_identity import ImdsCredential, MsiCredential
from .._base import ClientSecretCredentialBase, CertificateCredentialBase
from ..constants import Endpoints, EnvironmentVariables
from ..credentials import ChainedTokenCredential
Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/azure-identity/azure/identity/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from ._authn_client import AuthnClient
from ._base import ClientSecretCredentialBase, CertificateCredentialBase
from ._internal import ImdsCredential, MsiCredential
from ._managed_identity import ImdsCredential, MsiCredential
from .constants import Endpoints, EnvironmentVariables

try:
Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/azure-identity/tests/test_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
ManagedIdentityCredential,
ChainedTokenCredential,
)
from azure.identity._internal import ImdsCredential
from azure.identity._managed_identity import ImdsCredential
from azure.identity.constants import EnvironmentVariables

from helpers import mock_response, Request, validating_transport
Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/azure-identity/tests/test_identity_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
EnvironmentCredential,
ManagedIdentityCredential,
)
from azure.identity.aio._internal import ImdsCredential
from azure.identity.aio._managed_identity import ImdsCredential
from azure.identity.constants import EnvironmentVariables

from helpers import mock_response, Request, async_validating_transport
Expand Down
22 changes: 13 additions & 9 deletions sdk/identity/azure-identity/tests/test_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import os

try:
from unittest import mock
except ImportError: # python < 3.3
import mock # type: ignore

from azure.identity import DefaultAzureCredential, CertificateCredential, ClientSecretCredential
from azure.identity.constants import EnvironmentVariables
import pytest
from azure.identity._internal import ConfidentialClientCredential

ARM_SCOPE = "https://management.azure.com/.default"

Expand Down Expand Up @@ -46,3 +38,15 @@ def test_default_credential(live_identity_settings):
assert token
assert token.token
assert token.expires_on


def test_confidential_client_credential(live_identity_settings):
credential = ConfidentialClientCredential(
client_id=live_identity_settings["client_id"],
client_credential=live_identity_settings["client_secret"],
authority="https://login.microsoftonline.com/" + live_identity_settings["tenant_id"],
)
token = credential.get_token(ARM_SCOPE)
assert token
assert token.token
assert token.expires_on