diff --git a/sdk/core/azure-core/CHANGELOG.md b/sdk/core/azure-core/CHANGELOG.md index 5740ca3f3989..876882b471f2 100644 --- a/sdk/core/azure-core/CHANGELOG.md +++ b/sdk/core/azure-core/CHANGELOG.md @@ -4,9 +4,13 @@ ### Features Added -- `AccessToken` now has an optional `refresh_on` attribute that can be used to specify when the token should be refreshed. #36183 - - `BearerTokenCredentialPolicy` and `AsyncBearerTokenCredentialPolicy` now check the `refresh_on` attribute when determining if a token request should be made. -- Added `azure.core.AzureClouds` enum to represent the different Azure clouds. +- Added azure.core.AzureClouds enum to represent the different Azure clouds. +- Added two new credential protocol classes, `SupportsTokenInfo` and `AsyncSupportsTokenInfo`, to offer more extensibility in supporting various token acquisition scenarios. #36565 + - Each new protocol class defines a `get_token_info` method that returns an `AccessTokenInfo` object. +- Added a new `TokenRequestOptions` class, which is a `TypedDict` with optional parameters, that can be used to define options for token requests through the `get_token_info` method. #36565 +- Added a new `AccessTokenInfo` class, which is returned by `get_token_info` implementations. This class contains the token, its expiration time, and optional additional information like when a token should be refreshed. #36565 +- `BearerTokenCredentialPolicy` and `AsyncBearerTokenCredentialPolicy` now first check if a credential has the `get_token_info` method defined. If so, the `get_token_info` method is used to acquire a token. Otherwise, the `get_token` method is used. #36565 + - These policies now also check the `refresh_on` attribute when determining if a new token request should be made. ### Breaking Changes diff --git a/sdk/core/azure-core/azure/core/credentials.py b/sdk/core/azure-core/azure/core/credentials.py index a2885137ba54..dbe1a7cdbde0 100644 --- a/sdk/core/azure-core/azure/core/credentials.py +++ b/sdk/core/azure-core/azure/core/credentials.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. See LICENSE.txt in the project root for # license information. # ------------------------------------------------------------------------- -from typing import Any, NamedTuple, Optional +from typing import Any, NamedTuple, Optional, TypedDict, Union, ContextManager from typing_extensions import Protocol, runtime_checkable @@ -11,13 +11,56 @@ class AccessToken(NamedTuple): """Represents an OAuth access token.""" token: str + """The token string.""" expires_on: int - refresh_on: Optional[int] = None + """The token's expiration time in Unix time.""" -AccessToken.token.__doc__ = """The token string.""" -AccessToken.expires_on.__doc__ = """The token's expiration time in Unix time.""" -AccessToken.refresh_on.__doc__ = """When the token should be refreshed in Unix time.""" +class AccessTokenInfo: + """Information about an OAuth access token. + + This class is an alternative to `AccessToken` which provides additional information about the token. + + :param str token: The token string. + :param int expires_on: The token's expiration time in Unix time. + :keyword str token_type: The type of access token. Defaults to 'Bearer'. + :keyword int refresh_on: Specifies the time, in Unix time, when the cached token should be proactively + refreshed. Optional. + """ + + token: str + """The token string.""" + expires_on: int + """The token's expiration time in Unix time.""" + token_type: str + """The type of access token.""" + refresh_on: Optional[int] + """Specifies the time, in Unix time, when the cached token should be proactively refreshed. Optional.""" + + def __init__( + self, token: str, expires_on: int, *, token_type: str = "Bearer", refresh_on: Optional[int] = None + ) -> None: + self.token = token + self.expires_on = expires_on + self.token_type = token_type + self.refresh_on = refresh_on + + def __repr__(self) -> str: + return "AccessTokenInfo(token='{}', expires_on={}, token_type='{}', refresh_on={})".format( + self.token, self.expires_on, self.token_type, self.refresh_on + ) + + +class TokenRequestOptions(TypedDict, total=False): + """Options to use for access token requests. All parameters are optional.""" + + claims: str + """Additional claims required in the token, such as those returned in a resource provider's claims + challenge following an authorization failure.""" + tenant_id: str + """The tenant ID to include in the token request.""" + enable_cae: bool + """Indicates whether to enable Continuous Access Evaluation (CAE) for the requested token.""" @runtime_checkable @@ -30,7 +73,7 @@ def get_token( claims: Optional[str] = None, tenant_id: Optional[str] = None, enable_cae: bool = False, - **kwargs: Any + **kwargs: Any, ) -> AccessToken: """Request an access token for `scopes`. @@ -48,6 +91,32 @@ def get_token( ... +@runtime_checkable +class SupportsTokenInfo(Protocol, ContextManager["SupportsTokenInfo"]): + """Protocol for classes able to provide OAuth access tokens with additional properties.""" + + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. + + :param str scopes: The type of access needed. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + """ + ... + + def close(self) -> None: + pass + + +TokenProvider = Union[TokenCredential, SupportsTokenInfo] + + class AzureNamedKey(NamedTuple): """Represents a name and key pair.""" @@ -59,8 +128,12 @@ class AzureNamedKey(NamedTuple): "AzureKeyCredential", "AzureSasCredential", "AccessToken", + "AccessTokenInfo", + "SupportsTokenInfo", "AzureNamedKeyCredential", "TokenCredential", + "TokenRequestOptions", + "TokenProvider", ] diff --git a/sdk/core/azure-core/azure/core/credentials_async.py b/sdk/core/azure-core/azure/core/credentials_async.py index 94e470d88c47..ad0e80b73e96 100644 --- a/sdk/core/azure-core/azure/core/credentials_async.py +++ b/sdk/core/azure-core/azure/core/credentials_async.py @@ -4,9 +4,13 @@ # ------------------------------------ from __future__ import annotations from types import TracebackType -from typing import Any, Optional, AsyncContextManager, Type +from typing import Any, Optional, AsyncContextManager, Type, Union from typing_extensions import Protocol, runtime_checkable -from .credentials import AccessToken as _AccessToken +from .credentials import ( + AccessToken as _AccessToken, + AccessTokenInfo as _AccessTokenInfo, + TokenRequestOptions as _TokenRequestOptions, +) @runtime_checkable @@ -46,3 +50,37 @@ async def __aexit__( traceback: Optional[TracebackType] = None, ) -> None: pass + + +@runtime_checkable +class AsyncSupportsTokenInfo(Protocol, AsyncContextManager["AsyncSupportsTokenInfo"]): + """Protocol for classes able to provide OAuth access tokens with additional properties.""" + + async def get_token_info(self, *scopes: str, options: Optional[_TokenRequestOptions] = None) -> _AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. + + :param str scopes: The type of access needed. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing the token string and its expiration time in Unix time. + """ + ... + + async def close(self) -> None: + pass + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + pass + + +AsyncTokenProvider = Union[AsyncTokenCredential, AsyncSupportsTokenInfo] diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py index cf0a2c662efe..c28d03bcf771 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py @@ -4,7 +4,8 @@ # license information. # ------------------------------------------------------------------------- import time -from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any +from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any, Union, cast +from azure.core.credentials import TokenCredential, SupportsTokenInfo, TokenRequestOptions, TokenProvider from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.transport import HttpResponse as LegacyHttpResponse, HttpRequest as LegacyHttpRequest from azure.core.rest import HttpResponse, HttpRequest @@ -15,7 +16,7 @@ # pylint:disable=unused-import from azure.core.credentials import ( AccessToken, - TokenCredential, + AccessTokenInfo, AzureKeyCredential, AzureSasCredential, ) @@ -29,17 +30,17 @@ class _BearerTokenCredentialPolicyBase: """Base class for a Bearer Token Credential Policy. :param credential: The credential. - :type credential: ~azure.core.credentials.TokenCredential + :type credential: ~azure.core.credentials.TokenProvider :param str scopes: Lets you specify the type of access needed. :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested tokens. Defaults to False. """ - def __init__(self, credential: "TokenCredential", *scopes: str, **kwargs: Any) -> None: + def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> None: super(_BearerTokenCredentialPolicyBase, self).__init__() self._scopes = scopes self._credential = credential - self._token: Optional["AccessToken"] = None + self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None self._enable_cae: bool = kwargs.get("enable_cae", False) @staticmethod @@ -70,11 +71,29 @@ def _update_headers(headers: MutableMapping[str, str], token: str) -> None: @property def _need_new_token(self) -> bool: now = time.time() - return ( - not self._token - or (self._token.refresh_on is not None and self._token.refresh_on <= now) - or self._token.expires_on - now < 300 - ) + refresh_on = getattr(self._token, "refresh_on", None) + return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 + + def _request_token(self, *scopes: str, **kwargs: Any) -> None: + """Request a new token from the credential. + + This will call the credential's appropriate method to get a token and store it in the policy. + + :param str scopes: The type of access needed. + """ + if self._enable_cae: + kwargs.setdefault("enable_cae", self._enable_cae) + + if hasattr(self._credential, "get_token_info"): + options: TokenRequestOptions = {} + # Loop through all the keyword arguments and check if they are part of the TokenRequestOptions. + for key in list(kwargs.keys()): + if key in TokenRequestOptions.__annotations__: # pylint:disable=no-member + options[key] = kwargs.pop(key) # type: ignore[literal-required] + + self._token = cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options) + else: + self._token = cast(TokenCredential, self._credential).get_token(*scopes, **kwargs) class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]): @@ -98,11 +117,9 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: self._enforce_https(request) if self._token is None or self._need_new_token: - if self._enable_cae: - self._token = self._credential.get_token(*self._scopes, enable_cae=self._enable_cae) - else: - self._token = self._credential.get_token(*self._scopes) - self._update_headers(request.http_request.headers, self._token.token) + self._request_token(*self._scopes) + bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token + self._update_headers(request.http_request.headers, bearer_token) def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None: """Acquire a token from the credential and authorize the request with it. @@ -113,10 +130,9 @@ def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: :param ~azure.core.pipeline.PipelineRequest request: the request :param str scopes: required scopes of authentication """ - if self._enable_cae: - kwargs.setdefault("enable_cae", self._enable_cae) - self._token = self._credential.get_token(*scopes, **kwargs) - self._update_headers(request.http_request.headers, self._token.token) + self._request_token(*scopes, **kwargs) + bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token + self._update_headers(request.http_request.headers, bearer_token) def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]: """Authorize request with a bearer token and send it to the next policy diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py index cd5ed6773a17..e4b0b328eff7 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py @@ -4,9 +4,10 @@ # license information. # ------------------------------------------------------------------------- import time -from typing import TYPE_CHECKING, Any, Awaitable, Optional, cast, TypeVar +from typing import Any, Awaitable, Optional, cast, TypeVar, Union -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions +from azure.core.credentials_async import AsyncTokenCredential, AsyncSupportsTokenInfo, AsyncTokenProvider from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import AsyncHTTPPolicy from azure.core.pipeline.policies._authentication import ( @@ -18,9 +19,6 @@ from .._tools_async import await_result -if TYPE_CHECKING: - from azure.core.credentials_async import AsyncTokenCredential - AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType", AsyncHttpResponse, LegacyAsyncHttpResponse) HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) @@ -29,18 +27,18 @@ class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy[HTTPRequestType, AsyncHTT """Adds a bearer token Authorization header to requests. :param credential: The credential. - :type credential: ~azure.core.credentials.TokenCredential + :type credential: ~azure.core.credentials_async.AsyncTokenProvider :param str scopes: Lets you specify the type of access needed. :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested tokens. Defaults to False. """ - def __init__(self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: Any) -> None: + def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None: super().__init__() self._credential = credential self._scopes = scopes self._lock_instance = None - self._token: Optional["AccessToken"] = None + self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None self._enable_cae: bool = kwargs.get("enable_cae", False) @property @@ -62,13 +60,9 @@ async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: async with self._lock: # double check because another coroutine may have acquired a token while we waited to acquire the lock if self._token is None or self._need_new_token(): - if self._enable_cae: - self._token = await await_result( - self._credential.get_token, *self._scopes, enable_cae=self._enable_cae - ) - else: - self._token = await await_result(self._credential.get_token, *self._scopes) - request.http_request.headers["Authorization"] = "Bearer " + cast(AccessToken, self._token).token + await self._request_token(*self._scopes) + bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token + request.http_request.headers["Authorization"] = "Bearer " + bearer_token async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None: """Acquire a token from the credential and authorize the request with it. @@ -79,11 +73,11 @@ async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *sc :param ~azure.core.pipeline.PipelineRequest request: the request :param str scopes: required scopes of authentication """ - if self._enable_cae: - kwargs.setdefault("enable_cae", self._enable_cae) + async with self._lock: - self._token = await await_result(self._credential.get_token, *scopes, **kwargs) - request.http_request.headers["Authorization"] = "Bearer " + cast(AccessToken, self._token).token + await self._request_token(*scopes, **kwargs) + bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token + request.http_request.headers["Authorization"] = "Bearer " + bearer_token async def send( self, request: PipelineRequest[HTTPRequestType] @@ -165,8 +159,30 @@ def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None: def _need_new_token(self) -> bool: now = time.time() - return ( - not self._token - or (self._token.refresh_on is not None and self._token.refresh_on <= now) - or self._token.expires_on - now < 300 - ) + refresh_on = getattr(self._token, "refresh_on", None) + return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 + + async def _request_token(self, *scopes: str, **kwargs: Any) -> None: + """Request a new token from the credential. + + This will call the credential's appropriate method to get a token and store it in the policy. + + :param str scopes: The type of access needed. + """ + if self._enable_cae: + kwargs.setdefault("enable_cae", self._enable_cae) + + if hasattr(self._credential, "get_token_info"): + options: TokenRequestOptions = {} + # Loop through all the keyword arguments and check if they are part of the TokenRequestOptions. + for key in list(kwargs.keys()): + if key in TokenRequestOptions.__annotations__: # pylint:disable=no-member + options[key] = kwargs.pop(key) # type: ignore[literal-required] + + self._token = await await_result( + cast(AsyncSupportsTokenInfo, self._credential).get_token_info, + *scopes, + options=options, + ) + else: + self._token = await await_result(cast(AsyncTokenCredential, self._credential).get_token, *scopes, **kwargs) diff --git a/sdk/core/azure-core/tests/async_tests/test_authentication_async.py b/sdk/core/azure-core/tests/async_tests/test_authentication_async.py index dc0bb926bce4..416875256888 100644 --- a/sdk/core/azure-core/tests/async_tests/test_authentication_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_authentication_async.py @@ -6,13 +6,13 @@ import asyncio import sys import time -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, AsyncMock, create_autospec from requests import Response -from azure.core.credentials import AccessToken -from azure.core.credentials_async import AsyncTokenCredential +from azure.core.credentials import AccessToken, AccessTokenInfo +from azure.core.credentials_async import AsyncTokenCredential, AsyncSupportsTokenInfo from azure.core.exceptions import ServiceRequestError -from azure.core.pipeline import AsyncPipeline +from azure.core.pipeline import AsyncPipeline, PipelineRequest, PipelineContext from azure.core.pipeline.policies import ( AsyncBearerTokenCredentialPolicy, SansIOHTTPPolicy, @@ -56,6 +56,84 @@ async def get_token(*_, **__): assert get_token_calls == 1 +@pytest.mark.asyncio +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_bearer_policy_authorize_request(http_request): + """The authorize_request method should add a header containing a token from its credential""" + # 2524608000 == 01/01/2050 @ 12:00am (UTC) + expected_token = AccessToken("expected_token", 2524608000) + + fake_credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=expected_token)) + policy = AsyncBearerTokenCredentialPolicy(fake_credential, "scope") + http_req = http_request("GET", "https://spam.eggs") + request = PipelineRequest(http_req, PipelineContext(None)) + + await policy.authorize_request(request, "scope", claims="foo") + assert policy._token is expected_token + assert http_req.headers["Authorization"] == f"Bearer {expected_token.token}" + assert fake_credential.get_token.call_count == 1 + assert fake_credential.get_token.call_args[0] == ("scope",) + assert fake_credential.get_token.call_args[1] == {"claims": "foo"} + + +@pytest.mark.asyncio +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_bearer_policy_adds_header_access_token_info(http_request): + """The bearer token policy should also add an auth header when an AccessTokenInfo is returned.""" + # 2524608000 == 01/01/2050 @ 12:00am (UTC) + access_token = AccessToken("other_token", 2524608000) + expected_token = AccessTokenInfo("expected_token", 2524608000, refresh_on=2524608000) + + async def verify_authorization_header(request): + assert request.http_request.headers["Authorization"] == "Bearer {}".format(expected_token.token) + return Mock() + + get_token_calls = 0 + get_token_info_calls = 0 + + class MockCredential(AsyncTokenCredential): + async def get_token(self, *_, **__): + nonlocal get_token_calls + get_token_calls += 1 + return access_token + + async def get_token_info(*_, **__): + nonlocal get_token_info_calls + get_token_info_calls += 1 + return expected_token + + fake_credential: AsyncTokenCredential = MockCredential() + policies = [AsyncBearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)] + pipeline = AsyncPipeline(transport=AsyncMock(), policies=policies) + + await pipeline.run(http_request("GET", "https://spam.eggs"), context=None) + assert get_token_info_calls == 1 + + await pipeline.run(http_request("GET", "https://spam.eggs"), context=None) + # Didn't need a new token + assert get_token_info_calls == 1 + # get_token should not have been called + assert get_token_calls == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_bearer_policy_authorize_request_access_token_info(http_request): + """The authorize_request method should add a header containing a token from its credential""" + # 2524608000 == 01/01/2050 @ 12:00am (UTC) + expected_token = AccessTokenInfo("expected_token", 2524608000) + fake_credential = Mock(get_token=Mock(), get_token_info=Mock(return_value=expected_token)) + policy = AsyncBearerTokenCredentialPolicy(fake_credential, "scope") + http_req = http_request("GET", "https://spam.eggs") + request = PipelineRequest(http_req, PipelineContext(None)) + + await policy.authorize_request(request, "scope", claims="foo") + assert policy._token is expected_token + assert http_req.headers["Authorization"] == f"Bearer {expected_token.token}" + assert fake_credential.get_token_info.call_args[0] == ("scope",) + assert fake_credential.get_token_info.call_args[1] == {"options": {"claims": "foo"}} + + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_bearer_policy_send(http_request): @@ -97,7 +175,7 @@ async def verify_request(request): @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_bearer_policy_token_caching(http_request): - good_for_one_hour = AccessToken("token", time.time() + 3600) + good_for_one_hour = AccessToken("token", int(time.time() + 3600)) expected_token = good_for_one_hour get_token_calls = 0 @@ -119,7 +197,7 @@ async def get_token(*_, **__): await pipeline.run(http_request("GET", "https://spam.eggs")) assert get_token_calls == 1 # token is good for an hour -> policy should return it from cache - expired_token = AccessToken("token", time.time()) + expired_token = AccessToken("token", int(time.time())) get_token_calls = 0 expected_token = expired_token policies = [ @@ -135,6 +213,48 @@ async def get_token(*_, **__): assert get_token_calls == 2 # token expired -> policy should call get_token +@pytest.mark.asyncio +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_bearer_policy_access_token_info_caching(http_request): + """The policy should cache AccessTokenInfo instances and refresh them when necessary.""" + + good_for_one_hour = AccessTokenInfo("token", int(time.time() + 3600)) + + credential = create_autospec(AsyncSupportsTokenInfo, instance=True, spec_set=True) + credential.get_token_info = AsyncMock(return_value=good_for_one_hour) + pipeline = AsyncPipeline(transport=AsyncMock(), policies=[AsyncBearerTokenCredentialPolicy(credential, "scope")]) + + await pipeline.run(http_request("GET", "https://spam.eggs")) + assert ( + credential.get_token_info.call_count == 1 + ) # policy has no token at first request -> it should call get_token_info + + await pipeline.run(http_request("GET", "https://spam.eggs")) + assert credential.get_token_info.call_count == 1 # token is good for an hour -> policy should return it from cache + + expired_token = AccessTokenInfo("token", int(time.time())) + credential.get_token_info.reset_mock() + credential.get_token_info.return_value = expired_token + pipeline = AsyncPipeline(transport=AsyncMock(), policies=[AsyncBearerTokenCredentialPolicy(credential, "scope")]) + + await pipeline.run(http_request("GET", "https://spam.eggs")) + assert credential.get_token_info.call_count == 1 + + await pipeline.run(http_request("GET", "https://spam.eggs")) + assert credential.get_token_info.call_count == 2 # token is expired -> policy should call get_token_info again + + refreshable_token = AccessTokenInfo("token", int(time.time() + 3600), refresh_on=int(time.time() - 1)) + credential.get_token_info.reset_mock() + credential.get_token_info.return_value = refreshable_token + pipeline = AsyncPipeline(transport=AsyncMock(), policies=[AsyncBearerTokenCredentialPolicy(credential, "scope")]) + + await pipeline.run(http_request("GET", "https://spam.eggs")) + assert credential.get_token_info.call_count == 1 + + await pipeline.run(http_request("GET", "https://spam.eggs")) + assert credential.get_token_info.call_count == 2 # token refresh-on time has passed, call again + + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_bearer_policy_optionally_enforces_https(http_request): diff --git a/sdk/core/azure-core/tests/test_authentication.py b/sdk/core/azure-core/tests/test_authentication.py index 2f323db3ab19..eeeafaab0041 100644 --- a/sdk/core/azure-core/tests/test_authentication.py +++ b/sdk/core/azure-core/tests/test_authentication.py @@ -3,13 +3,20 @@ # Licensed under the MIT License. See LICENSE.txt in the project root for # license information. # ------------------------------------------------------------------------- +from collections import namedtuple import time from itertools import product from requests import Response import azure.core -from azure.core.credentials import AccessToken, AzureKeyCredential, AzureSasCredential, AzureNamedKeyCredential +from azure.core.credentials import ( + AccessToken, + AzureKeyCredential, + AzureSasCredential, + AzureNamedKeyCredential, + AccessTokenInfo, +) from azure.core.exceptions import ServiceRequestError -from azure.core.pipeline import Pipeline +from azure.core.pipeline import Pipeline, PipelineRequest, PipelineContext from azure.core.pipeline.transport import HttpTransport, HttpRequest from azure.core.pipeline.policies import ( BearerTokenCredentialPolicy, @@ -50,6 +57,70 @@ def verify_authorization_header(request): assert fake_credential.get_token.call_count == 1 +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_authorize_request(http_request): + """The authorize_request method should add a header containing a token from its credential""" + # 2524608000 == 01/01/2050 @ 12:00am (UTC) + expected_token = AccessToken("expected_token", 2524608000) + + fake_credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=expected_token)) + policy = BearerTokenCredentialPolicy(fake_credential, "scope") + http_req = http_request("GET", "https://spam.eggs") + request = PipelineRequest(http_req, PipelineContext(None)) + + policy.authorize_request(request, "scope", claims="foo") + assert policy._token is expected_token + assert http_req.headers["Authorization"] == f"Bearer {expected_token.token}" + assert fake_credential.get_token.call_count == 1 + assert fake_credential.get_token.call_args[0] == ("scope",) + assert fake_credential.get_token.call_args[1] == {"claims": "foo"} + + +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_adds_header_access_token_info(http_request): + """The bearer token policy should also add an auth header when an AccessTokenInfo is returned.""" + # 2524608000 == 01/01/2050 @ 12:00am (UTC) + access_token = AccessToken("other_token", 2524608000) + expected_token = AccessTokenInfo("expected_token", 2524608000, refresh_on=2524608000) + + def verify_authorization_header(request): + assert request.http_request.headers["Authorization"] == "Bearer {}".format(expected_token.token) + return Mock() + + fake_credential = Mock(get_token=Mock(return_value=access_token), get_token_info=Mock(return_value=expected_token)) + policies = [BearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)] + + pipeline = Pipeline(transport=Mock(), policies=policies) + pipeline.run(http_request("GET", "https://spam.eggs")) + + assert fake_credential.get_token_info.call_count == 1 + + pipeline.run(http_request("GET", "https://spam.eggs")) + + # Didn't need a new token + assert fake_credential.get_token_info.call_count == 1 + + # get_token should not have been called + assert fake_credential.get_token.call_count == 0 + + +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_authorize_request_access_token_info(http_request): + """The authorize_request method should add a header containing a token from its credential""" + # 2524608000 == 01/01/2050 @ 12:00am (UTC) + expected_token = AccessTokenInfo("expected_token", 2524608000) + fake_credential = Mock(get_token=Mock(), get_token_info=Mock(return_value=expected_token)) + policy = BearerTokenCredentialPolicy(fake_credential, "scope") + http_req = http_request("GET", "https://spam.eggs") + request = PipelineRequest(http_req, PipelineContext(None)) + + policy.authorize_request(request, "scope", claims="foo") + assert policy._token is expected_token + assert http_req.headers["Authorization"] == f"Bearer {expected_token.token}" + assert fake_credential.get_token_info.call_args[0] == ("scope",) + assert fake_credential.get_token_info.call_args[1] == {"options": {"claims": "foo"}} + + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_bearer_policy_send(http_request): """The bearer token policy should invoke the next policy's send method and return the result""" @@ -72,7 +143,7 @@ def get_token(*_, **__): @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_bearer_policy_token_caching(http_request): - good_for_one_hour = AccessToken("token", time.time() + 3600) + good_for_one_hour = AccessToken("token", int(time.time() + 3600)) credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=good_for_one_hour)) pipeline = Pipeline(transport=Mock(), policies=[BearerTokenCredentialPolicy(credential, "scope")]) @@ -82,7 +153,7 @@ def test_bearer_policy_token_caching(http_request): pipeline.run(http_request("GET", "https://spam.eggs")) assert credential.get_token.call_count == 1 # token is good for an hour -> policy should return it from cache - expired_token = AccessToken("token", time.time()) + expired_token = AccessToken("token", int(time.time())) credential.get_token.reset_mock() credential.get_token.return_value = expired_token pipeline = Pipeline(transport=Mock(), policies=[BearerTokenCredentialPolicy(credential, "scope")]) @@ -91,7 +162,46 @@ def test_bearer_policy_token_caching(http_request): assert credential.get_token.call_count == 1 pipeline.run(http_request("GET", "https://spam.eggs")) - assert credential.get_token.call_count == 2 # token expired -> policy should call get_token + assert credential.get_token.call_count == 2 + + +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_access_token_info_caching(http_request): + """The policy should cache AccessTokenInfo instances and refresh them when necessary.""" + + good_for_one_hour = AccessTokenInfo("token", int(time.time() + 3600)) + credential = Mock(get_token=Mock(return_value=Mock()), get_token_info=Mock(return_value=good_for_one_hour)) + pipeline = Pipeline(transport=Mock(), policies=[BearerTokenCredentialPolicy(credential, "scope")]) + + pipeline.run(http_request("GET", "https://spam.eggs")) + assert ( + credential.get_token_info.call_count == 1 + ) # policy has no token at first request -> it should call get_token_info + + pipeline.run(http_request("GET", "https://spam.eggs")) + assert credential.get_token_info.call_count == 1 # token is good for an hour -> policy should return it from cache + + expired_token = AccessTokenInfo("token", int(time.time())) + credential.get_token_info.reset_mock() + credential.get_token_info.return_value = expired_token + pipeline = Pipeline(transport=Mock(), policies=[BearerTokenCredentialPolicy(credential, "scope")]) + + pipeline.run(http_request("GET", "https://spam.eggs")) + assert credential.get_token_info.call_count == 1 + + pipeline.run(http_request("GET", "https://spam.eggs")) + assert credential.get_token_info.call_count == 2 # token is expired -> policy should call get_token_info again + + refreshable_token = AccessTokenInfo("token", int(time.time() + 3600), refresh_on=int(time.time() - 1)) + credential.get_token_info.reset_mock() + credential.get_token_info.return_value = refreshable_token + pipeline = Pipeline(transport=Mock(), policies=[BearerTokenCredentialPolicy(credential, "scope")]) + + pipeline.run(http_request("GET", "https://spam.eggs")) + assert credential.get_token_info.call_count == 1 + + pipeline.run(http_request("GET", "https://spam.eggs")) + assert credential.get_token_info.call_count == 2 # token refresh-on time has passed, call again @pytest.mark.parametrize("http_request", HTTP_REQUESTS) @@ -324,11 +434,32 @@ def test_need_new_token(): assert not policy._need_new_token # Token has both expires_on and refresh_on set well into the future. - policy._token = AccessToken("", now + 1200, now + 1200) + policy._token = AccessTokenInfo("", now + 1200, refresh_on=now + 1200) assert not policy._need_new_token # Token is not close to expiring, but refresh_on is in the past. - policy._token = AccessToken("", now + 1200, now - 1) + policy._token = AccessTokenInfo("", now + 1200, refresh_on=now - 1) + assert policy._need_new_token + + policy._token = None + assert policy._need_new_token + + +def test_need_new_token_with_external_defined_token_class(): + """Test the case where some custom credential get_token call returns a custom token object.""" + FooAccessToken = namedtuple("FooAccessToken", ["token", "expires_on"]) + + expected_scope = "scope" + now = int(time.time()) + + policy = BearerTokenCredentialPolicy(Mock(), expected_scope) + + # Token is expired. + policy._token = FooAccessToken("", now - 1200) + assert policy._need_new_token + + # Token is about to expire within 300 seconds. + policy._token = FooAccessToken("", now + 299) assert policy._need_new_token @@ -638,3 +769,25 @@ def verify_authorization_header(request): pipeline = Pipeline(transport=transport, policies=[credential_policy]) pipeline.run(http_request("GET", "https://test_key_credential")) + + +def test_access_token_unpack(): + """Test unpacking of AccessToken.""" + token = AccessToken("token", 42) + assert token.token == "token" + assert token.expires_on == 42 + + token, expires_on = AccessToken("token", 42) + assert token == "token" + assert expires_on == 42 + + with pytest.raises(ValueError): + token, expires_on, _ = AccessToken("token", 42) + + +def test_access_token_subscriptable(): + """Test AccessToken property access using index values.""" + token = AccessToken("token", 42) + assert len(token) == 2 + assert token[0] == "token" + assert token[1] == 42 diff --git a/sdk/core/azure-mgmt-core/tests/test_authentication.py b/sdk/core/azure-mgmt-core/tests/test_authentication.py index c621476cfe29..cdc7b1d4c8f8 100644 --- a/sdk/core/azure-mgmt-core/tests/test_authentication.py +++ b/sdk/core/azure-mgmt-core/tests/test_authentication.py @@ -35,6 +35,8 @@ ) from azure.core.pipeline.transport import HttpRequest +from devtools_testutils.fake_credentials import FakeTokenCredential + import pytest from unittest.mock import Mock @@ -171,14 +173,14 @@ def send(request): return Mock(status_code=401, headers={"WWW-Authenticate": expected_header}) transport = Mock(send=Mock(wraps=send)) - credential = Mock() + credential = FakeTokenCredential() policies = [ARMChallengeAuthenticationPolicy(credential, "scope")] pipeline = Pipeline(transport=transport, policies=policies) response = pipeline.run(HttpRequest("GET", "https://localhost")) assert transport.send.call_count == 1 - assert credential.get_token.call_count == 1 + assert credential.get_token_count == 1 # the policy should have returned the error response because it was unable to handle the challenge assert response.http_response.status_code == 401 diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py index 6fc8f7892ac8..1a872f36b6a8 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py @@ -38,7 +38,7 @@ class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None: super().__init__(credential, *scopes, **kwargs) - self._credential = credential + self._credential: AsyncTokenCredential = credential self._token: Optional[AccessToken] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py index 41ed3fe794b8..f16297aa5026 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py @@ -67,7 +67,7 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy): def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None: super(ChallengeAuthPolicy, self).__init__(credential, *scopes, **kwargs) - self._credential = credential + self._credential: TokenCredential = credential self._token: Optional[AccessToken] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py index 6fc8f7892ac8..1a872f36b6a8 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py @@ -38,7 +38,7 @@ class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None: super().__init__(credential, *scopes, **kwargs) - self._credential = credential + self._credential: AsyncTokenCredential = credential self._token: Optional[AccessToken] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py index 41ed3fe794b8..f16297aa5026 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py @@ -67,7 +67,7 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy): def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None: super(ChallengeAuthPolicy, self).__init__(credential, *scopes, **kwargs) - self._credential = credential + self._credential: TokenCredential = credential self._token: Optional[AccessToken] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py index 6fc8f7892ac8..1a872f36b6a8 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py @@ -38,7 +38,7 @@ class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None: super().__init__(credential, *scopes, **kwargs) - self._credential = credential + self._credential: AsyncTokenCredential = credential self._token: Optional[AccessToken] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py index 41ed3fe794b8..f16297aa5026 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py @@ -67,7 +67,7 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy): def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None: super(ChallengeAuthPolicy, self).__init__(credential, *scopes, **kwargs) - self._credential = credential + self._credential: TokenCredential = credential self._token: Optional[AccessToken] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py index 6fc8f7892ac8..1a872f36b6a8 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py @@ -38,7 +38,7 @@ class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None: super().__init__(credential, *scopes, **kwargs) - self._credential = credential + self._credential: AsyncTokenCredential = credential self._token: Optional[AccessToken] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py index 41ed3fe794b8..f16297aa5026 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py @@ -67,7 +67,7 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy): def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None: super(ChallengeAuthPolicy, self).__init__(credential, *scopes, **kwargs) - self._credential = credential + self._credential: TokenCredential = credential self._token: Optional[AccessToken] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None