Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reginal endpoint support #358

Merged
merged 13 commits into from
May 18, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 95 additions & 7 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .wstrust_response import *
from .token_cache import TokenCache
import msal.telemetry
from .region import _detect_region


# The __init__.py will import this. Not the other way around.
Expand Down Expand Up @@ -108,14 +109,21 @@ class ClientApplication(object):
GET_ACCOUNTS_ID = "902"
REMOVE_ACCOUNT_ID = "903"

ATTEMPT_REGION_DISCOVERY = "TryAutoDetect"

def __init__(
self, client_id,
client_credential=None, authority=None, validate_authority=True,
token_cache=None,
http_client=None,
verify=True, proxies=None, timeout=None,
client_claims=None, app_name=None, app_version=None,
client_capabilities=None):
client_capabilities=None,
region=None, # Note: We choose to add this param in this base class,
rayluo marked this conversation as resolved.
Show resolved Hide resolved
# despite it is currently only needed by ConfidentialClientApplication.
# This way, it holds the same positional param place for PCA,
# when we would eventually want to add this feature to PCA in future.
):
"""Create an instance of application.

:param str client_id: Your app has a client_id after you register it on AAD.
Expand Down Expand Up @@ -220,6 +228,36 @@ def __init__(
MSAL will combine them into
`claims parameter <https://openid.net/specs/openid-connect-core-1_0-final.html#ClaimsParameter`_
which you will later provide via one of the acquire-token request.

:param str region:
Added since MSAL Python 1.12.0.

As of 2021 May, regional service is only available for
``acquire_token_for_client()`` sent by any of the following scenarios::

1. An app powered by a capable MSAL
(MSAL Python 1.12+ will be provisioned)

2. An app with managed identity, which is formerly known as MSI.
(However MSAL Python does not support managed identity,
so this one does not apply.)

3. An app authenticated by Subject Name/Issuer (SNI).

4. An app which already onboard to the region's allow-list.

MSAL's default value is None, which means region behavior remains off.
If enabled, some of the MSAL traffic would remain inside that region.
rayluo marked this conversation as resolved.
Show resolved Hide resolved

App developer can opt in to regional endpoint,
by provide a region name, such as "westus", "eastus2".
rayluo marked this conversation as resolved.
Show resolved Hide resolved

An app running inside Azure VM can use a special keyword
rayluo marked this conversation as resolved.
Show resolved Hide resolved
``ClientApplication.ATTEMPT_REGION_DISCOVERY`` to auto-detect region.
(Attempting this on a non-VM could hang indefinitely.
Make sure you configure a short timeout,
or provide a custom http_client which has a short timeout.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

http_client

imo, we should have dedicated http client with a shorter timeout for this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having 2 different http clients (one dedicated for region detection and one for the rest), would complicate not just MSAL's internal implementation, but also the API model, which would then propagate to our app developer's implementation. Yet the gain is debatable.

When an http_client has short timeout (say, 2 seconds x 2 retries = 4 seconds latency), those seemingly short latency might go unnoticed, and become a perpetual behavior for that app. (If that app happens to be a command-line app such as Azure CLI, it would mean each "az ..." would have 4+ seconds extra latency.)

The current approach would have a longer latency by default. But hopefully this "fail early, fail loudly" approach would lead customer to a better solution: to make their "region" setting configurable, so that it could (and should) be completely turned off by an on-site configuration, when their app is deployed outside of Azure VMs.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a http client dedicated for region, but a client dedicated to fetch data from the imds endpoint. Primary reason is the timeout setting from above.

Copy link

@neha-bhargava neha-bhargava May 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, usually imds call fails fast but with the default retry logic it was taking longer to fail making the overall response time longer. So it was decided to set timeout to 2 sec for the imds call specifically with 1 retry.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is: (1) if your app is running inside Azure VM (or Azure Function etc., for that matter), the detection is a local http call which would be quick (even quicker than other cross-machine http requests), so customizing timeout would be unnecessary; (2) if your app is running outside of Azure infrastructure, you wouldn't want to waste some latency on a meaningless region detection.

The default longer timeout would cause app developer in scenario No.2 to notice the hanging during their smoke testing, which would then lead them to somehow add a per-deployment flag to opt out of the region behavior when/if their app is not deployed on Azure. After that, their app would be more performant when deployed to production.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rayluo agree that there is a need to set the timeout specifically for the imds call to 2 seconds.

That way, the latency would be under your control.)
"""
self.client_id = client_id
self.client_credential = client_credential
Expand Down Expand Up @@ -249,7 +287,10 @@ def __init__(
self.http_client, validate_authority=validate_authority)
# Here the self.authority is not the same type as authority in input
self.token_cache = token_cache or TokenCache()
self.client = self._build_client(client_credential, self.authority)
self._region_configured = region
self._region_detected = None
self.client, self._regional_client = self._build_client(
client_credential, self.authority)
self.authority_groups = None
self._telemetry_buffer = {}
self._telemetry_lock = Lock()
Expand All @@ -260,6 +301,27 @@ def _build_telemetry_context(
self._telemetry_buffer, self._telemetry_lock, api_id,
correlation_id=correlation_id, refresh_reason=refresh_reason)

def _get_regional_authority(self, central_authority):
is_region_specified = bool(self._region_configured
and self._region_configured != self.ATTEMPT_REGION_DISCOVERY)
self._region_detected = self._region_detected or _detect_region(
self.http_client if self._region_configured is not None else None)
if (is_region_specified and self._region_configured != self._region_detected):
logger.warning('Region configured ({}) != region detected ({})'.format(
repr(self._region_configured), repr(self._region_detected)))
region_to_use = (
self._region_configured if is_region_specified else self._region_detected)
rayluo marked this conversation as resolved.
Show resolved Hide resolved
if region_to_use:
logger.info('Region to be used: {}'.format(repr(region_to_use)))
regional_host = ("{}.login.microsoft.com".format(region_to_use)
rayluo marked this conversation as resolved.
Show resolved Hide resolved
if central_authority.instance == "login.microsoftonline.com"
else "{}.{}".format(region_to_use, central_authority.instance))
return Authority(
"https://{}/{}".format(regional_host, central_authority.tenant),
self.http_client,
validate_authority=False) # The central_authority has already been validated
return None

def _build_client(self, client_credential, authority):
client_assertion = None
client_assertion_type = None
Expand Down Expand Up @@ -298,15 +360,15 @@ def _build_client(self, client_credential, authority):
client_assertion_type = Client.CLIENT_ASSERTION_TYPE_JWT
else:
default_body['client_secret'] = client_credential
server_configuration = {
central_configuration = {
"authorization_endpoint": authority.authorization_endpoint,
"token_endpoint": authority.token_endpoint,
"device_authorization_endpoint":
authority.device_authorization_endpoint or
urljoin(authority.token_endpoint, "devicecode"),
}
return Client(
server_configuration,
central_client = Client(
central_configuration,
self.client_id,
http_client=self.http_client,
default_headers=default_headers,
Expand All @@ -318,6 +380,31 @@ def _build_client(self, client_credential, authority):
on_removing_rt=self.token_cache.remove_rt,
on_updating_rt=self.token_cache.update_rt)

regional_client = None
if client_credential: # Currently regional endpoint only serves some CCA flows
regional_authority = self._get_regional_authority(authority)
if regional_authority:
regional_configuration = {
"authorization_endpoint": regional_authority.authorization_endpoint,
"token_endpoint": regional_authority.token_endpoint,
"device_authorization_endpoint":
regional_authority.device_authorization_endpoint or
urljoin(regional_authority.token_endpoint, "devicecode"),
}
regional_client = Client(
regional_configuration,
self.client_id,
http_client=self.http_client,
default_headers=default_headers,
default_body=default_body,
client_assertion=client_assertion,
client_assertion_type=client_assertion_type,
on_obtaining_tokens=lambda event: self.token_cache.add(dict(
event, environment=authority.instance)),
on_removing_rt=self.token_cache.remove_rt,
on_updating_rt=self.token_cache.update_rt)
return central_client, regional_client

def initiate_auth_code_flow(
self,
scopes, # type: list[str]
Expand Down Expand Up @@ -953,7 +1040,7 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
# target=scopes, # AAD RTs are scope-independent
query=query)
logger.debug("Found %d RTs matching %s", len(matches), query)
client = self._build_client(self.client_credential, authority)
client, _ = self._build_client(self.client_credential, authority)

response = None # A distinguishable value to mean cache is empty
telemetry_context = self._build_telemetry_context(
Expand Down Expand Up @@ -1304,7 +1391,8 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
telemetry_context = self._build_telemetry_context(
self.ACQUIRE_TOKEN_FOR_CLIENT_ID)
response = _clean_up(self.client.obtain_token_for_client(
client = self._regional_client or self.client
response = _clean_up(client.obtain_token_for_client(
scope=scopes, # This grant flow requires no scope decoration
headers=telemetry_context.generate_headers(),
data=dict(
Expand Down
33 changes: 33 additions & 0 deletions msal/region.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os
import json
import logging

logger = logging.getLogger(__name__)


def _detect_region(http_client=None):
region = _detect_region_of_azure_function() # It is cheap, so we do it always
if http_client and not region:
rayluo marked this conversation as resolved.
Show resolved Hide resolved
return _detect_region_of_azure_vm(http_client) # It could hang for minutes
return region


def _detect_region_of_azure_function():
return os.environ.get("REGION_NAME")


def _detect_region_of_azure_vm(http_client):
rayluo marked this conversation as resolved.
Show resolved Hide resolved
url = "http://169.254.169.254/metadata/instance?api-version=2021-01-01"
rayluo marked this conversation as resolved.
Show resolved Hide resolved
logger.info(
"Connecting to IMDS {}. "
"You may want to use a shorter timeout on your http_client".format(url))
try:
# https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service?tabs=linux#instance-metadata
resp = http_client.get(url, headers={"Metadata": "true"})
except:
logger.info(
"IMDS {} unavailable. Perhaps not running in Azure VM?".format(url))
return None
else:
return json.loads(resp.text)["compute"]["location"]

106 changes: 99 additions & 7 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ def assertCacheWorksForUser(
"We should get an AT from acquire_token_silent(...) call")

def assertCacheWorksForApp(self, result_from_wire, scope):
logger.debug(
"%s: cache = %s, id_token_claims = %s",
self.id(),
json.dumps(self.app.token_cache._cache, indent=4),
json.dumps(result_from_wire.get("id_token_claims"), indent=4),
)
# Going to test acquire_token_silent(...) to locate an AT from cache
result_from_cache = self.app.acquire_token_silent(scope, account=None)
self.assertIsNotNone(result_from_cache)
Expand Down Expand Up @@ -345,7 +351,9 @@ def test_device_flow(self):
def get_lab_app(
env_client_id="LAB_APP_CLIENT_ID",
env_client_secret="LAB_APP_CLIENT_SECRET",
):
authority="https://login.microsoftonline.com/"
"72f988bf-86f1-41af-91ab-2d7cd011db47", # Microsoft tenant ID
**kwargs):
"""Returns the lab app as an MSAL confidential client.

Get it from environment variables if defined, otherwise fall back to use MSI.
Expand All @@ -367,16 +375,21 @@ def get_lab_app(
env_client_id, env_client_secret)
# See also https://microsoft.sharepoint-df.com/teams/MSIDLABSExtended/SitePages/Programmatically-accessing-LAB-API's.aspx
raise unittest.SkipTest("MSI-based mechanism has not been implemented yet")
return msal.ConfidentialClientApplication(client_id, client_secret,
authority="https://login.microsoftonline.com/"
"72f988bf-86f1-41af-91ab-2d7cd011db47", # Microsoft tenant ID
http_client=MinimalHttpClient())
return msal.ConfidentialClientApplication(
client_id,
client_credential=client_secret,
authority=authority,
http_client=MinimalHttpClient(),
**kwargs)

def get_session(lab_app, scopes): # BTW, this infrastructure tests the confidential client flow
logger.info("Creating session")
lab_token = lab_app.acquire_token_for_client(scopes)
result = lab_app.acquire_token_for_client(scopes)
assert result.get("access_token"), \
"Unable to obtain token for lab. Encountered {}: {}".format(
result.get("error"), result.get("error_description"))
session = requests.Session()
session.headers.update({"Authorization": "Bearer %s" % lab_token["access_token"]})
session.headers.update({"Authorization": "Bearer %s" % result["access_token"]})
session.hooks["response"].append(lambda r, *args, **kwargs: r.raise_for_status())
return session

Expand Down Expand Up @@ -726,6 +739,85 @@ def test_b2c_acquire_token_by_ropc(self):
)


class WorldWideRegionalEndpointTestCase(LabBasedTestCase):
region = "westus"

def test_acquire_token_for_client_should_hit_regional_endpoint(self):
"""This is the only grant supported by regional endpoint, for now"""
self.app = get_lab_app( # Regional endpoint only supports confidential client
## Would fail the OIDC Discovery
#authority="https://westus2.login.microsoftonline.com/"
# "72f988bf-86f1-41af-91ab-2d7cd011db47", # Microsoft tenant ID

#authority="https://westus.login.microsoft.com/microsoft.onmicrosoft.com",
#validate_authority=False,

authority="https://login.microsoftonline.com/microsoft.onmicrosoft.com",
region=self.region, # Explicitly use this region, regardless of detection
)
scopes = ["https://graph.microsoft.com/.default"]
result = self.app.acquire_token_for_client(
scopes,
params={"AllowEstsRNonMsi": "true"}, # For testing regional endpoint. It will be removed once MSAL Python 1.12+ has been onboard to ESTS-R
)
self.assertIn('access_token', result)
self.assertCacheWorksForApp(result, scopes)
# TODO: Test the request hit the regional endpoint self.region?


class RegionalEndpointViaEnvVarTestCase(WorldWideRegionalEndpointTestCase):

def setUp(self):
os.environ["REGION_NAME"] = "eastus"

def tearDown(self):
del os.environ["REGION_NAME"]

@unittest.skipUnless(
os.getenv("LAB_OBO_CLIENT_SECRET"),
"Need LAB_OBO_CLIENT_SECRET from https://aka.ms/GetLabSecret?Secret=TodoListServiceV2-OBO")
@unittest.skipUnless(
os.getenv("LAB_OBO_CONFIDENTIAL_CLIENT_ID"),
"Need LAB_OBO_CONFIDENTIAL_CLIENT_ID from https://docs.msidlab.com/flows/onbehalfofflow.html")
@unittest.skipUnless(
os.getenv("LAB_OBO_PUBLIC_CLIENT_ID"),
"Need LAB_OBO_PUBLIC_CLIENT_ID from https://docs.msidlab.com/flows/onbehalfofflow.html")
def test_cca_obo_should_bypass_regional_endpoint_therefore_still_work(self):
"""We test OBO because it is implemented in sub class ConfidentialClientApplication"""
config = self.get_lab_user(usertype="cloud")

config_cca = {}
config_cca.update(config)
config_cca["client_id"] = os.getenv("LAB_OBO_CONFIDENTIAL_CLIENT_ID")
config_cca["scope"] = ["https://graph.microsoft.com/.default"]
config_cca["client_secret"] = os.getenv("LAB_OBO_CLIENT_SECRET")

config_pca = {}
config_pca.update(config)
config_pca["client_id"] = os.getenv("LAB_OBO_PUBLIC_CLIENT_ID")
config_pca["password"] = self.get_lab_user_secret(config_pca["lab_name"])
config_pca["scope"] = ["api://%s/read" % config_cca["client_id"]]

self._test_acquire_token_obo(config_pca, config_cca)

@unittest.skipUnless(
os.getenv("LAB_OBO_CLIENT_SECRET"),
"Need LAB_OBO_CLIENT_SECRET from https://aka.ms/GetLabSecret?Secret=TodoListServiceV2-OBO")
@unittest.skipUnless(
os.getenv("LAB_OBO_CONFIDENTIAL_CLIENT_ID"),
"Need LAB_OBO_CONFIDENTIAL_CLIENT_ID from https://docs.msidlab.com/flows/onbehalfofflow.html")
def test_cca_ropc_should_bypass_regional_endpoint_therefore_still_work(self):
"""We test ROPC because it is implemented in base class ClientApplication"""
config = self.get_lab_user(usertype="cloud")
config["password"] = self.get_lab_user_secret(config["lab_name"])
# We repurpose the obo confidential app to test ROPC
# Swap in the OBO confidential app
config["client_id"] = os.getenv("LAB_OBO_CONFIDENTIAL_CLIENT_ID")
config["scope"] = ["https://graph.microsoft.com/.default"]
config["client_secret"] = os.getenv("LAB_OBO_CLIENT_SECRET")
self._test_username_password(**config)


class ArlingtonCloudTestCase(LabBasedTestCase):
environment = "azureusgovernment"

Expand Down