Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reginal endpoint support #358

Merged
merged 13 commits into from
May 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 133 additions & 9 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .wstrust_response import *
from .token_cache import TokenCache
import msal.telemetry
from .region import _detect_region


# The __init__.py will import this. Not the other way around.
Expand Down Expand Up @@ -108,14 +109,21 @@ class ClientApplication(object):
GET_ACCOUNTS_ID = "902"
REMOVE_ACCOUNT_ID = "903"

ATTEMPT_REGION_DISCOVERY = True # "TryAutoDetect"

def __init__(
self, client_id,
client_credential=None, authority=None, validate_authority=True,
token_cache=None,
http_client=None,
verify=True, proxies=None, timeout=None,
client_claims=None, app_name=None, app_version=None,
client_capabilities=None):
client_capabilities=None,
azure_region=None, # Note: We choose to add this param in this base class,
# despite it is currently only needed by ConfidentialClientApplication.
# This way, it holds the same positional param place for PCA,
# when we would eventually want to add this feature to PCA in future.
):
"""Create an instance of application.

:param str client_id: Your app has a client_id after you register it on AAD.
Expand Down Expand Up @@ -220,6 +228,53 @@ def __init__(
MSAL will combine them into
`claims parameter <https://openid.net/specs/openid-connect-core-1_0-final.html#ClaimsParameter`_
which you will later provide via one of the acquire-token request.

:param str azure_region:
Added since MSAL Python 1.12.0.

As of 2021 May, regional service is only available for
``acquire_token_for_client()`` sent by any of the following scenarios::

1. An app powered by a capable MSAL
(MSAL Python 1.12+ will be provisioned)

2. An app with managed identity, which is formerly known as MSI.
(However MSAL Python does not support managed identity,
so this one does not apply.)

3. An app authenticated by
`Subject Name/Issuer (SNI) <https://github.com/AzureAD/microsoft-authentication-library-for-python/issues/60>`_.

4. An app which already onboard to the region's allow-list.

MSAL's default value is None, which means region behavior remains off.
If enabled, the `acquire_token_for_client()`-relevant traffic
would remain inside that region.

App developer can opt in to a regional endpoint,
by provide its region name, such as "westus", "eastus2".
You can find a full list of regions by running
``az account list-locations -o table``, or referencing to
`this doc <https://docs.microsoft.com/en-us/dotnet/api/microsoft.azure.management.resourcemanager.fluent.core.region?view=azure-dotnet>`_.

An app running inside Azure Functions and Azure VM can use a special keyword
``ClientApplication.ATTEMPT_REGION_DISCOVERY`` to auto-detect region.

.. note::

Setting ``azure_region`` to non-``None`` for an app running
outside of Azure Function/VM could hang indefinitely.

You should consider opting in/out region behavior on-demand,
by loading ``azure_region=None`` or ``azure_region="westus"``
or ``azure_region=True`` (which means opt-in and auto-detect)
from your per-deployment configuration, and then do
``app = ConfidentialClientApplication(..., azure_region=azure_region)``.

Alternatively, you can configure a short timeout,
or provide a custom http_client which has a short timeout.
That way, the latency would be under your control,
but still less performant than opting out of region feature.
"""
self.client_id = client_id
self.client_credential = client_credential
Expand All @@ -244,12 +299,29 @@ def __init__(

self.app_name = app_name
self.app_version = app_version
self.authority = Authority(

# Here the self.authority will not be the same type as authority in input
try:
self.authority = Authority(
authority or "https://login.microsoftonline.com/common/",
self.http_client, validate_authority=validate_authority)
# Here the self.authority is not the same type as authority in input
except ValueError: # Those are explicit authority validation errors
raise
except Exception: # The rest are typically connection errors
if validate_authority and region:
# Since caller opts in to use region, here we tolerate connection
# errors happened during authority validation at non-region endpoint
self.authority = Authority(
authority or "https://login.microsoftonline.com/common/",
self.http_client, validate_authority=False)
else:
raise

self.token_cache = token_cache or TokenCache()
self.client = self._build_client(client_credential, self.authority)
self._region_configured = azure_region
self._region_detected = None
self.client, self._regional_client = self._build_client(
client_credential, self.authority)
self.authority_groups = None
self._telemetry_buffer = {}
self._telemetry_lock = Lock()
Expand All @@ -260,6 +332,32 @@ def _build_telemetry_context(
self._telemetry_buffer, self._telemetry_lock, api_id,
correlation_id=correlation_id, refresh_reason=refresh_reason)

def _get_regional_authority(self, central_authority):
is_region_specified = bool(self._region_configured
and self._region_configured != self.ATTEMPT_REGION_DISCOVERY)
self._region_detected = self._region_detected or _detect_region(
self.http_client if self._region_configured is not None else None)
if (is_region_specified and self._region_configured != self._region_detected):
logger.warning('Region configured ({}) != region detected ({})'.format(
repr(self._region_configured), repr(self._region_detected)))
region_to_use = (
self._region_configured if is_region_specified else self._region_detected)
rayluo marked this conversation as resolved.
Show resolved Hide resolved
if region_to_use:
logger.info('Region to be used: {}'.format(repr(region_to_use)))
regional_host = ("{}.login.microsoft.com".format(region_to_use)
rayluo marked this conversation as resolved.
Show resolved Hide resolved
if central_authority.instance in (
# The list came from https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/358/files#r629400328
"login.microsoftonline.com",
"login.windows.net",
"sts.windows.net",
)
else "{}.{}".format(region_to_use, central_authority.instance))
return Authority(
"https://{}/{}".format(regional_host, central_authority.tenant),
self.http_client,
validate_authority=False) # The central_authority has already been validated
return None

def _build_client(self, client_credential, authority):
client_assertion = None
client_assertion_type = None
Expand Down Expand Up @@ -298,15 +396,15 @@ def _build_client(self, client_credential, authority):
client_assertion_type = Client.CLIENT_ASSERTION_TYPE_JWT
else:
default_body['client_secret'] = client_credential
server_configuration = {
central_configuration = {
"authorization_endpoint": authority.authorization_endpoint,
"token_endpoint": authority.token_endpoint,
"device_authorization_endpoint":
authority.device_authorization_endpoint or
urljoin(authority.token_endpoint, "devicecode"),
}
return Client(
server_configuration,
central_client = Client(
central_configuration,
self.client_id,
http_client=self.http_client,
default_headers=default_headers,
Expand All @@ -318,6 +416,31 @@ def _build_client(self, client_credential, authority):
on_removing_rt=self.token_cache.remove_rt,
on_updating_rt=self.token_cache.update_rt)

regional_client = None
if client_credential: # Currently regional endpoint only serves some CCA flows
regional_authority = self._get_regional_authority(authority)
if regional_authority:
regional_configuration = {
"authorization_endpoint": regional_authority.authorization_endpoint,
"token_endpoint": regional_authority.token_endpoint,
"device_authorization_endpoint":
regional_authority.device_authorization_endpoint or
urljoin(regional_authority.token_endpoint, "devicecode"),
}
regional_client = Client(
regional_configuration,
self.client_id,
http_client=self.http_client,
default_headers=default_headers,
default_body=default_body,
client_assertion=client_assertion,
client_assertion_type=client_assertion_type,
on_obtaining_tokens=lambda event: self.token_cache.add(dict(
event, environment=authority.instance)),
on_removing_rt=self.token_cache.remove_rt,
on_updating_rt=self.token_cache.update_rt)
return central_client, regional_client

def initiate_auth_code_flow(
self,
scopes, # type: list[str]
Expand Down Expand Up @@ -953,7 +1076,7 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
# target=scopes, # AAD RTs are scope-independent
query=query)
logger.debug("Found %d RTs matching %s", len(matches), query)
client = self._build_client(self.client_credential, authority)
client, _ = self._build_client(self.client_credential, authority)

response = None # A distinguishable value to mean cache is empty
telemetry_context = self._build_telemetry_context(
Expand Down Expand Up @@ -1304,7 +1427,8 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
telemetry_context = self._build_telemetry_context(
self.ACQUIRE_TOKEN_FOR_CLIENT_ID)
response = _clean_up(self.client.obtain_token_for_client(
client = self._regional_client or self.client
response = _clean_up(client.obtain_token_for_client(
scope=scopes, # This grant flow requires no scope decoration
headers=telemetry_context.generate_headers(),
data=dict(
Expand Down
47 changes: 47 additions & 0 deletions msal/region.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os
import logging

logger = logging.getLogger(__name__)


def _detect_region(http_client=None):
region = _detect_region_of_azure_function() # It is cheap, so we do it always
if http_client and not region:
rayluo marked this conversation as resolved.
Show resolved Hide resolved
return _detect_region_of_azure_vm(http_client) # It could hang for minutes
return region


def _detect_region_of_azure_function():
return os.environ.get("REGION_NAME")


def _detect_region_of_azure_vm(http_client):
rayluo marked this conversation as resolved.
Show resolved Hide resolved
url = (
"http://169.254.169.254/metadata/instance"

# Utilize the "route parameters" feature to obtain region as a string
# https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service?tabs=linux#route-parameters
"/compute/location?format=text"

# Location info is available since API version 2017-04-02
# https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service?tabs=linux#response-1
"&api-version=2021-01-01"
)
logger.info(
"Connecting to IMDS {}. "
"It may take a while if you are running outside of Azure. "
"You should consider opting in/out region behavior on-demand, "
'by loading a boolean flag "is_deployed_in_azure" '
'from your per-deployment config and then do '
'"app = ConfidentialClientApplication(..., '
'azure_region=is_deployed_in_azure)"'.format(url))
try:
# https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service?tabs=linux#instance-metadata
resp = http_client.get(url, headers={"Metadata": "true"})
except:
logger.info(
"IMDS {} unavailable. Perhaps not running in Azure VM?".format(url))
return None
else:
return resp.text.strip()

1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
# We will go with "<4" for now, which is also what our another dependency,
# pyjwt, currently use.

"mock;python_version<'3.3'",
]
)

Loading