-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Expose methods for closing async credential transport sessions #9090
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b4dacdd
812ce4f
e4779da
dea61f8
c805092
6f7b85c
3bb53ca
d70940d
4618c97
87fcd36
6f34be0
ab12fb9
6cd4eb9
fcabaa6
58d1b2e
231bed9
a3cb1a4
5330c31
6e206e2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| # ------------------------------------ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. | ||
| # ------------------------------------ | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| if TYPE_CHECKING: | ||
| from typing import Any | ||
| from typing_extensions import Protocol | ||
| from .credentials import AccessToken | ||
|
|
||
| class AsyncTokenCredential(Protocol): | ||
| async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: | ||
| pass | ||
|
|
||
| async def close(self) -> None: | ||
| pass | ||
|
|
||
| async def __aenter__(self): | ||
| pass | ||
|
|
||
| async def __aexit__(self, exc_type, exc_value, traceback) -> None: | ||
| pass |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| # ------------------------------------ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. | ||
| # ------------------------------------ | ||
| import abc | ||
|
|
||
|
|
||
| class AsyncCredentialBase(abc.ABC): | ||
| @abc.abstractmethod | ||
| async def close(self): | ||
| pass | ||
|
|
||
| async def __aenter__(self): | ||
| return self | ||
|
|
||
| async def __aexit__(self, *args): | ||
| await self.close() | ||
|
|
||
| @abc.abstractmethod | ||
| async def get_token(self, *scopes, **kwargs): | ||
| pass | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,27 +2,40 @@ | |
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. | ||
| # ------------------------------------ | ||
| import asyncio | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| from azure.core.exceptions import ClientAuthenticationError | ||
| from ... import ChainedTokenCredential as SyncChainedTokenCredential | ||
| from .base import AsyncCredentialBase | ||
| from ..._credentials.chained import _get_error_message | ||
|
|
||
| if TYPE_CHECKING: | ||
| from typing import Any | ||
| from azure.core.credentials import AccessToken | ||
| from azure.core.credentials_async import AsyncTokenCredential | ||
|
|
||
|
|
||
| class ChainedTokenCredential(SyncChainedTokenCredential): | ||
| class ChainedTokenCredential(AsyncCredentialBase): | ||
| """A sequence of credentials that is itself a credential. | ||
|
|
||
| Its :func:`get_token` method calls ``get_token`` on each credential in the sequence, in order, returning the first | ||
| valid token received. | ||
|
|
||
| :param credentials: credential instances to form the chain | ||
| :type credentials: :class:`azure.core.credentials.TokenCredential` | ||
| :type credentials: :class:`azure.core.credentials.AsyncTokenCredential` | ||
| """ | ||
|
|
||
| async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument | ||
| def __init__(self, *credentials: "AsyncTokenCredential") -> None: | ||
| if not credentials: | ||
| raise ValueError("at least one credential is required") | ||
| self.credentials = credentials | ||
|
|
||
| async def close(self): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a concern here: do we expect customer to "async enter" all the credentials in the chain, or should we have a aenter here that loop thourgh all of them and enter them?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is what I want to enable: credential = DefaultAzureCredential()
client = FooServiceClient(credential)
# ... time passes, many useful service requests are authorized ...
credential.close()I think
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you implement If we don't want to give an example of intended use with
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To be consistent with everything else in the SDK that wraps async transport by exposing The awkwardness for |
||
| """Close the transport sessions of all credentials in the chain.""" | ||
|
|
||
| await asyncio.gather(*(credential.close() for credential in self.credentials)) | ||
|
|
||
| async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": | ||
| """Asynchronously request a token from each credential, in order, returning the first token received. | ||
|
|
||
| If no credential provides a token, raises :class:`azure.core.exceptions.ClientAuthenticationError` | ||
|
|
@@ -41,5 +54,5 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py | |
| history.append((credential, ex.message)) | ||
| except Exception as ex: # pylint: disable=broad-except | ||
| history.append((credential, str(ex))) | ||
| error_message = self._get_error_message(history) | ||
| error_message = _get_error_message(history) | ||
| raise ClientAuthenticationError(message=error_message) | ||
Uh oh!
There was an error while loading. Please reload this page.