diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py b/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py index 2887311fd58b..c2a6776f7213 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py @@ -38,9 +38,9 @@ def get_token(self, *scopes, **kwargs): ) return super(AppServiceCredential, self).get_token(*scopes, **kwargs) - def _acquire_token_silently(self, *scopes): - # type: (*str) -> Optional[AccessToken] - return self._client.get_cached_token(*scopes) + def _acquire_token_silently(self, *scopes, **kwargs): + # type: (*str, **Any) -> Optional[AccessToken] + return self._client.get_cached_token(*scopes, **kwargs) def _request_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py index 03f7f05eddd2..f7069f697f23 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py @@ -60,9 +60,9 @@ def get_token(self, *scopes, **kwargs): ) return super(AzureArcCredential, self).get_token(*scopes, **kwargs) - def _acquire_token_silently(self, *scopes): - # type: (*str) -> Optional[AccessToken] - return self._client.get_cached_token(*scopes) + def _acquire_token_silently(self, *scopes, **kwargs): + # type: (*str, **Any) -> Optional[AccessToken] + return self._client.get_cached_token(*scopes, **kwargs) def _request_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py b/sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py index 067cc092d898..e9634594a5a6 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py @@ -42,9 +42,9 @@ def get_token(self, *scopes, **kwargs): ) return super(CloudShellCredential, self).get_token(*scopes, **kwargs) - def _acquire_token_silently(self, *scopes): - # type: (*str) -> Optional[AccessToken] - return self._client.get_cached_token(*scopes) + def _acquire_token_silently(self, *scopes, **kwargs): + # type: (*str, **Any) -> Optional[AccessToken] + return self._client.get_cached_token(*scopes, **kwargs) def _request_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/service_fabric.py b/sdk/identity/azure-identity/azure/identity/_credentials/service_fabric.py index 2c48db078b0f..691f5c578c09 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/service_fabric.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/service_fabric.py @@ -38,9 +38,9 @@ def get_token(self, *scopes, **kwargs): ) return super(ServiceFabricCredential, self).get_token(*scopes, **kwargs) - def _acquire_token_silently(self, *scopes): - # type: (*str) -> Optional[AccessToken] - return self._client.get_cached_token(*scopes) + def _acquire_token_silently(self, *scopes, **kwargs): + # type: (*str, **Any) -> Optional[AccessToken] + return self._client.get_cached_token(*scopes, **kwargs) def _request_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken diff --git a/sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py b/sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py index 68fc0df801ea..84f3e5674007 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py @@ -45,7 +45,7 @@ def _request_token(self, *scopes, **kwargs): # type: (*str, **Any) -> Optional[AccessToken] app = self._get_app() request_time = int(time.time()) - result = app.acquire_token_for_client(list(scopes)) + result = app.acquire_token_for_client(list(scopes), **kwargs) if "access_token" not in result: message = "Authentication failed: {}".format(result.get("error_description") or result.get("error")) raise ClientAuthenticationError(message=message) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/get_token_mixin.py b/sdk/identity/azure-identity/azure/identity/_internal/get_token_mixin.py index c927504d2bd3..9d66e6d316b4 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/get_token_mixin.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/get_token_mixin.py @@ -31,8 +31,8 @@ def __init__(self, *args, **kwargs): super(GetTokenMixin, self).__init__(*args, **kwargs) # type: ignore @abc.abstractmethod - def _acquire_token_silently(self, *scopes): - # type: (*str) -> Optional[AccessToken] + def _acquire_token_silently(self, *scopes, **kwargs): + # type: (*str, **Any) -> Optional[AccessToken] """Attempt to acquire an access token from a cache or by redeeming a refresh token""" @abc.abstractmethod @@ -66,10 +66,10 @@ def get_token(self, *scopes, **kwargs): raise ValueError('"get_token" requires at least one scope') try: - token = self._acquire_token_silently(*scopes) + token = self._acquire_token_silently(*scopes, **kwargs) if not token: self._last_request_time = int(time.time()) - token = self._request_token(*scopes) + token = self._request_token(*scopes, **kwargs) elif self._should_refresh(token): try: self._last_request_time = int(time.time()) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py index 1a943b070fca..32a60598fc7f 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py @@ -84,8 +84,8 @@ def _process_response(self, response, request_time): return token - def get_cached_token(self, *scopes): - # type: (*str) -> Optional[AccessToken] + def get_cached_token(self, *scopes, **kwargs): # pylint:disable=unused-argument + # type: (*str, **Any) -> Optional[AccessToken] resource = _scopes_to_resource(*scopes) tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=[resource]) for token in tokens: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/app_service.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/app_service.py index 0f9f5f3e0783..63143646679f 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/app_service.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/app_service.py @@ -39,8 +39,8 @@ async def get_token( # pylint:disable=invalid-overridden-method async def close(self) -> None: await self._client.close() # pylint:disable=no-member - async def _acquire_token_silently(self, *scopes: str) -> "Optional[AccessToken]": - return self._client.get_cached_token(*scopes) + async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": + return self._client.get_cached_token(*scopes, **kwargs) async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": return await self._client.request_token(*scopes, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py index 09848f1e452a..586397c17545 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py @@ -65,8 +65,8 @@ async def get_token( # pylint:disable=invalid-overridden-method async def close(self) -> None: await self._client.close() # pylint:disable=no-member - async def _acquire_token_silently(self, *scopes: str) -> "Optional[AccessToken]": - return self._client.get_cached_token(*scopes) + async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": + return self._client.get_cached_token(*scopes, **kwargs) async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": return await self._client.request_token(*scopes, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/cloud_shell.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/cloud_shell.py index 9617becdc67d..ee6c657dc855 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/cloud_shell.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/cloud_shell.py @@ -43,8 +43,8 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": async def close(self) -> None: await self._client.close() - async def _acquire_token_silently(self, *scopes: str) -> "Optional[AccessToken]": - return self._client.get_cached_token(*scopes) + async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": + return self._client.get_cached_token(*scopes, **kwargs) async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": return await self._client.request_token(*scopes, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/service_fabric.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/service_fabric.py index be6f18fcf536..bf30b2e5d381 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/service_fabric.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/service_fabric.py @@ -39,8 +39,8 @@ async def get_token( # pylint:disable=invalid-overridden-method async def close(self) -> None: await self._client.close() # pylint:disable=no-member - async def _acquire_token_silently(self, *scopes: str) -> "Optional[AccessToken]": - return self._client.get_cached_token(*scopes) + async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": + return self._client.get_cached_token(*scopes, **kwargs) async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": return await self._client.request_token(*scopes, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py index 204cffff1d6d..4f7245164fba 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py @@ -25,7 +25,7 @@ def __init__(self, *args: "Any", **kwargs: "Any") -> None: super(GetTokenMixin, self).__init__(*args, **kwargs) # type: ignore @abc.abstractmethod - async def _acquire_token_silently(self, *scopes: str) -> "Optional[AccessToken]": + async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": """Attempt to acquire an access token from a cache or by redeeming a refresh token""" @abc.abstractmethod diff --git a/sdk/identity/azure-identity/tests/test_get_token_mixin.py b/sdk/identity/azure-identity/tests/test_get_token_mixin.py index 28d2e5df2705..ea171ec39e68 100644 --- a/sdk/identity/azure-identity/tests/test_get_token_mixin.py +++ b/sdk/identity/azure-identity/tests/test_get_token_mixin.py @@ -24,8 +24,8 @@ def __init__(self, cached_token=None): self.request_token = mock.Mock(return_value=MockCredential.NEW_TOKEN) self.acquire_token_silently = mock.Mock(return_value=cached_token) - def _acquire_token_silently(self, *scopes): - return self.acquire_token_silently(*scopes) + def _acquire_token_silently(self, *scopes, **kwargs): + return self.acquire_token_silently(*scopes, **kwargs) def _request_token(self, *scopes, **kwargs): return self.request_token(*scopes, **kwargs) diff --git a/sdk/identity/azure-identity/tests/test_get_token_mixin_async.py b/sdk/identity/azure-identity/tests/test_get_token_mixin_async.py index a76c7a82faf7..bd9ee0fb476c 100644 --- a/sdk/identity/azure-identity/tests/test_get_token_mixin_async.py +++ b/sdk/identity/azure-identity/tests/test_get_token_mixin_async.py @@ -23,8 +23,8 @@ def __init__(self, cached_token=None): self.request_token = mock.Mock(return_value=MockCredential.NEW_TOKEN) self.acquire_token_silently = mock.Mock(return_value=cached_token) - async def _acquire_token_silently(self, *scopes): - return self.acquire_token_silently(*scopes) + async def _acquire_token_silently(self, *scopes, **kwargs): + return self.acquire_token_silently(*scopes, **kwargs) async def _request_token(self, *scopes, **kwargs): return self.request_token(*scopes, **kwargs)