diff --git a/sdk/identity/azure-identity/HISTORY.md b/sdk/identity/azure-identity/HISTORY.md index f49fd18bd23f..2d7590e374b1 100644 --- a/sdk/identity/azure-identity/HISTORY.md +++ b/sdk/identity/azure-identity/HISTORY.md @@ -2,6 +2,10 @@ ## 1.0.1 +- `ClientCertificateCredential` uses application and tenant IDs correctly +([#8315](https://github.com/Azure/azure-sdk-for-python/pull/8315)) + + ## 1.0.0 (2019-10-29) ### Breaking changes: - Async credentials now default to [`aiohttp`](https://pypi.org/project/aiohttp/) diff --git a/sdk/identity/azure-identity/azure/identity/_authn_client.py b/sdk/identity/azure-identity/azure/identity/_authn_client.py index 04b43a2b0618..e41994c29e3d 100644 --- a/sdk/identity/azure-identity/azure/identity/_authn_client.py +++ b/sdk/identity/azure-identity/azure/identity/_authn_client.py @@ -63,6 +63,10 @@ def __init__(self, endpoint=None, authority=None, tenant=None, **kwargs): # pyl self._auth_url = "https://" + "/".join((authority.strip("/"), tenant.strip("/"), "oauth2/v2.0/token")) self._cache = kwargs.get("cache") or TokenCache() # type: TokenCache + @property + def auth_url(self): + return self._auth_url + def get_cached_token(self, scopes): # type: (Iterable[str]) -> Optional[AccessToken] tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes)) diff --git a/sdk/identity/azure-identity/azure/identity/_base.py b/sdk/identity/azure-identity/azure/identity/_base.py index 3bc20d84b3b5..98850e5f7231 100644 --- a/sdk/identity/azure-identity/azure/identity/_base.py +++ b/sdk/identity/azure-identity/azure/identity/_base.py @@ -2,14 +2,19 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import abc import binascii from cryptography import x509 from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.backends import default_backend from msal.oauth2cli import JwtSigner +import six -from ._constants import Endpoints +try: + ABC = abc.ABC +except AttributeError: # Python 2.7, abc exists, but not ABC + ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()}) # type: ignore try: from typing import TYPE_CHECKING @@ -38,7 +43,7 @@ def __init__(self, tenant_id, client_id, secret, **kwargs): # pylint:disable=un super(ClientSecretCredentialBase, self).__init__() -class CertificateCredentialBase(object): +class CertificateCredentialBase(ABC): """Sans I/O base for certificate credentials""" def __init__(self, tenant_id, client_id, certificate_path, **kwargs): # pylint:disable=unused-argument @@ -58,16 +63,23 @@ def __init__(self, tenant_id, client_id, certificate_path, **kwargs): # pylint: cert = x509.load_pem_x509_certificate(pem_bytes, default_backend()) fingerprint = cert.fingerprint(hashes.SHA1()) - self._auth_url = Endpoints.AAD_OAUTH2_V2_FORMAT.format(tenant_id) + self._client = self._get_auth_client(tenant_id, **kwargs) self._client_id = client_id self._signer = JwtSigner(private_key, "RS256", sha1_thumbprint=binascii.hexlify(fingerprint)) def _get_request_data(self, *scopes): - assertion = self._signer.sign_assertion(audience=self._auth_url, issuer=self._client_id) + assertion = self._signer.sign_assertion(audience=self._client.auth_url, issuer=self._client_id) + if isinstance(assertion, six.binary_type): + assertion = assertion.decode("utf-8") + return { "client_assertion": assertion, "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", "client_id": self._client_id, "grant_type": "client_credentials", - "scope": " ".join(scopes) + "scope": " ".join(scopes), } + + @abc.abstractmethod + def _get_auth_client(self, tenant_id, **kwargs): + pass diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/client_credential.py b/sdk/identity/azure-identity/azure/identity/_credentials/client_credential.py index 49cd6a2e92ca..10eae06c2b2a 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/client_credential.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/client_credential.py @@ -63,11 +63,6 @@ class CertificateCredential(CertificateCredentialBase): defines authorities for other clouds. """ - def __init__(self, tenant_id, client_id, certificate_path, **kwargs): - # type: (str, str, str, **Any) -> None - self._client = AuthnClient(tenant=tenant_id, **kwargs) - super(CertificateCredential, self).__init__(client_id, tenant_id, certificate_path, **kwargs) - def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument # type: (*str, **Any) -> AccessToken """Request an access token for `scopes`. @@ -83,3 +78,6 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument data = self._get_request_data(*scopes) token = self._client.request_token(scopes, form_data=data) return token + + def _get_auth_client(self, tenant_id, **kwargs): + return AuthnClient(tenant=tenant_id, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_credential.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_credential.py index 398035567fad..d674bee2ec8d 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_credential.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_credential.py @@ -57,10 +57,6 @@ class CertificateCredential(CertificateCredentialBase): defines authorities for other clouds. """ - def __init__(self, tenant_id: str, client_id: str, certificate_path: str, **kwargs: "Mapping[str, Any]") -> None: - super(CertificateCredential, self).__init__(tenant_id, client_id, certificate_path, **kwargs) - self._client = AsyncAuthnClient(tenant=tenant_id, **kwargs) - async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument """Asynchronously request an access token for `scopes`. @@ -75,3 +71,6 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py data = self._get_request_data(*scopes) token = await self._client.request_token(scopes, form_data=data) return token # type: ignore + + def _get_auth_client(self, tenant_id, **kwargs): + return AsyncAuthnClient(tenant=tenant_id, **kwargs) diff --git a/sdk/identity/azure-identity/tests/certificate.pem b/sdk/identity/azure-identity/tests/certificate.pem new file mode 100644 index 000000000000..4b66bfa021a0 --- /dev/null +++ b/sdk/identity/azure-identity/tests/certificate.pem @@ -0,0 +1,49 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDL1hG+JYCfIPp3 +tlZ05J4pYIJ3Ckfs432bE3rYuWlR2w9KqdjWkKxuAxpjJ+T+uoqVaT3BFMfi4ZRY +OCI69s4+lP3DwR8uBCp9xyVkF8thXfS3iui0liGDviVBoBJJWvjDFU8a/Hseg+Qf +oxAb6tx0kEc7V3ozBLWoIDJjfwJ3NdsLZGVtAC34qCWeEIvS97CDA4g3Kc6hYJIr +Aa7pxHzo/Nd0U3e7z+DlBcJV7dY6TZUyjBVTpzppWe+XQEOfKsjkDNykHEC1C1bC +lG0u7unS7QOBMd6bOGkeL+Bc+n22slTzs5amsbDLNuobSaUsFt9vgD5jRD6FwhpX +wj/Ek0F7AgMBAAECggEAblU3UWdXUcs2CCqIbcl52wfEVs8X05/n01MeAcWKvqYG +hvGcz7eLvhir5dQoXcF3VhybMrIe6C4WcBIiZSxGwxU+rwEP8YaLwX1UPfOrQM7s +sZTdFTLWfUslO3p7q300fdRA92iG9COMDZvkElh0cBvQksxs9sSr149l9vk+ymtC +uBhZtHG6Ki0BIMBNC9jGUqDuOatXl/dkK4tNjXrNJT7tVwzPaqnNALIWl6B+k9oQ +m1oNhSH2rvs9tw2ITXfIoIk9KdOMjQVUD43wKOaz0hNZhUsb1OFuls7UtRzaFcZH +rMd/M8DtA104QTTlHK+XS7r+nqdv7+ZyB+suTdM+oQKBgQDxCrJZU3hJ0eJ4VYhK +xGDfVGNpYxNkQ4CDB9fwRNbFr/Ck3kgzfE9QxTx1pJOolVmfuFmk9B86in4UNy91 +KdaqT79AU5RdOBXNN6tuMbLC0AVqe8sZq+1vWVVwbCstffxEMmyW1Ju/FLYPl2Zp +e5P96dBh5B3mXrQtpDJ0RkxxaQKBgQDYfE6tQQnQSs2ewD6ae8Mu6j8ueDlVoZ37 +vze1QdBasR26xu2H8XBt3u41zc524BwQsB1GE1tnC8ZylrqwVEayK4FesSQRCO6o +yK8QSdb06I5J4TaN+TppCDPLzstOh0Dmxp+iFUGoErb7AEOLAJ/VebhF9kBZObL/ +HYy4Es+bQwKBgHW/4vYuB3IQXNCp/+V+X1BZ+iJOaves3gekekF+b2itFSKFD8JO +9LQhVfKmTheptdmHhgtF0keXxhV8C+vxX1Ndl7EF41FSh5vzmQRAtPHkCvFEviex +TFD70/gSb1lO1UA/Xbqk69yBcprVPAtFejss0EYx2MVj+CLftmIEwW0ZAoGBAIMG +EVQ45eikLXjkn78+Iq7VZbIJX6IdNBH29I+GqsUJJ5Yw6fh6P3KwF3qG+mvmTfYn +sUAFXS+r58rYwVsRVsxlGmKmUc7hmhibhaEVH72QtvWuEiexbRG+viKfIVuA7t39 +3wXpWZiQ4yBdU4Pgt9wrVEU7ukyGaHiReOa7s90jAoGAJc0K7smn98YutQQ+g2ur +ybfnsl0YdsksaP2S2zvZUmNevKPrgnaIDDabOlhYYga+AK1G3FQ7/nefUgiIg1Nd +kr+T6Q4osS3xHB6Az9p/jaF4R2KaWN2nNVCn7ecsmPxDdM7k1vLxaT26vwO9OP5f +YU/5CeIzrfA5nQyPZkOXZBk= +-----END PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIIDazCCAlOgAwIBAgIUF2VIP4+AnEtb52KTCHbo4+fESfswDQYJKoZIhvcNAQEL +BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0xOTEwMzAyMjQ2MjBaFw0yMjA4 +MTkyMjQ2MjBaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw +HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQDL1hG+JYCfIPp3tlZ05J4pYIJ3Ckfs432bE3rYuWlR +2w9KqdjWkKxuAxpjJ+T+uoqVaT3BFMfi4ZRYOCI69s4+lP3DwR8uBCp9xyVkF8th +XfS3iui0liGDviVBoBJJWvjDFU8a/Hseg+QfoxAb6tx0kEc7V3ozBLWoIDJjfwJ3 +NdsLZGVtAC34qCWeEIvS97CDA4g3Kc6hYJIrAa7pxHzo/Nd0U3e7z+DlBcJV7dY6 +TZUyjBVTpzppWe+XQEOfKsjkDNykHEC1C1bClG0u7unS7QOBMd6bOGkeL+Bc+n22 +slTzs5amsbDLNuobSaUsFt9vgD5jRD6FwhpXwj/Ek0F7AgMBAAGjUzBRMB0GA1Ud +DgQWBBT6Mf9uXFB67bY2PeW3GCTKfkO7vDAfBgNVHSMEGDAWgBT6Mf9uXFB67bY2 +PeW3GCTKfkO7vDAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQCZ +1+kTISX85v9/ag7glavaPFUYsOSOOofl8gSzov7L01YL+srq7tXdvZmWrjQ/dnOY +h18rp9rb24vwIYxNioNG/M2cW1jBJwEGsDPOwdPV1VPcRmmUJW9kY130gRHBCd/N +qB7dIkcQnpNsxPIIWI+sRQp73U0ijhOByDnCNHLHon6vbfFTwkO1XggmV5BdZ3uQ +JNJyckILyNzlhmf6zhonMp4lVzkgxWsAm2vgdawd6dmBa+7Avb2QK9s+IdUSutFh +DgW2L12Obgh12Y4sf1iKQXA0RbZ2k+XQIz8EKZa7vJQY0ciYXSgB/BV3a96xX3cx +LIPL8Vam8Ytkopi3gsGA +-----END CERTIFICATE----- \ No newline at end of file diff --git a/sdk/identity/azure-identity/tests/helpers.py b/sdk/identity/azure-identity/tests/helpers.py index 694a92df922a..85eb7ff95cb2 100644 --- a/sdk/identity/azure-identity/tests/helpers.py +++ b/sdk/identity/azure-identity/tests/helpers.py @@ -2,7 +2,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import base64 import json +import six try: from unittest import mock @@ -60,6 +62,14 @@ def validate_request(request, **kwargs): return mock.Mock(send=validate_request) +def urlsafeb64_decode(s): + if isinstance(s, six.text_type): + s = s.encode("ascii") + + padding_needed = 4 - len(s) % 4 + return base64.urlsafe_b64decode(s + b"=" * padding_needed) + + try: import asyncio diff --git a/sdk/identity/azure-identity/tests/test_certificate_credential.py b/sdk/identity/azure-identity/tests/test_certificate_credential.py new file mode 100644 index 000000000000..7cf4c8e698bb --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_certificate_credential.py @@ -0,0 +1,60 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import json +import os + +from azure.identity import CertificateCredential +from six.moves.urllib_parse import urlparse +from helpers import urlsafeb64_decode, mock_response + +try: + from unittest.mock import Mock, patch +except ImportError: # python < 3.3 + from mock import Mock, patch # type: ignore + +CERT_PATH = os.path.join(os.path.dirname(__file__), "certificate.pem") + + +def test_request_url(): + authority = "authority.com" + tenant_id = "expected_tenant" + access_token = "***" + + def validate_url(url): + scheme, netloc, path, _, _, _ = urlparse(url) + assert scheme == "https" + assert netloc == authority + assert path.startswith("/" + tenant_id) + + def mock_send(request, **kwargs): + validate_url(request.url) + return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": access_token}) + + cred = CertificateCredential(tenant_id, "client_id", CERT_PATH, transport=Mock(send=mock_send), authority=authority) + token = cred.get_token("scope") + assert token.token == access_token + + +def test_request_body(): + access_token = "***" + authority = "authority.com" + tenant_id = "tenant" + + def validate_url(url): + scheme, netloc, path, _, _, _ = urlparse(url) + assert scheme == "https" + assert netloc == authority + assert path.startswith("/" + tenant_id) + + def mock_send(request, **kwargs): + jwt = request.body["client_assertion"] + header, payload, signature = (urlsafeb64_decode(s) for s in jwt.split(".")) + claims = json.loads(payload.decode("utf-8")) + validate_url(claims["aud"]) + return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": access_token}) + + cred = CertificateCredential(tenant_id, "client_id", CERT_PATH, transport=Mock(send=mock_send), authority=authority) + token = cred.get_token("scope") + assert token.token == access_token diff --git a/sdk/identity/azure-identity/tests/test_certificate_credential_async.py b/sdk/identity/azure-identity/tests/test_certificate_credential_async.py new file mode 100644 index 000000000000..3ef0b71da665 --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_certificate_credential_async.py @@ -0,0 +1,59 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import json +import os +from unittest.mock import Mock, patch +from urllib.parse import urlparse + +from azure.identity.aio import CertificateCredential +from helpers import urlsafeb64_decode, mock_response +import pytest + +CERT_PATH = os.path.join(os.path.dirname(__file__), "certificate.pem") + + +@pytest.mark.asyncio +async def test_request_url(): + authority = "authority.com" + tenant_id = "expected_tenant" + access_token = "***" + + def validate_url(url): + scheme, netloc, path, _, _, _ = urlparse(url) + assert scheme == "https" + assert netloc == authority + assert path.startswith("/" + tenant_id) + + async def mock_send(request, **kwargs): + validate_url(request.url) + return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": access_token}) + + cred = CertificateCredential(tenant_id, "client_id", CERT_PATH, transport=Mock(send=mock_send), authority=authority) + token = await cred.get_token("scope") + assert token.token == access_token + + +@pytest.mark.asyncio +async def test_request_body(): + access_token = "***" + authority = "authority.com" + tenant_id = "tenant" + + def validate_url(url): + scheme, netloc, path, _, _, _ = urlparse(url) + assert scheme == "https" + assert netloc == authority + assert path.startswith("/" + tenant_id) + + async def mock_send(request, **kwargs): + jwt = request.body["client_assertion"] + header, payload, signature = (urlsafeb64_decode(s) for s in jwt.split(".")) + claims = json.loads(payload.decode("utf-8")) + validate_url(claims["aud"]) + return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": access_token}) + + cred = CertificateCredential(tenant_id, "client_id", CERT_PATH, transport=Mock(send=mock_send), authority=authority) + token = await cred.get_token("scope") + assert token.token == access_token