Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ class _BearerTokenCredentialPolicyBase(object):
:param str scopes: Lets you specify the type of access needed.
"""

def __init__(self, credential, scopes, **kwargs): # pylint:disable=unused-argument
# type: (TokenCredential, Iterable[str], Mapping[str, Any]) -> None
def __init__(self, credential, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (TokenCredential, *str, Mapping[str, Any]) -> None
super(_BearerTokenCredentialPolicyBase, self).__init__()
self._scopes = scopes
self._credential = credential
Expand Down Expand Up @@ -60,6 +60,6 @@ def send(self, request):
:return: The pipeline response object
:rtype: ~azure.core.pipeline.PipelineResponse
"""
token = self._credential.get_token(self._scopes)
self._update_headers(request.http_request.headers, token) # type: ignore
token = self._credential.get_token(*self._scopes)
self._update_headers(request.http_request.headers, token)
return self.next.send(request)
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ async def send(self, request: PipelineRequest) -> PipelineResponse:
:return: The pipeline response object
:rtype: ~azure.core.pipeline.PipelineResponse
"""
token = await self._credential.get_token(self._scopes) # type: ignore
self._update_headers(request.http_request.headers, token) # type: ignore
token = await self._credential.get_token(*self._scopes)
self._update_headers(request.http_request.headers, token)
return await self.next.send(request) # type: ignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async def get_token(_):

fake_credential = Mock(get_token=get_token)
policies = [
AsyncBearerTokenCredentialPolicy(credential=fake_credential, scopes=("",)),
AsyncBearerTokenCredentialPolicy(fake_credential, "scope"),
Mock(spec=HTTPPolicy, send=verify_authorization_header),
]
pipeline = AsyncPipeline(transport=Mock(spec=AsyncHttpTransport), policies=policies)
Expand All @@ -51,7 +51,7 @@ async def get_token(_):

fake_credential = Mock(get_token=get_token)
policies = [
AsyncBearerTokenCredentialPolicy(credential=fake_credential, scopes=("",)),
AsyncBearerTokenCredentialPolicy(fake_credential, "scope"),
Mock(spec=HTTPPolicy, send=verify_request),
]
pipeline = AsyncPipeline(transport=Mock(spec=AsyncHttpTransport), policies=policies)
Expand Down
4 changes: 2 additions & 2 deletions sdk/core/azure-core/tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def verify_authorization_header(request):

fake_credential = Mock(get_token=Mock(return_value=expected_token))
policies = [
BearerTokenCredentialPolicy(credential=fake_credential, scopes=("",)),
BearerTokenCredentialPolicy(fake_credential, "scope"),
Mock(spec=HTTPPolicy, send=verify_authorization_header),
]

Expand All @@ -43,7 +43,7 @@ def verify_request(request):

fake_credential = Mock(get_token=lambda _: "")
policies = [
BearerTokenCredentialPolicy(credential=fake_credential, scopes=("",)),
BearerTokenCredentialPolicy(fake_credential, "scope"),
Mock(spec=HTTPPolicy, send=verify_request),
]
response = Pipeline(transport=Mock(spec=HttpTransport), policies=policies).run(expected_request)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
_VaultId = namedtuple("VaultId", ["vault_url", "collection", "name", "version"])


KEY_VAULT_SCOPES = ("https://vault.azure.net/.default",)
KEY_VAULT_SCOPE = "https://vault.azure.net/.default"


def _parse_vault_id(url):
Expand Down Expand Up @@ -68,7 +68,7 @@ def create_config(credential, api_version=None, **kwargs):
if api_version is None:
api_version = KeyVaultClient.DEFAULT_API_VERSION
config = KeyVaultClient.get_configuration_class(api_version, aio=False)(credential, **kwargs)
config.authentication_policy = BearerTokenCredentialPolicy(credential, scopes=KEY_VAULT_SCOPES)
config.authentication_policy = BearerTokenCredentialPolicy(credential, KEY_VAULT_SCOPE)
return config

def __init__(self, vault_url, credential, config=None, transport=None, api_version=None, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from msrest.serialization import Model

from .._generated import KeyVaultClient
from .._internal import KEY_VAULT_SCOPES
from .._internal import KEY_VAULT_SCOPE


class AsyncPagingAdapter:
Expand Down Expand Up @@ -58,7 +58,7 @@ def create_config(
if api_version is None:
api_version = KeyVaultClient.DEFAULT_API_VERSION
config = KeyVaultClient.get_configuration_class(api_version, aio=True)(credential, **kwargs)
config.authentication_policy = AsyncBearerTokenCredentialPolicy(credential, scopes=KEY_VAULT_SCOPES)
config.authentication_policy = AsyncBearerTokenCredentialPolicy(credential, KEY_VAULT_SCOPE)
return config

def __init__(
Expand Down