diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/authentication.py b/sdk/core/azure-core/azure/core/pipeline/policies/authentication.py index 26a67090b05a..a5047a6f5e67 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/authentication.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/authentication.py @@ -26,8 +26,8 @@ class _BearerTokenCredentialPolicyBase(object): :param str scopes: Lets you specify the type of access needed. """ - def __init__(self, credential, scopes, **kwargs): # pylint:disable=unused-argument - # type: (TokenCredential, Iterable[str], Mapping[str, Any]) -> None + def __init__(self, credential, *scopes, **kwargs): # pylint:disable=unused-argument + # type: (TokenCredential, *str, Mapping[str, Any]) -> None super(_BearerTokenCredentialPolicyBase, self).__init__() self._scopes = scopes self._credential = credential @@ -60,6 +60,6 @@ def send(self, request): :return: The pipeline response object :rtype: ~azure.core.pipeline.PipelineResponse """ - token = self._credential.get_token(self._scopes) - self._update_headers(request.http_request.headers, token) # type: ignore + token = self._credential.get_token(*self._scopes) + self._update_headers(request.http_request.headers, token) return self.next.send(request) diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/authentication_async.py b/sdk/core/azure-core/azure/core/pipeline/policies/authentication_async.py index 641f7fae5fed..67d9abf19486 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/authentication_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/authentication_async.py @@ -25,6 +25,6 @@ async def send(self, request: PipelineRequest) -> PipelineResponse: :return: The pipeline response object :rtype: ~azure.core.pipeline.PipelineResponse """ - token = await self._credential.get_token(self._scopes) # type: ignore - self._update_headers(request.http_request.headers, token) # type: ignore + token = await self._credential.get_token(*self._scopes) + self._update_headers(request.http_request.headers, token) return await self.next.send(request) # type: ignore diff --git a/sdk/core/azure-core/tests/azure_core_asynctests/test_authentication.py b/sdk/core/azure-core/tests/azure_core_asynctests/test_authentication.py index 345f7751d0b5..10e4067daf97 100644 --- a/sdk/core/azure-core/tests/azure_core_asynctests/test_authentication.py +++ b/sdk/core/azure-core/tests/azure_core_asynctests/test_authentication.py @@ -27,7 +27,7 @@ async def get_token(_): fake_credential = Mock(get_token=get_token) policies = [ - AsyncBearerTokenCredentialPolicy(credential=fake_credential, scopes=("",)), + AsyncBearerTokenCredentialPolicy(fake_credential, "scope"), Mock(spec=HTTPPolicy, send=verify_authorization_header), ] pipeline = AsyncPipeline(transport=Mock(spec=AsyncHttpTransport), policies=policies) @@ -51,7 +51,7 @@ async def get_token(_): fake_credential = Mock(get_token=get_token) policies = [ - AsyncBearerTokenCredentialPolicy(credential=fake_credential, scopes=("",)), + AsyncBearerTokenCredentialPolicy(fake_credential, "scope"), Mock(spec=HTTPPolicy, send=verify_request), ] pipeline = AsyncPipeline(transport=Mock(spec=AsyncHttpTransport), policies=policies) diff --git a/sdk/core/azure-core/tests/test_authentication.py b/sdk/core/azure-core/tests/test_authentication.py index 032c179d6ca2..aa74ae46ffd2 100644 --- a/sdk/core/azure-core/tests/test_authentication.py +++ b/sdk/core/azure-core/tests/test_authentication.py @@ -23,7 +23,7 @@ def verify_authorization_header(request): fake_credential = Mock(get_token=Mock(return_value=expected_token)) policies = [ - BearerTokenCredentialPolicy(credential=fake_credential, scopes=("",)), + BearerTokenCredentialPolicy(fake_credential, "scope"), Mock(spec=HTTPPolicy, send=verify_authorization_header), ] @@ -43,7 +43,7 @@ def verify_request(request): fake_credential = Mock(get_token=lambda _: "") policies = [ - BearerTokenCredentialPolicy(credential=fake_credential, scopes=("",)), + BearerTokenCredentialPolicy(fake_credential, "scope"), Mock(spec=HTTPPolicy, send=verify_request), ] response = Pipeline(transport=Mock(spec=HttpTransport), policies=policies).run(expected_request) diff --git a/sdk/keyvault/azure-security-keyvault/azure/security/keyvault/_internal.py b/sdk/keyvault/azure-security-keyvault/azure/security/keyvault/_internal.py index ecd11ad97617..00396276606a 100644 --- a/sdk/keyvault/azure-security-keyvault/azure/security/keyvault/_internal.py +++ b/sdk/keyvault/azure-security-keyvault/azure/security/keyvault/_internal.py @@ -28,7 +28,7 @@ _VaultId = namedtuple("VaultId", ["vault_url", "collection", "name", "version"]) -KEY_VAULT_SCOPES = ("https://vault.azure.net/.default",) +KEY_VAULT_SCOPE = "https://vault.azure.net/.default" def _parse_vault_id(url): @@ -68,7 +68,7 @@ def create_config(credential, api_version=None, **kwargs): if api_version is None: api_version = KeyVaultClient.DEFAULT_API_VERSION config = KeyVaultClient.get_configuration_class(api_version, aio=False)(credential, **kwargs) - config.authentication_policy = BearerTokenCredentialPolicy(credential, scopes=KEY_VAULT_SCOPES) + config.authentication_policy = BearerTokenCredentialPolicy(credential, KEY_VAULT_SCOPE) return config def __init__(self, vault_url, credential, config=None, transport=None, api_version=None, **kwargs): diff --git a/sdk/keyvault/azure-security-keyvault/azure/security/keyvault/aio/_internal.py b/sdk/keyvault/azure-security-keyvault/azure/security/keyvault/aio/_internal.py index 34599bb92b53..2337a97d5064 100644 --- a/sdk/keyvault/azure-security-keyvault/azure/security/keyvault/aio/_internal.py +++ b/sdk/keyvault/azure-security-keyvault/azure/security/keyvault/aio/_internal.py @@ -20,7 +20,7 @@ from msrest.serialization import Model from .._generated import KeyVaultClient -from .._internal import KEY_VAULT_SCOPES +from .._internal import KEY_VAULT_SCOPE class AsyncPagingAdapter: @@ -58,7 +58,7 @@ def create_config( if api_version is None: api_version = KeyVaultClient.DEFAULT_API_VERSION config = KeyVaultClient.get_configuration_class(api_version, aio=True)(credential, **kwargs) - config.authentication_policy = AsyncBearerTokenCredentialPolicy(credential, scopes=KEY_VAULT_SCOPES) + config.authentication_policy = AsyncBearerTokenCredentialPolicy(credential, KEY_VAULT_SCOPE) return config def __init__(