-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Expose methods for closing async credential transport sessions #9090
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 16 commits
b4dacdd
812ce4f
e4779da
dea61f8
c805092
6f7b85c
3bb53ca
d70940d
4618c97
87fcd36
6f34be0
ab12fb9
6cd4eb9
fcabaa6
58d1b2e
231bed9
a3cb1a4
5330c31
6e206e2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| # ------------------------------------ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. | ||
| # ------------------------------------ | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| if TYPE_CHECKING: | ||
| from typing import Any | ||
| from typing_extensions import Protocol | ||
| from .credentials import AccessToken | ||
|
|
||
| class AsyncTokenCredential(Protocol): | ||
| async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: | ||
| pass | ||
|
|
||
| async def close(self) -> None: | ||
| pass | ||
|
|
||
| async def __aenter__(self): | ||
| pass | ||
|
|
||
| async def __aexit__(self, exc_type, exc_value, traceback) -> None: | ||
| pass |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| # ------------------------------------ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. | ||
| # ------------------------------------ | ||
| import abc | ||
|
|
||
|
|
||
| class AsyncCredentialBase(abc.ABC): | ||
| @abc.abstractmethod | ||
| async def close(self): | ||
| pass | ||
|
|
||
| async def __aenter__(self): | ||
| return self | ||
|
|
||
| async def __aexit__(self, *args): | ||
| await self.close() | ||
|
|
||
| @abc.abstractmethod | ||
| async def get_token(self, *scopes, **kwargs): | ||
| pass | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,17 +2,19 @@ | |
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. | ||
| # ------------------------------------ | ||
| import asyncio | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| from azure.core.exceptions import ClientAuthenticationError | ||
| from ... import ChainedTokenCredential as SyncChainedTokenCredential | ||
| from .base import AsyncCredentialBase | ||
|
|
||
| if TYPE_CHECKING: | ||
| from typing import Any | ||
| from azure.core.credentials import AccessToken | ||
|
|
||
|
|
||
| class ChainedTokenCredential(SyncChainedTokenCredential): | ||
| class ChainedTokenCredential(SyncChainedTokenCredential, AsyncCredentialBase): | ||
| """A sequence of credentials that is itself a credential. | ||
|
|
||
| Its :func:`get_token` method calls ``get_token`` on each credential in the sequence, in order, returning the first | ||
|
|
@@ -22,6 +24,11 @@ class ChainedTokenCredential(SyncChainedTokenCredential): | |
| :type credentials: :class:`azure.core.credentials.TokenCredential` | ||
| """ | ||
|
|
||
| async def close(self): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a concern here: do we expect customer to "async enter" all the credentials in the chain, or should we have a aenter here that loop thourgh all of them and enter them?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is what I want to enable: credential = DefaultAzureCredential()
client = FooServiceClient(credential)
# ... time passes, many useful service requests are authorized ...
credential.close()I think
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you implement If we don't want to give an example of intended use with
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To be consistent with everything else in the SDK that wraps async transport by exposing The awkwardness for |
||
| """Close the transport sessions of all credentials in the chain.""" | ||
|
|
||
| await asyncio.gather(*(credential.close() for credential in self.credentials)) | ||
|
|
||
| async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument | ||
| """Asynchronously request a token from each credential, in order, returning the first token received. | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. | ||
| # ------------------------------------ | ||
| import abc | ||
| import os | ||
| from typing import TYPE_CHECKING | ||
|
|
||
|
|
@@ -10,6 +11,7 @@ | |
| from azure.core.pipeline.policies import AsyncRetryPolicy | ||
|
|
||
| from azure.identity._credentials.managed_identity import _ManagedIdentityBase | ||
| from .base import AsyncCredentialBase | ||
| from .._authn_client import AsyncAuthnClient | ||
| from ..._constants import Endpoints, EnvironmentVariables | ||
|
|
||
|
|
@@ -37,6 +39,15 @@ def __new__(cls, *args, **kwargs): | |
| def __init__(self, **kwargs: "Any") -> None: | ||
| pass | ||
|
|
||
| async def __aenter__(self): | ||
|
||
| pass | ||
|
|
||
| async def __aexit__(self, *args): | ||
| pass | ||
|
|
||
| async def close(self): | ||
| """Close the credential's transport session.""" | ||
|
|
||
| async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument | ||
| """Asynchronously request an access token for `scopes`. | ||
|
|
||
|
|
@@ -49,10 +60,23 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py | |
| return AccessToken() | ||
|
|
||
|
|
||
| class _AsyncManagedIdentityBase(_ManagedIdentityBase): | ||
| class _AsyncManagedIdentityBase(_ManagedIdentityBase, AsyncCredentialBase): | ||
| def __init__(self, endpoint: str, **kwargs: "Any") -> None: | ||
| super().__init__(endpoint=endpoint, client_cls=AsyncAuthnClient, **kwargs) | ||
|
|
||
| async def __aenter__(self): | ||
| await self._client.__aenter__() | ||
| return self | ||
|
|
||
| async def close(self): | ||
| """Close the credential's transport session.""" | ||
|
|
||
| await self._client.__aexit__() | ||
|
|
||
| @abc.abstractmethod | ||
| async def get_token(self, *scopes, **kwargs): | ||
| pass | ||
|
|
||
| @staticmethod | ||
| def _create_config(**kwargs: "Any") -> "Configuration": | ||
| """Build a default configuration for the credential's HTTP pipeline.""" | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.