From 036052e6bbfba1b8a2b1968bf7ff8bbf5990eb79 Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Tue, 5 Nov 2024 20:22:51 +0000 Subject: [PATCH 1/2] [Corehttp] Update credential classes Signed-off-by: Paul Van Eck --- sdk/core/corehttp/corehttp/credentials.py | 76 +++++++++++----- .../runtime/policies/_authentication.py | 21 ++++- .../runtime/policies/_authentication_async.py | 30 ++++-- .../async_tests/test_authentication_async.py | 72 ++++++++++----- .../corehttp/tests/test_authentication.py | 91 +++++++++++++------ 5 files changed, 203 insertions(+), 87 deletions(-) diff --git a/sdk/core/corehttp/corehttp/credentials.py b/sdk/core/corehttp/corehttp/credentials.py index b5514ea1046b..6b31a4557ac9 100644 --- a/sdk/core/corehttp/corehttp/credentials.py +++ b/sdk/core/corehttp/corehttp/credentials.py @@ -5,39 +5,71 @@ # ------------------------------------------------------------------------- from __future__ import annotations from types import TracebackType -from typing import Any, NamedTuple, Optional, AsyncContextManager, Type +from typing import NamedTuple, Optional, AsyncContextManager, Type, TypedDict, ContextManager from typing_extensions import Protocol, runtime_checkable -class AccessToken(NamedTuple): - """Represents an OAuth access token.""" +class AccessTokenInfo: + """Information about an OAuth access token. + + :param str token: The token string. + :param int expires_on: The token's expiration time in Unix time. + :keyword str token_type: The type of access token. Defaults to 'Bearer'. + :keyword int refresh_on: Specifies the time, in Unix time, when the cached token should be proactively + refreshed. Optional. + """ token: str + """The token string.""" expires_on: int + """The token's expiration time in Unix time.""" + token_type: str + """The type of access token.""" + refresh_on: Optional[int] + """Specifies the time, in Unix time, when the cached token should be proactively refreshed. Optional.""" + + def __init__( + self, token: str, expires_on: int, *, token_type: str = "Bearer", refresh_on: Optional[int] = None + ) -> None: + self.token = token + self.expires_on = expires_on + self.token_type = token_type + self.refresh_on = refresh_on + def __repr__(self) -> str: + return "AccessTokenInfo(token='{}', expires_on={}, token_type='{}', refresh_on={})".format( + self.token, self.expires_on, self.token_type, self.refresh_on + ) -AccessToken.token.__doc__ = """The token string.""" -AccessToken.expires_on.__doc__ = """The token's expiration time in Unix time.""" +class TokenRequestOptions(TypedDict, total=False): + """Options to use for access token requests. All parameters are optional.""" -@runtime_checkable -class TokenCredential(Protocol): - """Protocol for classes able to provide OAuth tokens.""" + claims: str + """Additional claims required in the token, such as those returned in a resource provider's claims + challenge following an authorization failure.""" + tenant_id: str + """The tenant ID to include in the token request.""" - def get_token(self, *scopes: str, claims: Optional[str] = None, **kwargs: Any) -> AccessToken: - """Request an access token for `scopes`. - :param str scopes: The type of access needed. +class TokenCredential(Protocol, ContextManager["TokenCredential"]): + """Protocol for classes able to provide OAuth access tokens.""" - :keyword str claims: Additional claims required in the token, such as those returned in a resource - provider's claims challenge following an authorization failure. + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + :param str scopes: The type of access needed. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: TokenRequestOptions - :rtype: AccessToken - :return: An AccessToken instance containing the token string and its expiration time in Unix time. + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. """ ... + def close(self) -> None: + pass + class ServiceNamedKey(NamedTuple): """Represents a name and key pair.""" @@ -47,10 +79,11 @@ class ServiceNamedKey(NamedTuple): __all__ = [ - "AccessToken", + "AccessTokenInfo", "ServiceKeyCredential", "ServiceNamedKeyCredential", "TokenCredential", + "TokenRequestOptions", "AsyncTokenCredential", ] @@ -134,16 +167,15 @@ def update(self, name: str, key: str) -> None: class AsyncTokenCredential(Protocol, AsyncContextManager["AsyncTokenCredential"]): """Protocol for classes able to provide OAuth tokens.""" - async def get_token(self, *scopes: str, claims: Optional[str] = None, **kwargs: Any) -> AccessToken: + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: """Request an access token for `scopes`. :param str scopes: The type of access needed. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: TokenRequestOptions - :keyword str claims: Additional claims required in the token, such as those returned in a resource - provider's claims challenge following an authorization failure. - - :rtype: AccessToken - :return: An AccessToken instance containing the token string and its expiration time in Unix time. + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing the token string and its expiration time in Unix time. """ ... diff --git a/sdk/core/corehttp/corehttp/runtime/policies/_authentication.py b/sdk/core/corehttp/corehttp/runtime/policies/_authentication.py index d19790472202..38a7360249ab 100644 --- a/sdk/core/corehttp/corehttp/runtime/policies/_authentication.py +++ b/sdk/core/corehttp/corehttp/runtime/policies/_authentication.py @@ -7,6 +7,7 @@ import time from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any +from ...credentials import TokenRequestOptions from ...rest import HttpResponse, HttpRequest from . import HTTPPolicy, SansIOHTTPPolicy from ...exceptions import ServiceRequestError @@ -14,7 +15,7 @@ if TYPE_CHECKING: from ...credentials import ( - AccessToken, + AccessTokenInfo, TokenCredential, ServiceKeyCredential, ) @@ -39,7 +40,7 @@ def __init__( super(_BearerTokenCredentialPolicyBase, self).__init__() self._scopes = scopes self._credential = credential - self._token: Optional["AccessToken"] = None + self._token: Optional["AccessTokenInfo"] = None @staticmethod def _enforce_https(request: PipelineRequest[HTTPRequestType]) -> None: @@ -68,7 +69,12 @@ def _update_headers(headers: MutableMapping[str, str], token: str) -> None: @property def _need_new_token(self) -> bool: - return not self._token or self._token.expires_on - time.time() < 300 + now = time.time() + return ( + not self._token + or (self._token.refresh_on is not None and self._token.refresh_on <= now) + or (self._token.expires_on - now < 300) + ) class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]): @@ -90,7 +96,7 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: self._enforce_https(request) if self._token is None or self._need_new_token: - self._token = self._credential.get_token(*self._scopes) + self._token = self._credential.get_token_info(*self._scopes) self._update_headers(request.http_request.headers, self._token.token) def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None: @@ -102,7 +108,12 @@ def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: :param ~corehttp.runtime.pipeline.PipelineRequest request: the request :param str scopes: required scopes of authentication """ - self._token = self._credential.get_token(*scopes, **kwargs) + options: TokenRequestOptions = {} + # Loop through all the keyword arguments and check if they are part of the TokenRequestOptions. + for key in list(kwargs.keys()): + if key in TokenRequestOptions.__annotations__: # pylint:disable=no-member + options[key] = kwargs.pop(key) # type: ignore[literal-required] + self._token = self._credential.get_token_info(*scopes, options=options) self._update_headers(request.http_request.headers, self._token.token) def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]: diff --git a/sdk/core/corehttp/corehttp/runtime/policies/_authentication_async.py b/sdk/core/corehttp/corehttp/runtime/policies/_authentication_async.py index b10bc3739eb2..8da76d8f79d1 100644 --- a/sdk/core/corehttp/corehttp/runtime/policies/_authentication_async.py +++ b/sdk/core/corehttp/corehttp/runtime/policies/_authentication_async.py @@ -7,7 +7,7 @@ import time from typing import TYPE_CHECKING, Any, Awaitable, Optional, cast, TypeVar -from ...credentials import AccessToken +from ...credentials import AccessTokenInfo, TokenRequestOptions from ..pipeline import PipelineRequest, PipelineResponse from ..pipeline._tools_async import await_result from ._base_async import AsyncHTTPPolicy @@ -38,7 +38,7 @@ def __init__( self._credential = credential self._lock_instance = None self._scopes = scopes - self._token: Optional["AccessToken"] = None + self._token: Optional[AccessTokenInfo] = None @property def _lock(self): @@ -55,12 +55,12 @@ async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: """ _BearerTokenCredentialPolicyBase._enforce_https(request) # pylint:disable=protected-access - if self._token is None or self._need_new_token(): + if self._token is None or self._need_new_token: async with self._lock: # double check because another coroutine may have acquired a token while we waited to acquire the lock - if self._token is None or self._need_new_token(): - self._token = await await_result(self._credential.get_token, *self._scopes) - request.http_request.headers["Authorization"] = "Bearer " + cast(AccessToken, self._token).token + if self._token is None or self._need_new_token: + self._token = await await_result(self._credential.get_token_info, *self._scopes) + request.http_request.headers["Authorization"] = "Bearer " + cast(AccessTokenInfo, self._token).token async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None: """Acquire a token from the credential and authorize the request with it. @@ -71,9 +71,15 @@ async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *sc :param ~corehttp.runtime.pipeline.PipelineRequest request: the request :param str scopes: required scopes of authentication """ + options: TokenRequestOptions = {} + # Loop through all the keyword arguments and check if they are part of the TokenRequestOptions. + for key in list(kwargs.keys()): + if key in TokenRequestOptions.__annotations__: # pylint:disable=no-member + options[key] = kwargs.pop(key) # type: ignore[literal-required] + async with self._lock: - self._token = await await_result(self._credential.get_token, *scopes, **kwargs) - request.http_request.headers["Authorization"] = "Bearer " + cast(AccessToken, self._token).token + self._token = await await_result(self._credential.get_token_info, *scopes, options=options) + request.http_request.headers["Authorization"] = "Bearer " + cast(AccessTokenInfo, self._token).token async def send( self, request: PipelineRequest[HTTPRequestType] @@ -149,5 +155,11 @@ def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None: # pylint: disable=unused-argument return + @property def _need_new_token(self) -> bool: - return not self._token or self._token.expires_on - time.time() < 300 + now = time.time() + return ( + not self._token + or (self._token.refresh_on is not None and self._token.refresh_on <= now) + or (self._token.expires_on - now < 300) + ) diff --git a/sdk/core/corehttp/tests/async_tests/test_authentication_async.py b/sdk/core/corehttp/tests/async_tests/test_authentication_async.py index a8d0b77d6ff0..e41305903c64 100644 --- a/sdk/core/corehttp/tests/async_tests/test_authentication_async.py +++ b/sdk/core/corehttp/tests/async_tests/test_authentication_async.py @@ -7,7 +7,7 @@ import time from unittest.mock import Mock -from corehttp.credentials import AccessToken +from corehttp.credentials import AccessTokenInfo from corehttp.credentials import AsyncTokenCredential from corehttp.exceptions import ServiceRequestError from corehttp.runtime.pipeline import AsyncPipeline @@ -25,7 +25,7 @@ async def test_bearer_policy_adds_header(): """The bearer token policy should add a header containing a token from its credential""" # 2524608000 == 01/01/2050 @ 12:00am (UTC) - expected_token = AccessToken("expected_token", 2524608000) + expected_token = AccessTokenInfo("expected_token", 2524608000) async def verify_authorization_header(request): assert request.http_request.headers["Authorization"] == "Bearer {}".format(expected_token.token) @@ -33,12 +33,12 @@ async def verify_authorization_header(request): get_token_calls = 0 - async def get_token(*_, **__): + async def get_token_info(*_, **__): nonlocal get_token_calls get_token_calls += 1 return expected_token - fake_credential = Mock(spec_set=["get_token"], get_token=get_token) + fake_credential = Mock(spec_set=["get_token_info"], get_token_info=get_token_info) policies = [AsyncBearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)] pipeline = AsyncPipeline(transport=Mock(), policies=policies) @@ -59,8 +59,8 @@ async def verify_request(request): assert request.http_request is expected_request return expected_response - get_token = get_completed_future(AccessToken("***", 42)) - fake_credential = Mock(spec_set=["get_token"], get_token=lambda *_, **__: get_token) + get_token = get_completed_future(AccessTokenInfo("***", 42)) + fake_credential = Mock(spec_set=["get_token_info"], get_token_info=lambda *_, **__: get_token) policies = [AsyncBearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_request)] response = await AsyncPipeline(transport=Mock(), policies=policies).run(expected_request) @@ -76,8 +76,8 @@ async def verify_request(request): assert request.http_request is expected_request return expected_response - get_token = get_completed_future(AccessToken("***", 42)) - fake_credential = Mock(spec_set=["get_token"], get_token=lambda *_, **__: get_token) + get_token = get_completed_future(AccessTokenInfo("***", 42)) + fake_credential = Mock(spec_set=["get_token_info"], get_token_info=lambda *_, **__: get_token) policies = [AsyncBearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_request)] response = await AsyncPipeline(transport=Mock(), policies=policies).run(expected_request) @@ -85,11 +85,11 @@ async def verify_request(request): async def test_bearer_policy_token_caching(): - good_for_one_hour = AccessToken("token", int(time.time()) + 3600) + good_for_one_hour = AccessTokenInfo("token", int(time.time()) + 3600) expected_token = good_for_one_hour get_token_calls = 0 - async def get_token(*_, **__): + async def get_token_info(*_, **__): nonlocal get_token_calls get_token_calls += 1 return expected_token @@ -97,7 +97,7 @@ async def get_token(*_, **__): async def send_mock(_): return Mock(http_response=Mock(status_code=200)) - credential = Mock(spec_set=["get_token"], get_token=get_token) + credential = Mock(spec_set=["get_token_info"], get_token_info=get_token_info) policies = [ AsyncBearerTokenCredentialPolicy(credential, "scope"), Mock(send=send_mock), @@ -110,7 +110,7 @@ async def send_mock(_): await pipeline.run(HttpRequest("GET", "https://spam.eggs")) assert get_token_calls == 1 # token is good for an hour -> policy should return it from cache - expired_token = AccessToken("token", int(time.time())) + expired_token = AccessTokenInfo("token", int(time.time())) get_token_calls = 0 expected_token = expired_token policies = [ @@ -133,7 +133,9 @@ async def assert_option_popped(request, **kwargs): assert "enforce_https" not in kwargs, "AsyncBearerTokenCredentialPolicy didn't pop the 'enforce_https' option" return Mock() - credential = Mock(spec_set=["get_token"], get_token=lambda *_, **__: get_completed_future(AccessToken("***", 42))) + credential = Mock( + spec_set=["get_token_info"], get_token_info=lambda *_, **__: get_completed_future(AccessTokenInfo("***", 42)) + ) pipeline = AsyncPipeline( transport=Mock(send=assert_option_popped), policies=[AsyncBearerTokenCredentialPolicy(credential, "scope")] ) @@ -161,8 +163,8 @@ def on_request(self, request): assert "enforce_https" in request.context, "'enforce_https' is not in the request's context" return Mock() - get_token = get_completed_future(AccessToken("***", 42)) - credential = Mock(spec_set=["get_token"], get_token=lambda *_, **__: get_token) + get_token = get_completed_future(AccessTokenInfo("***", 42)) + credential = Mock(spec_set=["get_token_info"], get_token_info=lambda *_, **__: get_token) policies = [AsyncBearerTokenCredentialPolicy(credential, "scope"), ContextValidator()] pipeline = AsyncPipeline(transport=Mock(send=lambda *_, **__: get_completed_future(Mock())), policies=policies) @@ -177,8 +179,8 @@ def on_request(self, request): assert not any(request.context), "the policy shouldn't add to the request's context" return Mock() - get_token = get_completed_future(AccessToken("***", 42)) - credential = Mock(spec_set=["get_token"], get_token=lambda *_, **__: get_token) + get_token = get_completed_future(AccessTokenInfo("***", 42)) + credential = Mock(spec_set=["get_token_info"], get_token_info=lambda *_, **__: get_token) policies = [AsyncBearerTokenCredentialPolicy(credential, "scope"), ContextValidator()] pipeline = AsyncPipeline(transport=Mock(send=lambda *_, **__: get_completed_future(Mock())), policies=policies) @@ -202,7 +204,7 @@ async def send(self, request): credential = Mock( spec_set=["get_token"], - get_token=Mock(return_value=get_completed_future(AccessToken("***", int(time.time()) + 3600))), + get_token=Mock(return_value=get_completed_future(AccessTokenInfo("***", int(time.time()) + 3600))), ) policy = TestPolicy(credential, "scope") transport = Mock(send=Mock(return_value=get_completed_future(Mock(status_code=200)))) @@ -273,11 +275,12 @@ def get_completed_future(result=None): @pytest.mark.asyncio async def test_async_token_credential_inheritance(): class TestTokenCredential(AsyncTokenCredential): - async def get_token(self, *scopes, **kwargs): + async def get_token_info(self, *scopes, options=None): return "TOKEN" cred = TestTokenCredential() - await cred.get_token("scope") + token = await cred.get_token_info("scope") + assert token == "TOKEN" @pytest.mark.asyncio @@ -286,6 +289,33 @@ async def test_async_token_credential_asyncio_lock(): assert isinstance(auth_policy._lock, asyncio.Lock) -def test_async_token_credential_sync(): +async def test_async_token_credential_sync(): """Verify that AsyncBearerTokenCredentialPolicy can be constructed in a synchronous context.""" AsyncBearerTokenCredentialPolicy(Mock(), "scope") + + +async def test_need_new_token(): + expected_scope = "scope" + now = int(time.time()) + + policy = AsyncBearerTokenCredentialPolicy(Mock(), expected_scope) + + # Token is expired. + policy._token = AccessTokenInfo("", now - 1200) + assert policy._need_new_token + + # Token is about to expire within 300 seconds. + policy._token = AccessTokenInfo("", now + 299) + assert policy._need_new_token + + # Token still has more than 300 seconds to live. + policy._token = AccessTokenInfo("", now + 305) + assert not policy._need_new_token + + # Token has both expires_on and refresh_on set well into the future. + policy._token = AccessTokenInfo("", now + 1200, refresh_on=now + 1200) + assert not policy._need_new_token + + # Token is not close to expiring, but refresh_on is in the past. + policy._token = AccessTokenInfo("", now + 1200, refresh_on=now - 1) + assert policy._need_new_token diff --git a/sdk/core/corehttp/tests/test_authentication.py b/sdk/core/corehttp/tests/test_authentication.py index 97e8a51315e8..f7a6a0da3be6 100644 --- a/sdk/core/corehttp/tests/test_authentication.py +++ b/sdk/core/corehttp/tests/test_authentication.py @@ -6,7 +6,7 @@ import time from unittest.mock import Mock -from corehttp.credentials import AccessToken, ServiceKeyCredential, ServiceNamedKeyCredential +from corehttp.credentials import AccessTokenInfo, ServiceKeyCredential, ServiceNamedKeyCredential from corehttp.exceptions import ServiceRequestError from corehttp.runtime.pipeline import Pipeline from corehttp.runtime.policies import ( @@ -22,24 +22,24 @@ def test_bearer_policy_adds_header(): """The bearer token policy should add a header containing a token from its credential""" # 2524608000 == 01/01/2050 @ 12:00am (UTC) - expected_token = AccessToken("expected_token", 2524608000) + expected_token = AccessTokenInfo("expected_token", 2524608000) def verify_authorization_header(request): assert request.http_request.headers["Authorization"] == "Bearer {}".format(expected_token.token) return Mock() - fake_credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=expected_token)) + fake_credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(return_value=expected_token)) policies = [BearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)] pipeline = Pipeline(transport=Mock(), policies=policies) pipeline.run(HttpRequest("GET", "https://spam.eggs")) - assert fake_credential.get_token.call_count == 1 + assert fake_credential.get_token_info.call_count == 1 pipeline.run(HttpRequest("GET", "https://spam.eggs")) # Didn't need a new token - assert fake_credential.get_token.call_count == 1 + assert fake_credential.get_token_info.call_count == 1 def test_bearer_policy_send(): @@ -51,10 +51,10 @@ def verify_request(request): assert request.http_request is expected_request return expected_response - def get_token(*_, **__): - return AccessToken("***", 42) + def get_token_info(*_, **__): + return AccessTokenInfo("***", 42) - fake_credential = Mock(spec_set=["get_token"], get_token=get_token) + fake_credential = Mock(spec_set=["get_token_info"], get_token_info=get_token_info) policies = [BearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_request)] response = Pipeline(transport=Mock(), policies=policies).run(expected_request) @@ -62,26 +62,26 @@ def get_token(*_, **__): def test_bearer_policy_token_caching(): - good_for_one_hour = AccessToken("token", int(time.time()) + 3600) - credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=good_for_one_hour)) + good_for_one_hour = AccessTokenInfo("token", int(time.time()) + 3600) + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(return_value=good_for_one_hour)) pipeline = Pipeline(transport=Mock(), policies=[BearerTokenCredentialPolicy(credential, "scope")]) pipeline.run(HttpRequest("GET", "https://spam.eggs")) - assert credential.get_token.call_count == 1 # policy has no token at first request -> it should call get_token + assert credential.get_token_info.call_count == 1 # policy has no token at first request -> it should call get_token pipeline.run(HttpRequest("GET", "https://spam.eggs")) - assert credential.get_token.call_count == 1 # token is good for an hour -> policy should return it from cache + assert credential.get_token_info.call_count == 1 # token is good for an hour -> policy should return it from cache - expired_token = AccessToken("token", int(time.time())) - credential.get_token.reset_mock() - credential.get_token.return_value = expired_token + expired_token = AccessTokenInfo("token", int(time.time())) + credential.get_token_info.reset_mock() + credential.get_token_info.return_value = expired_token pipeline = Pipeline(transport=Mock(), policies=[BearerTokenCredentialPolicy(credential, "scope")]) pipeline.run(HttpRequest("GET", "https://spam.eggs")) - assert credential.get_token.call_count == 1 + assert credential.get_token_info.call_count == 1 pipeline.run(HttpRequest("GET", "https://spam.eggs")) - assert credential.get_token.call_count == 2 # token expired -> policy should call get_token + assert credential.get_token_info.call_count == 2 # token expired -> policy should call get_token def test_bearer_policy_optionally_enforces_https(): @@ -91,10 +91,10 @@ def assert_option_popped(request, **kwargs): assert "enforce_https" not in kwargs, "BearerTokenCredentialPolicy didn't pop the 'enforce_https' option" return Mock() - def get_token(*_, **__): - return AccessToken("***", 42) + def get_token_info(*_, **__): + return AccessTokenInfo("***", 42) - credential = Mock(spec_set=["get_token"], get_token=get_token) + credential = Mock(spec_set=["get_token_info"], get_token_info=get_token_info) pipeline = Pipeline( transport=Mock(send=assert_option_popped), policies=[BearerTokenCredentialPolicy(credential, "scope")] ) @@ -122,7 +122,7 @@ def on_request(self, request): assert "enforce_https" in request.context, "'enforce_https' is not in the request's context" return Mock() - credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=AccessToken("***", 42))) + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(return_value=AccessTokenInfo("***", 42))) policies = [BearerTokenCredentialPolicy(credential, "scope"), ContextValidator()] pipeline = Pipeline(transport=Mock(), policies=policies) @@ -132,14 +132,14 @@ def on_request(self, request): def test_bearer_policy_default_context(): """The policy should call get_token with the scopes given at construction, and no keyword arguments, by default""" expected_scope = "scope" - token = AccessToken("", 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=token)) + token = AccessTokenInfo("", 0) + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(return_value=token)) policy = BearerTokenCredentialPolicy(credential, expected_scope) pipeline = Pipeline(transport=Mock(), policies=[policy]) pipeline.run(HttpRequest("GET", "https://localhost")) - credential.get_token.assert_called_once_with(expected_scope) + credential.get_token_info.assert_called_once_with(expected_scope) def test_bearer_policy_context_unmodified_by_default(): @@ -149,7 +149,7 @@ class ContextValidator(SansIOHTTPPolicy): def on_request(self, request): assert not any(request.context), "the policy shouldn't add to the request's context" - credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=AccessToken("***", 42))) + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(return_value=AccessTokenInfo("***", 42))) policies = [BearerTokenCredentialPolicy(credential, "scope"), ContextValidator()] pipeline = Pipeline(transport=Mock(), policies=policies) @@ -166,7 +166,9 @@ def on_challenge(self, request, challenge): self.__class__.called = True return False - credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=AccessToken("***", int(time.time()) + 3600))) + credential = Mock( + spec_set=["get_token_info"], get_token_info=Mock(return_value=AccessTokenInfo("***", int(time.time()) + 3600)) + ) policies = [TestPolicy(credential, "scope")] response = Mock(status_code=401, headers={"WWW-Authenticate": 'Basic realm="localhost"'}) transport = Mock(send=Mock(return_value=response)) @@ -181,8 +183,8 @@ def test_bearer_policy_cannot_complete_challenge(): """BearerTokenCredentialPolicy should return the 401 response when it can't complete its challenge""" expected_scope = "scope" - expected_token = AccessToken("***", int(time.time()) + 3600) - credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=expected_token)) + expected_token = AccessTokenInfo("***", int(time.time()) + 3600) + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(return_value=expected_token)) expected_response = Mock(status_code=401, headers={"WWW-Authenticate": 'Basic realm="localhost"'}) transport = Mock(send=Mock(return_value=expected_response)) policies = [BearerTokenCredentialPolicy(credential, expected_scope)] @@ -192,7 +194,7 @@ def test_bearer_policy_cannot_complete_challenge(): assert response.http_response is expected_response assert transport.send.call_count == 1 - credential.get_token.assert_called_once_with(expected_scope) + credential.get_token_info.assert_called_once_with(expected_scope) def test_bearer_policy_calls_sansio_methods(): @@ -210,7 +212,9 @@ def send(self, request): self.response = super(TestPolicy, self).send(request) return self.response - credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=AccessToken("***", int(time.time()) + 3600))) + credential = Mock( + spec_set=["get_token_info"], get_token_info=Mock(return_value=AccessTokenInfo("***", int(time.time()) + 3600)) + ) policy = TestPolicy(credential, "scope") transport = Mock(send=Mock(return_value=Mock(status_code=200))) @@ -384,3 +388,30 @@ def verify_authorization_header(request): pipeline = Pipeline(transport=transport, policies=[credential_policy]) pipeline.run(HttpRequest("GET", "https://test_key_credential")) + + +def test_need_new_token(): + expected_scope = "scope" + now = int(time.time()) + + policy = BearerTokenCredentialPolicy(Mock(), expected_scope) + + # Token is expired. + policy._token = AccessTokenInfo("", now - 1200) + assert policy._need_new_token + + # Token is about to expire within 300 seconds. + policy._token = AccessTokenInfo("", now + 299) + assert policy._need_new_token + + # Token still has more than 300 seconds to live. + policy._token = AccessTokenInfo("", now + 305) + assert not policy._need_new_token + + # Token has both expires_on and refresh_on set well into the future. + policy._token = AccessTokenInfo("", now + 1200, refresh_on=now + 1200) + assert not policy._need_new_token + + # Token is not close to expiring, but refresh_on is in the past. + policy._token = AccessTokenInfo("", now + 1200, refresh_on=now - 1) + assert policy._need_new_token From 3dbf70217b006e749328dee673bce5e5653eb29f Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Fri, 8 Nov 2024 02:54:33 +0000 Subject: [PATCH 2/2] Update changelog Signed-off-by: Paul Van Eck --- sdk/core/corehttp/CHANGELOG.md | 20 +++++++++++++++++++ sdk/core/corehttp/corehttp/_version.py | 2 +- .../samples/sample_async_pipeline_client.py | 10 ---------- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/sdk/core/corehttp/CHANGELOG.md b/sdk/core/corehttp/CHANGELOG.md index 50e2f169162e..8a4d87644799 100644 --- a/sdk/core/corehttp/CHANGELOG.md +++ b/sdk/core/corehttp/CHANGELOG.md @@ -1,5 +1,25 @@ # Release History +## 1.0.0b6 (Unreleased) + +### Features Added + +- The `TokenCredential` and `AsyncTokenCredential` protocols have been updated to include a new `get_token_info` method. This method should be used to acquire tokens and return an `AccessTokenInfo` object. [#38346](https://github.com/Azure/azure-sdk-for-python/pull/38346) +- Added a new `TokenRequestOptions` class, which is a `TypedDict` with optional parameters, that can be used to define options for token requests through the `get_token_info` method. [#38346](https://github.com/Azure/azure-sdk-for-python/pull/38346) +- Added a new `AccessTokenInfo` class, which is returned by `get_token_info` implementations. This class contains the token, its expiration time, and optional additional information like when a token should be refreshed. [#38346](https://github.com/Azure/azure-sdk-for-python/pull/38346) +- `BearerTokenCredentialPolicy` and `AsyncBearerTokenCredentialPolicy` now check if a credential has the `get_token_info` method defined. If so, the `get_token_info` method is used to acquire a token. [#38346](https://github.com/Azure/azure-sdk-for-python/pull/38346) + - These policies now also check the `refresh_on` attribute when determining if a new token request should be made. + +### Breaking Changes + +- The `get_token` method has been removed from the `TokenCredential` and `AsyncTokenCredential` protocols. Implementations should now use the new `get_token_info` method to acquire tokens. [#38346](https://github.com/Azure/azure-sdk-for-python/pull/38346) +- The `AccessToken` class has been removed and replaced with a new `AccessTokenInfo` class. [#38346](https://github.com/Azure/azure-sdk-for-python/pull/38346) +- `BearerTokenCredentialPolicy` and `AsyncBearerTokenCredentialPolicy` now rely on credentials having the `get_token_info` method defined. [#38346](https://github.com/Azure/azure-sdk-for-python/pull/38346) + +### Bugs Fixed + +### Other Changes + ## 1.0.0b5 (2024-02-29) ### Other Changes diff --git a/sdk/core/corehttp/corehttp/_version.py b/sdk/core/corehttp/corehttp/_version.py index 34b338b446b9..d88f70a9762b 100644 --- a/sdk/core/corehttp/corehttp/_version.py +++ b/sdk/core/corehttp/corehttp/_version.py @@ -9,4 +9,4 @@ # regenerated. # -------------------------------------------------------------------------- -VERSION = "1.0.0b5" +VERSION = "1.0.0b6" diff --git a/sdk/core/corehttp/samples/sample_async_pipeline_client.py b/sdk/core/corehttp/samples/sample_async_pipeline_client.py index c8bc4f0d4d62..0e05cac9835c 100644 --- a/sdk/core/corehttp/samples/sample_async_pipeline_client.py +++ b/sdk/core/corehttp/samples/sample_async_pipeline_client.py @@ -16,16 +16,6 @@ import asyncio from typing import Iterable, Union -from corehttp.runtime import AsyncPipelineClient -from corehttp.rest import HttpRequest, AsyncHttpResponse -from corehttp.runtime.policies import ( - AsyncHTTPPolicy, - SansIOHTTPPolicy, - HeadersPolicy, - UserAgentPolicy, - AsyncRetryPolicy, -) - async def sample_pipeline_client(): # [START build_async_pipeline_client]