diff --git a/sdk/core/azure-core/azure/core/credentials_async.py b/sdk/core/azure-core/azure/core/credentials_async.py new file mode 100644 index 000000000000..b42e717f4790 --- /dev/null +++ b/sdk/core/azure-core/azure/core/credentials_async.py @@ -0,0 +1,23 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any + from typing_extensions import Protocol + from .credentials import AccessToken + + class AsyncTokenCredential(Protocol): + async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + pass + + async def close(self) -> None: + pass + + async def __aenter__(self): + pass + + async def __aexit__(self, exc_type, exc_value, traceback) -> None: + pass diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py index 11e6c8cf99fe..45245c252d87 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py @@ -29,11 +29,10 @@ class _BearerTokenCredentialPolicyBase(object): :param str scopes: Lets you specify the type of access needed. """ - def __init__(self, credential, *scopes, **kwargs): # pylint:disable=unused-argument - # type: (TokenCredential, *str, Mapping[str, Any]) -> None + def __init__(self, *scopes, **kwargs): # pylint:disable=unused-argument + # type: (*str, **Any) -> None super(_BearerTokenCredentialPolicyBase, self).__init__() self._scopes = scopes - self._credential = credential self._token = None # type: Optional[AccessToken] @staticmethod @@ -69,6 +68,11 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPo :raises: :class:`~azure.core.exceptions.ServiceRequestError` """ + def __init__(self, credential, *scopes, **kwargs): + # type: (TokenCredential, *str, **Any) -> None + self._credential = credential + super(BearerTokenCredentialPolicy, self).__init__(*scopes, **kwargs) + def on_request(self, request): # type: (PipelineRequest) -> None """Adds a bearer token Authorization header to request and sends request to next policy. diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py index 1c6e220d092a..1edac216f2d0 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py @@ -4,26 +4,33 @@ # license information. # ------------------------------------------------------------------------- import threading +from typing import TYPE_CHECKING -from azure.core.pipeline import PipelineRequest from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.core.pipeline.policies._authentication import _BearerTokenCredentialPolicyBase +if TYPE_CHECKING: + # pylint:disable=unused-import + from typing import Any + from azure.core.credentials_async import AsyncTokenCredential + from azure.core.pipeline import PipelineRequest + class AsyncBearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPolicy): # pylint:disable=too-few-public-methods """Adds a bearer token Authorization header to requests. :param credential: The credential. - :type credential: ~azure.core.credentials.TokenCredential + :type credential: ~azure.core.credentials_async.AsyncTokenCredential :param str scopes: Lets you specify the type of access needed. """ - def __init__(self, credential, *scopes, **kwargs): - super().__init__(credential, *scopes, **kwargs) + def __init__(self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: "Any") -> None: + self._credential = credential self._lock = threading.Lock() + super().__init__(*scopes, **kwargs) - async def on_request(self, request: PipelineRequest): + async def on_request(self, request: "PipelineRequest"): """Adds a bearer token Authorization header to request and sends request to next policy. :param request: The pipeline request object to be modified. diff --git a/sdk/identity/azure-identity/HISTORY.md b/sdk/identity/azure-identity/HISTORY.md index 84bc1fb147c4..35122e3b983a 100644 --- a/sdk/identity/azure-identity/HISTORY.md +++ b/sdk/identity/azure-identity/HISTORY.md @@ -4,6 +4,8 @@ - All credential pipelines include `ProxyPolicy` ([#8945](https://github.com/Azure/azure-sdk-for-python/pull/8945)) +- Async credentials are async context managers and have an async `close` method +([#9090](https://github.com/Azure/azure-sdk-for-python/pull/9090)) ## 1.1.0 (2019-11-27) diff --git a/sdk/identity/azure-identity/README.md b/sdk/identity/azure-identity/README.md index df2568b237c6..e44003aaeb44 100644 --- a/sdk/identity/azure-identity/README.md +++ b/sdk/identity/azure-identity/README.md @@ -213,6 +213,24 @@ async transport, such as [aiohttp](https://pypi.org/project/aiohttp/). See [azure-core documentation](../../core/azure-core/README.md#transport) for more information. +Async credentials should be closed when they're no longer needed. Each async +credential is an async context manager and defines an async `close` method. For +example: + +```py +from azure.identity.aio import DefaultAzureCredential + +# call close when the credential is no longer needed +credential = DefaultAzureCredential() +... +await credential.close() + +# alternatively, use the credential as an async context manager +credential = DefaultAzureCredential() +async with credential: + ... +``` + This example demonstrates authenticating the asynchronous `SecretClient` from [azure-keyvault-secrets][azure_keyvault_secrets] with an asynchronous credential. diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/chained.py b/sdk/identity/azure-identity/azure/identity/_credentials/chained.py index 3056bd21f4dd..6c97ac8b6865 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/chained.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/chained.py @@ -15,6 +15,19 @@ from azure.core.credentials import AccessToken, TokenCredential +def _get_error_message(history): + attempts = [] + for credential, error in history: + if error: + attempts.append("{}: {}".format(credential.__class__.__name__, error)) + else: + attempts.append(credential.__class__.__name__) + return """No credential in this chain provided a token. +Attempted credentials:\n\t{}""".format( + "\n\t".join(attempts) + ) + + class ChainedTokenCredential(object): """A sequence of credentials that is itself a credential. @@ -48,16 +61,5 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument history.append((credential, ex.message)) except Exception as ex: # pylint: disable=broad-except history.append((credential, str(ex))) - error_message = self._get_error_message(history) + error_message = _get_error_message(history) raise ClientAuthenticationError(message=error_message) - - @staticmethod - def _get_error_message(history): - attempts = [] - for credential, error in history: - if error: - attempts.append("{}: {}".format(credential.__class__.__name__, error)) - else: - attempts.append(credential.__class__.__name__) - return """No credential in this chain provided a token. -Attempted credentials:\n\t{}""".format("\n\t".join(attempts)) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_authn_client.py b/sdk/identity/azure-identity/azure/identity/aio/_authn_client.py index 47a4a90ebca7..9cfe13bd9498 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_authn_client.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_authn_client.py @@ -54,6 +54,16 @@ def __init__( self._pipeline = AsyncPipeline(transport=transport, policies=policies) super().__init__(**kwargs) + async def __aenter__(self): + await self._pipeline.__aenter__() + return self + + async def __aexit__(self, *args): + await self.close() + + async def close(self) -> None: + await self._pipeline.__aexit__() + async def request_token( self, scopes: "Iterable[str]", diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/base.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/base.py new file mode 100644 index 000000000000..3dbc1a3a7a68 --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/base.py @@ -0,0 +1,21 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import abc + + +class AsyncCredentialBase(abc.ABC): + @abc.abstractmethod + async def close(self): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + await self.close() + + @abc.abstractmethod + async def get_token(self, *scopes, **kwargs): + pass diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py index 04a94ba2b96c..155d02860763 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py @@ -2,27 +2,40 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import asyncio from typing import TYPE_CHECKING from azure.core.exceptions import ClientAuthenticationError -from ... import ChainedTokenCredential as SyncChainedTokenCredential +from .base import AsyncCredentialBase +from ..._credentials.chained import _get_error_message if TYPE_CHECKING: from typing import Any from azure.core.credentials import AccessToken + from azure.core.credentials_async import AsyncTokenCredential -class ChainedTokenCredential(SyncChainedTokenCredential): +class ChainedTokenCredential(AsyncCredentialBase): """A sequence of credentials that is itself a credential. Its :func:`get_token` method calls ``get_token`` on each credential in the sequence, in order, returning the first valid token received. :param credentials: credential instances to form the chain - :type credentials: :class:`azure.core.credentials.TokenCredential` + :type credentials: :class:`azure.core.credentials.AsyncTokenCredential` """ - async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument + def __init__(self, *credentials: "AsyncTokenCredential") -> None: + if not credentials: + raise ValueError("at least one credential is required") + self.credentials = credentials + + async def close(self): + """Close the transport sessions of all credentials in the chain.""" + + await asyncio.gather(*(credential.close() for credential in self.credentials)) + + async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": """Asynchronously request a token from each credential, in order, returning the first token received. If no credential provides a token, raises :class:`azure.core.exceptions.ClientAuthenticationError` @@ -41,5 +54,5 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py history.append((credential, ex.message)) except Exception as ex: # pylint: disable=broad-except history.append((credential, str(ex))) - error_message = self._get_error_message(history) + error_message = _get_error_message(history) raise ClientAuthenticationError(message=error_message) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_credential.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_credential.py index d674bee2ec8d..79ea22931960 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_credential.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_credential.py @@ -4,6 +4,7 @@ # ------------------------------------ from typing import TYPE_CHECKING +from .base import AsyncCredentialBase from .._authn_client import AsyncAuthnClient from ..._base import ClientSecretCredentialBase, CertificateCredentialBase @@ -12,7 +13,7 @@ from azure.core.credentials import AccessToken -class ClientSecretCredential(ClientSecretCredentialBase): +class ClientSecretCredential(ClientSecretCredentialBase, AsyncCredentialBase): """Authenticates as a service principal using a client ID and client secret. :param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID. @@ -28,6 +29,15 @@ def __init__(self, tenant_id: str, client_id: str, client_secret: str, **kwargs: super(ClientSecretCredential, self).__init__(tenant_id, client_id, client_secret, **kwargs) self._client = AsyncAuthnClient(tenant=tenant_id, **kwargs) + async def __aenter__(self): + await self._client.__aenter__() + return self + + async def close(self): + """Close the credential's transport session.""" + + await self._client.__aexit__() + async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument """Asynchronously request an access token for `scopes`. @@ -44,7 +54,7 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py return token # type: ignore -class CertificateCredential(CertificateCredentialBase): +class CertificateCredential(CertificateCredentialBase, AsyncCredentialBase): """Authenticates as a service principal using a certificate. :param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID. @@ -57,6 +67,15 @@ class CertificateCredential(CertificateCredentialBase): defines authorities for other clouds. """ + async def __aenter__(self): + await self._client.__aenter__() + return self + + async def close(self): + """Close the credential's transport session.""" + + await self._client.__aexit__() + async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument """Asynchronously request an access token for `scopes`. diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py index 56086433215d..33f420257a71 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py @@ -8,13 +8,14 @@ from azure.core.exceptions import ClientAuthenticationError from ..._constants import EnvironmentVariables from .client_credential import CertificateCredential, ClientSecretCredential +from .base import AsyncCredentialBase if TYPE_CHECKING: from typing import Any, Optional, Union from azure.core.credentials import AccessToken -class EnvironmentCredential: +class EnvironmentCredential(AsyncCredentialBase): """A credential configured by environment variables. This credential is capable of authenticating as a service principal using a client secret or a certificate, or as @@ -50,6 +51,17 @@ def __init__(self, **kwargs: "Any") -> None: **kwargs ) + async def __aenter__(self): + if self._credential: + await self._credential.__aenter__() + return self + + async def close(self): + """Close the credential's transport session.""" + + if self._credential: + await self._credential.__aexit__() + async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": """Asynchronously request an access token for `scopes`. diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py index dff552a2e560..697b5c07086e 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import abc import os from typing import TYPE_CHECKING @@ -10,6 +11,7 @@ from azure.core.pipeline.policies import AsyncRetryPolicy from azure.identity._credentials.managed_identity import _ManagedIdentityBase +from .base import AsyncCredentialBase from .._authn_client import AsyncAuthnClient from ..._constants import Endpoints, EnvironmentVariables @@ -37,6 +39,15 @@ def __new__(cls, *args, **kwargs): def __init__(self, **kwargs: "Any") -> None: pass + async def __aenter__(self): + pass + + async def __aexit__(self, *args): + pass + + async def close(self): + """Close the credential's transport session.""" + async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument """Asynchronously request an access token for `scopes`. @@ -49,10 +60,23 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py return AccessToken() -class _AsyncManagedIdentityBase(_ManagedIdentityBase): +class _AsyncManagedIdentityBase(_ManagedIdentityBase, AsyncCredentialBase): def __init__(self, endpoint: str, **kwargs: "Any") -> None: super().__init__(endpoint=endpoint, client_cls=AsyncAuthnClient, **kwargs) + async def __aenter__(self): + await self._client.__aenter__() + return self + + async def close(self): + """Close the credential's transport session.""" + + await self._client.__aexit__() + + @abc.abstractmethod + async def get_token(self, *scopes, **kwargs): + pass + @staticmethod def _create_config(**kwargs: "Any") -> "Configuration": """Build a default configuration for the credential's HTTP pipeline.""" diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py index bc5c2f799dac..46c518df6541 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py @@ -9,6 +9,7 @@ from ..._internal.shared_token_cache import NO_TOKEN, SharedTokenCacheBase from .._internal.aad_client import AadClient from .._internal.exception_wrapper import wrap_exceptions +from .base import AsyncCredentialBase if TYPE_CHECKING: from typing import Any @@ -16,7 +17,7 @@ from ..._internal.aad_client import AadClientBase -class SharedTokenCacheCredential(SharedTokenCacheBase): +class SharedTokenCacheCredential(SharedTokenCacheBase, AsyncCredentialBase): """Authenticates using tokens in the local cache shared between Microsoft applications. :param str username: @@ -24,6 +25,17 @@ class SharedTokenCacheCredential(SharedTokenCacheBase): may contain tokens for multiple identities. """ + async def __aenter__(self): + if self._client: + await self._client.__aenter__() + return self + + async def close(self): + """Close the credential's transport session.""" + + if self._client: + await self._client.__aexit__() + @wrap_exceptions async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument """Get an access token for `scopes` from the shared cache. diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py index 853af81ec319..ceb96ba87d21 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py @@ -19,8 +19,19 @@ class AadClient(AadClientBase): - # pylint:disable=arguments-differ + async def __aenter__(self): + await self._client.session.__aenter__() + return self + + async def __aexit__(self, *args): + await self.close() + + async def close(self) -> None: + """Close the client's transport session.""" + await self._client.session.__aexit__() + + # pylint:disable=arguments-differ def obtain_token_by_authorization_code( self, *args: "Any", loop: "asyncio.AbstractEventLoop" = None, **kwargs: "Any" ) -> "AccessToken": diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/msal_transport_adapter.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/msal_transport_adapter.py index 9fc38ad38727..1f4d12e151cc 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/msal_transport_adapter.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/msal_transport_adapter.py @@ -5,7 +5,6 @@ """Adapter to substitute an async azure-core pipeline for Requests in MSAL application token acquisition methods.""" import asyncio -import atexit from typing import TYPE_CHECKING from azure.core.configuration import Configuration @@ -51,21 +50,18 @@ def __init__( HttpLoggingPolicy(**kwargs), ] self._transport = transport or AioHttpTransport(configuration=config) - atexit.register(self._close_transport_session) # prevent aiohttp warnings self._pipeline = AsyncPipeline(transport=self._transport, policies=policies) - def _close_transport_session(self) -> None: - """If transport has a 'close' method, invoke it.""" + async def __aenter__(self): + await self._pipeline.__aenter__() + return self - close = getattr(self._transport, "close", None) - if not callable(close): - return + async def __aexit__(self, *args): + await self.close() - if asyncio.iscoroutinefunction(close): - # we expect no loop is running because this method should be called only when the interpreter is exiting - asyncio.new_event_loop().run_until_complete(close()) - else: - close() + async def close(self): + """Close the adapter's transport session.""" + await self._pipeline.__aexit__() def get( self, diff --git a/sdk/identity/azure-identity/tests/helpers.py b/sdk/identity/azure-identity/tests/helpers.py index 02f4cfd1f819..5e5f19dbf103 100644 --- a/sdk/identity/azure-identity/tests/helpers.py +++ b/sdk/identity/azure-identity/tests/helpers.py @@ -6,7 +6,6 @@ import json import time -from azure.core.pipeline.policies import SansIOHTTPPolicy import six try: @@ -142,15 +141,3 @@ def urlsafeb64_decode(s): padding_needed = 4 - len(s) % 4 return base64.urlsafe_b64decode(s + b"=" * padding_needed) - - -try: - import asyncio - - def async_validating_transport(requests, responses): - sync_transport = validating_transport(requests, responses) - return mock.Mock(send=asyncio.coroutine(sync_transport.send)) - - -except ImportError: - pass diff --git a/sdk/identity/azure-identity/tests/helpers_async.py b/sdk/identity/azure-identity/tests/helpers_async.py new file mode 100644 index 000000000000..38fac1d36395 --- /dev/null +++ b/sdk/identity/azure-identity/tests/helpers_async.py @@ -0,0 +1,49 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import asyncio +import functools +import sys +from unittest import mock + +from helpers import validating_transport + + +def get_completed_future(result=None): + future = asyncio.Future() + future.set_result(result) + return future + + +def wrap_in_future(fn): + """Return a completed Future whose result is the return of fn. + + Added to simplify using unittest.Mock in async code. Python 3.8's AsyncMock would be preferable. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + result = fn(*args, **kwargs) + return get_completed_future(result) + + return wrapper + + +class AsyncMockTransport(mock.MagicMock): + """Mock with do-nothing aenter/exit for mocking async transport. + + This is unnecessary on 3.8+, where MagicMocks implement aenter/exit. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if sys.version_info < (3, 8): + self.__aenter__ = mock.Mock(return_value=get_completed_future()) + self.__aexit__ = mock.Mock(return_value=get_completed_future()) + + +def async_validating_transport(requests, responses): + sync_transport = validating_transport(requests, responses) + return AsyncMockTransport(send=wrap_in_future(sync_transport.send)) diff --git a/sdk/identity/azure-identity/tests/test_auth_code_async.py b/sdk/identity/azure-identity/tests/test_auth_code_async.py index 84c02f18ba32..1f1d9719169c 100644 --- a/sdk/identity/azure-identity/tests/test_auth_code_async.py +++ b/sdk/identity/azure-identity/tests/test_auth_code_async.py @@ -12,7 +12,8 @@ from azure.identity.aio import AuthorizationCodeCredential import pytest -from helpers import async_validating_transport, build_aad_response, mock_response, Request +from helpers import build_aad_response, mock_response, Request +from helpers_async import async_validating_transport, wrap_in_future @pytest.mark.asyncio @@ -55,7 +56,7 @@ async def test_auth_code_credential(): mock_client = Mock(spec=object) obtain_by_auth_code = Mock(return_value=expected_token) - mock_client.obtain_token_by_authorization_code = asyncio.coroutine(obtain_by_auth_code) + mock_client.obtain_token_by_authorization_code = wrap_in_future(obtain_by_auth_code) credential = AuthorizationCodeCredential( client_id=client_id, @@ -81,7 +82,7 @@ async def test_auth_code_credential(): # no auth code, no cached token -> credential should use refresh token mock_client.get_cached_access_token = lambda *_: None mock_client.get_cached_refresh_tokens = lambda *_: ["this is a refresh token"] - mock_client.obtain_token_by_refresh_token = asyncio.coroutine(lambda *_, **__: expected_token) + mock_client.obtain_token_by_refresh_token = wrap_in_future(lambda *_, **__: expected_token) token = await credential.get_token("scope") assert token is expected_token diff --git a/sdk/identity/azure-identity/tests/test_authn_client_async.py b/sdk/identity/azure-identity/tests/test_authn_client_async.py index 06e63dc9dba9..b2982c2190b5 100644 --- a/sdk/identity/azure-identity/tests/test_authn_client_async.py +++ b/sdk/identity/azure-identity/tests/test_authn_client_async.py @@ -10,6 +10,7 @@ from azure.identity.aio._authn_client import AsyncAuthnClient from helpers import mock_response +from helpers_async import wrap_in_future @pytest.mark.asyncio @@ -24,5 +25,5 @@ def mock_send(request, **kwargs): assert path.startswith("/" + tenant) return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"}) - client = AsyncAuthnClient(tenant=tenant, transport=Mock(send=asyncio.coroutine(mock_send)), authority=authority) + client = AsyncAuthnClient(tenant=tenant, transport=Mock(send=wrap_in_future(mock_send)), authority=authority) await client.request_token(("scope",)) diff --git a/sdk/identity/azure-identity/tests/test_certificate_credential_async.py b/sdk/identity/azure-identity/tests/test_certificate_credential_async.py index 174f62f99baf..d14e929693b3 100644 --- a/sdk/identity/azure-identity/tests/test_certificate_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_certificate_credential_async.py @@ -11,15 +11,38 @@ from azure.identity._internal.user_agent import USER_AGENT from azure.identity.aio import CertificateCredential -from helpers import async_validating_transport, build_aad_response, urlsafeb64_decode, mock_response, Request -from test_certificate_credential import validate_jwt - import pytest +from helpers import build_aad_response, urlsafeb64_decode, mock_response, Request +from helpers_async import async_validating_transport, AsyncMockTransport +from test_certificate_credential import validate_jwt + CERT_PATH = os.path.join(os.path.dirname(__file__), "certificate.pem") +@pytest.mark.asyncio +async def test_close(): + transport = AsyncMockTransport() + credential = CertificateCredential("tenant-id", "client-id", CERT_PATH, transport=transport) + + await credential.close() + + assert transport.__aexit__.call_count == 1 + + +@pytest.mark.asyncio +async def test_context_manager(): + transport = AsyncMockTransport() + credential = CertificateCredential("tenant-id", "client-id", CERT_PATH, transport=transport) + + async with credential: + assert transport.__aenter__.call_count == 1 + + assert transport.__aenter__.call_count == 1 + assert transport.__aexit__.call_count == 1 + + @pytest.mark.asyncio async def test_policies_configurable(): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) diff --git a/sdk/identity/azure-identity/tests/test_chained_token_credential_async.py b/sdk/identity/azure-identity/tests/test_chained_token_credential_async.py new file mode 100644 index 000000000000..af66cd31bc48 --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_chained_token_credential_async.py @@ -0,0 +1,32 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from azure.identity.aio import ChainedTokenCredential +import pytest +from unittest.mock import Mock + +from helpers_async import get_completed_future + + +@pytest.mark.asyncio +async def test_close(): + credentials = [Mock(close=Mock(wraps=get_completed_future)) for _ in range(5)] + chain = ChainedTokenCredential(*credentials) + + await chain.close() + + for credential in credentials: + assert credential.close.call_count == 1 + + +@pytest.mark.asyncio +async def test_context_manager(): + credentials = [Mock(close=Mock(wraps=get_completed_future)) for _ in range(5)] + chain = ChainedTokenCredential(*credentials) + + async with chain: + pass + + for credential in credentials: + assert credential.close.call_count == 1 diff --git a/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py b/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py index 20beddc5917c..34c3197d64b7 100644 --- a/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py @@ -10,12 +10,34 @@ from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy from azure.identity._internal.user_agent import USER_AGENT from azure.identity.aio import ClientSecretCredential - -from helpers import async_validating_transport, build_aad_response, mock_response, Request +from helpers import build_aad_response, mock_response, Request +from helpers_async import async_validating_transport, AsyncMockTransport, wrap_in_future import pytest +@pytest.mark.asyncio +async def test_close(): + transport = AsyncMockTransport() + credential = ClientSecretCredential("tenant-id", "client-id", "client-secret", transport=transport) + + await credential.close() + + assert transport.__aexit__.call_count == 1 + + +@pytest.mark.asyncio +async def test_context_manager(): + transport = AsyncMockTransport() + credential = ClientSecretCredential("tenant-id", "client-id", "client-secret", transport=transport) + + async with credential: + assert transport.__aenter__.call_count == 1 + + assert transport.__aenter__.call_count == 1 + assert transport.__aexit__.call_count == 1 + + @pytest.mark.asyncio async def test_policies_configurable(): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) @@ -88,7 +110,7 @@ async def test_cache(): "token_type": "Bearer", } mock_send = Mock(return_value=mock_response(json_payload=token_payload)) - transport = Mock(send=asyncio.coroutine(mock_send)) + transport = Mock(send=wrap_in_future(mock_send)) scope = "scope" credential = ClientSecretCredential("tenant-id", "client-id", "secret", transport=transport) diff --git a/sdk/identity/azure-identity/tests/test_default_async.py b/sdk/identity/azure-identity/tests/test_default_async.py index 0efab8607e31..a83bce5b136d 100644 --- a/sdk/identity/azure-identity/tests/test_default_async.py +++ b/sdk/identity/azure-identity/tests/test_default_async.py @@ -7,14 +7,14 @@ from unittest.mock import Mock, patch from urllib.parse import urlparse -from azure.core.credentials import AccessToken from azure.identity import KnownAuthorities from azure.identity.aio import DefaultAzureCredential, SharedTokenCacheCredential from azure.identity.aio._credentials.managed_identity import ImdsCredential, MsiCredential from azure.identity._constants import EnvironmentVariables import pytest -from helpers import async_validating_transport, mock_response, Request +from helpers import mock_response, Request +from helpers_async import async_validating_transport, wrap_in_future from test_shared_cache_credential import build_aad_response, get_account_event, populated_cache @@ -61,7 +61,7 @@ async def send(request, **_): # managed identity credential should ignore authority with patch("os.environ", {EnvironmentVariables.MSI_ENDPOINT: "https://some.url"}): - transport = Mock(send=asyncio.coroutine(lambda *_, **__: response)) + transport = Mock(send=wrap_in_future(lambda *_, **__: response)) if authority_kwarg: credential = DefaultAzureCredential(authority=authority_kwarg, transport=transport) else: diff --git a/sdk/identity/azure-identity/tests/test_identity_async.py b/sdk/identity/azure-identity/tests/test_identity_async.py index 66aaef45b7d1..cef2f7a54fe1 100644 --- a/sdk/identity/azure-identity/tests/test_identity_async.py +++ b/sdk/identity/azure-identity/tests/test_identity_async.py @@ -21,8 +21,8 @@ from azure.identity.aio._credentials.managed_identity import ImdsCredential from azure.identity._constants import EnvironmentVariables -from helpers import mock_response, Request, async_validating_transport - +from helpers import mock_response, Request +from helpers_async import async_validating_transport, wrap_in_future @pytest.mark.asyncio async def test_client_secret_environment_credential(): @@ -84,7 +84,7 @@ async def raise_authn_error(message="it didn't work"): credentials = [ Mock(get_token=Mock(wraps=raise_authn_error)), Mock(get_token=Mock(wraps=raise_authn_error)), - Mock(get_token=asyncio.coroutine(lambda _: expected_token)), + Mock(get_token=wrap_in_future(lambda _: expected_token)), ] token = await ChainedTokenCredential(*credentials).get_token("scope") @@ -97,7 +97,7 @@ async def raise_authn_error(message="it didn't work"): @pytest.mark.asyncio async def test_chain_returns_first_token(): expected_token = Mock() - first_credential = Mock(get_token=asyncio.coroutine(lambda _: expected_token)) + first_credential = Mock(get_token=wrap_in_future(lambda _: expected_token)) second_credential = Mock(get_token=Mock()) aggregate = ChainedTokenCredential(first_credential, second_credential) @@ -130,7 +130,7 @@ async def test_imds_credential_cache(): ) mock_send = Mock(return_value=mock_response) - credential = ImdsCredential(transport=Mock(send=asyncio.coroutine(mock_send))) + credential = ImdsCredential(transport=Mock(send=wrap_in_future(mock_send))) token = await credential.get_token(scope) assert token.token == expired assert mock_send.call_count == 2 # first request was probing for endpoint availability @@ -166,7 +166,7 @@ async def test_imds_credential_retries(): mock_response.status_code = status_code try: await ImdsCredential( - transport=Mock(send=asyncio.coroutine(mock_send), sleep=asyncio.coroutine(lambda _: None)) + transport=Mock(send=wrap_in_future(mock_send), sleep=wrap_in_future(lambda _: None)) ).get_token("scope") except ClientAuthenticationError: pass diff --git a/sdk/identity/azure-identity/tests/test_imds_credential_async.py b/sdk/identity/azure-identity/tests/test_imds_credential_async.py new file mode 100644 index 000000000000..803e4182cc11 --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_imds_credential_async.py @@ -0,0 +1,30 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from azure.identity.aio._credentials.managed_identity import ImdsCredential +import pytest + +from helpers_async import AsyncMockTransport + + +@pytest.mark.asyncio +async def test_imds_close(): + transport = AsyncMockTransport() + + credential = ImdsCredential(transport=transport) + + await credential.close() + + assert transport.__aexit__.call_count == 1 + + +@pytest.mark.asyncio +async def test_imds_context_manager(): + transport = AsyncMockTransport() + credential = ImdsCredential(transport=transport) + + async with credential: + pass + + assert transport.__aexit__.call_count == 1 diff --git a/sdk/identity/azure-identity/tests/test_managed_identity_async.py b/sdk/identity/azure-identity/tests/test_managed_identity_async.py index e3cc4e42cf85..65cf7f495b9c 100644 --- a/sdk/identity/azure-identity/tests/test_managed_identity_async.py +++ b/sdk/identity/azure-identity/tests/test_managed_identity_async.py @@ -12,7 +12,8 @@ import pytest -from helpers import async_validating_transport, mock_response, Request +from helpers import mock_response, Request +from helpers_async import async_validating_transport @pytest.mark.asyncio diff --git a/sdk/identity/azure-identity/tests/test_msi_credential_async.py b/sdk/identity/azure-identity/tests/test_msi_credential_async.py new file mode 100644 index 000000000000..2f48abf9094f --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_msi_credential_async.py @@ -0,0 +1,37 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from unittest import mock + +from azure.identity._constants import EnvironmentVariables +from azure.identity.aio._credentials.managed_identity import MsiCredential +import pytest + +from helpers_async import AsyncMockTransport + + +@pytest.mark.asyncio +async def test_close(): + transport = AsyncMockTransport() + + with mock.patch("os.environ", {EnvironmentVariables.MSI_ENDPOINT: "https://url"}): + credential = MsiCredential(transport=transport) + + await credential.close() + + assert transport.__aexit__.call_count == 1 + + +@pytest.mark.asyncio +async def test_context_manager(): + transport = AsyncMockTransport() + + with mock.patch("os.environ", {EnvironmentVariables.MSI_ENDPOINT: "https://url"}): + credential = MsiCredential(transport=transport) + + async with credential: + assert transport.__aenter__.call_count == 1 + + assert transport.__aenter__.call_count == 1 + assert transport.__aexit__.call_count == 1 diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py index 92a052762daf..5aa8342ae9e0 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from unittest.mock import Mock +from unittest.mock import Mock, patch from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import SansIOHTTPPolicy @@ -19,10 +19,53 @@ from msal import TokenCache import pytest -from helpers import async_validating_transport, build_aad_response, build_id_token, mock_response, Request +from helpers import build_aad_response, build_id_token, mock_response, Request +from helpers_async import async_validating_transport, AsyncMockTransport from test_shared_cache_credential import get_account_event, populated_cache +@pytest.mark.asyncio +async def test_close(): + transport = AsyncMockTransport() + credential = SharedTokenCacheCredential( + _cache=populated_cache(get_account_event("test@user", "uid", "utid")), transport=transport + ) + + await credential.close() + + assert transport.__aexit__.call_count == 1 + + +@pytest.mark.asyncio +async def test_context_manager(): + transport = AsyncMockTransport() + credential = SharedTokenCacheCredential( + _cache=populated_cache(get_account_event("test@user", "uid", "utid")), transport=transport + ) + + async with credential: + assert transport.__aenter__.call_count == 1 + + assert transport.__aenter__.call_count == 1 + assert transport.__aexit__.call_count == 1 + + +@pytest.mark.asyncio +async def test_context_manager_no_cache(): + """the credential shouldn't open/close sessions when instantiated in an environment with no cache""" + + transport = AsyncMockTransport() + with patch.dict("azure.identity._internal.shared_token_cache.os.environ", {}, clear=True): + # clearing the environment ensures the credential won't try to load a cache + credential = SharedTokenCacheCredential(transport=transport) + + async with credential: + assert transport.__aenter__.call_count == 0 + + assert transport.__aenter__.call_count == 0 + assert transport.__aexit__.call_count == 0 + + @pytest.mark.asyncio async def test_policies_configurable(): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py index 200ce3272a6d..7dcfd9a45633 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py @@ -13,18 +13,28 @@ requirements can change. For example, a vault may move to a new tenant. In such a case the policy will attempt the protocol again. """ +from typing import TYPE_CHECKING -from azure.core.pipeline import PipelineRequest from azure.core.pipeline.policies import AsyncHTTPPolicy -from azure.core.pipeline.transport import HttpResponse -from . import ChallengeAuthPolicyBase, HttpChallenge, HttpChallengeCache +from . import ChallengeAuthPolicyBase, HttpChallengeCache + +if TYPE_CHECKING: + from typing import Any + from azure.core.credentials_async import AsyncTokenCredential + from azure.core.pipeline import PipelineRequest + from azure.core.pipeline.transport import HttpResponse + from . import HttpChallenge class AsyncChallengeAuthPolicy(ChallengeAuthPolicyBase, AsyncHTTPPolicy): """policy for handling HTTP authentication challenges""" - async def send(self, request: PipelineRequest) -> HttpResponse: + def __init__(self, credential: "AsyncTokenCredential", **kwargs: "Any") -> None: + self._credential = credential + super().__init__(**kwargs) + + async def send(self, request: "PipelineRequest") -> "HttpResponse": challenge = HttpChallengeCache.get_challenge_for_url(request.http_request.url) if not challenge: challenge_request = self._get_challenge_request(request) @@ -54,7 +64,7 @@ async def send(self, request: PipelineRequest) -> HttpResponse: return response - async def _handle_challenge(self, request: PipelineRequest, challenge: HttpChallenge) -> None: + async def _handle_challenge(self, request: "PipelineRequest", challenge: "HttpChallenge") -> None: """authenticate according to challenge, add Authorization header to request""" if self._need_new_token: diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py index a16fca4f5b28..617965080f78 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py @@ -31,16 +31,14 @@ if TYPE_CHECKING: # pylint:disable=unused-import + from typing import Any + from azure.core.credentials import TokenCredential from azure.core.pipeline.transport import HttpResponse class ChallengeAuthPolicyBase(_BearerTokenCredentialPolicyBase): """Sans I/O base for challenge authentication policies""" - # pylint:disable=useless-super-delegation - def __init__(self, credential, **kwargs): - super(ChallengeAuthPolicyBase, self).__init__(credential, **kwargs) - @staticmethod def _update_challenge(request, challenger): # type: (HttpRequest, HttpResponse) -> HttpChallenge @@ -74,6 +72,11 @@ def _get_challenge_request(request): class ChallengeAuthPolicy(ChallengeAuthPolicyBase, HTTPPolicy): """policy for handling HTTP authentication challenges""" + def __init__(self, credential, **kwargs): + # type: (TokenCredential, **Any) -> None + self._credential = credential + super(ChallengeAuthPolicy, self).__init__(**kwargs) + def send(self, request): # type: (PipelineRequest) -> HttpResponse diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py index 200ce3272a6d..7dcfd9a45633 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py @@ -13,18 +13,28 @@ requirements can change. For example, a vault may move to a new tenant. In such a case the policy will attempt the protocol again. """ +from typing import TYPE_CHECKING -from azure.core.pipeline import PipelineRequest from azure.core.pipeline.policies import AsyncHTTPPolicy -from azure.core.pipeline.transport import HttpResponse -from . import ChallengeAuthPolicyBase, HttpChallenge, HttpChallengeCache +from . import ChallengeAuthPolicyBase, HttpChallengeCache + +if TYPE_CHECKING: + from typing import Any + from azure.core.credentials_async import AsyncTokenCredential + from azure.core.pipeline import PipelineRequest + from azure.core.pipeline.transport import HttpResponse + from . import HttpChallenge class AsyncChallengeAuthPolicy(ChallengeAuthPolicyBase, AsyncHTTPPolicy): """policy for handling HTTP authentication challenges""" - async def send(self, request: PipelineRequest) -> HttpResponse: + def __init__(self, credential: "AsyncTokenCredential", **kwargs: "Any") -> None: + self._credential = credential + super().__init__(**kwargs) + + async def send(self, request: "PipelineRequest") -> "HttpResponse": challenge = HttpChallengeCache.get_challenge_for_url(request.http_request.url) if not challenge: challenge_request = self._get_challenge_request(request) @@ -54,7 +64,7 @@ async def send(self, request: PipelineRequest) -> HttpResponse: return response - async def _handle_challenge(self, request: PipelineRequest, challenge: HttpChallenge) -> None: + async def _handle_challenge(self, request: "PipelineRequest", challenge: "HttpChallenge") -> None: """authenticate according to challenge, add Authorization header to request""" if self._need_new_token: diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py index a16fca4f5b28..617965080f78 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py @@ -31,16 +31,14 @@ if TYPE_CHECKING: # pylint:disable=unused-import + from typing import Any + from azure.core.credentials import TokenCredential from azure.core.pipeline.transport import HttpResponse class ChallengeAuthPolicyBase(_BearerTokenCredentialPolicyBase): """Sans I/O base for challenge authentication policies""" - # pylint:disable=useless-super-delegation - def __init__(self, credential, **kwargs): - super(ChallengeAuthPolicyBase, self).__init__(credential, **kwargs) - @staticmethod def _update_challenge(request, challenger): # type: (HttpRequest, HttpResponse) -> HttpChallenge @@ -74,6 +72,11 @@ def _get_challenge_request(request): class ChallengeAuthPolicy(ChallengeAuthPolicyBase, HTTPPolicy): """policy for handling HTTP authentication challenges""" + def __init__(self, credential, **kwargs): + # type: (TokenCredential, **Any) -> None + self._credential = credential + super(ChallengeAuthPolicy, self).__init__(**kwargs) + def send(self, request): # type: (PipelineRequest) -> HttpResponse diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py index 200ce3272a6d..7dcfd9a45633 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py @@ -13,18 +13,28 @@ requirements can change. For example, a vault may move to a new tenant. In such a case the policy will attempt the protocol again. """ +from typing import TYPE_CHECKING -from azure.core.pipeline import PipelineRequest from azure.core.pipeline.policies import AsyncHTTPPolicy -from azure.core.pipeline.transport import HttpResponse -from . import ChallengeAuthPolicyBase, HttpChallenge, HttpChallengeCache +from . import ChallengeAuthPolicyBase, HttpChallengeCache + +if TYPE_CHECKING: + from typing import Any + from azure.core.credentials_async import AsyncTokenCredential + from azure.core.pipeline import PipelineRequest + from azure.core.pipeline.transport import HttpResponse + from . import HttpChallenge class AsyncChallengeAuthPolicy(ChallengeAuthPolicyBase, AsyncHTTPPolicy): """policy for handling HTTP authentication challenges""" - async def send(self, request: PipelineRequest) -> HttpResponse: + def __init__(self, credential: "AsyncTokenCredential", **kwargs: "Any") -> None: + self._credential = credential + super().__init__(**kwargs) + + async def send(self, request: "PipelineRequest") -> "HttpResponse": challenge = HttpChallengeCache.get_challenge_for_url(request.http_request.url) if not challenge: challenge_request = self._get_challenge_request(request) @@ -54,7 +64,7 @@ async def send(self, request: PipelineRequest) -> HttpResponse: return response - async def _handle_challenge(self, request: PipelineRequest, challenge: HttpChallenge) -> None: + async def _handle_challenge(self, request: "PipelineRequest", challenge: "HttpChallenge") -> None: """authenticate according to challenge, add Authorization header to request""" if self._need_new_token: diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py index a16fca4f5b28..617965080f78 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py @@ -31,16 +31,14 @@ if TYPE_CHECKING: # pylint:disable=unused-import + from typing import Any + from azure.core.credentials import TokenCredential from azure.core.pipeline.transport import HttpResponse class ChallengeAuthPolicyBase(_BearerTokenCredentialPolicyBase): """Sans I/O base for challenge authentication policies""" - # pylint:disable=useless-super-delegation - def __init__(self, credential, **kwargs): - super(ChallengeAuthPolicyBase, self).__init__(credential, **kwargs) - @staticmethod def _update_challenge(request, challenger): # type: (HttpRequest, HttpResponse) -> HttpChallenge @@ -74,6 +72,11 @@ def _get_challenge_request(request): class ChallengeAuthPolicy(ChallengeAuthPolicyBase, HTTPPolicy): """policy for handling HTTP authentication challenges""" + def __init__(self, credential, **kwargs): + # type: (TokenCredential, **Any) -> None + self._credential = credential + super(ChallengeAuthPolicy, self).__init__(**kwargs) + def send(self, request): # type: (PipelineRequest) -> HttpResponse