Skip to content

Commit

Permalink
POC of MSAL integration
Browse files Browse the repository at this point in the history
  • Loading branch information
rayluo committed Dec 2, 2021
1 parent 1e2fa15 commit b488f16
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 6 deletions.
24 changes: 24 additions & 0 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,13 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
refresh_reason = msal.telemetry.FORCE_REFRESH # TODO: It could also mean claims_challenge
assert refresh_reason, "It should have been established at this point"
try:
if sys.platform == "win32":
from .wam import _acquire_token_silently, _read_account_by_id
return _acquire_token_silently(
"https://{}/{}".format(self.authority.instance, self.authority.tenant), # TODO: What about B2C & ADFS?
self.client_id,
_read_account_by_id(account["local_account_id"]),
" ".join(scopes))
result = _clean_up(self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
authority, self._decorate_scope(scopes), account,
refresh_reason=refresh_reason, claims_challenge=claims_challenge,
Expand Down Expand Up @@ -1553,6 +1560,23 @@ def acquire_token_interactive(
and typically contains an "access_token" key.
- A dict containing an "error" key, when token refresh failed.
"""
if sys.platform == "win32":
from .wam import _signin_interactively
response = _signin_interactively(
"https://{}/{}".format(self.authority.instance, self.authority.tenant), # TODO: What about B2C & ADFS?
self.client_id,
" ".join(scopes),
login_hint=login_hint)
if response.get("error") != "TBD: Broker Unavailable": # TODO
self.token_cache.add(dict(
client_id=self.client_id,
scope=scopes,
token_endpoint=self.authority.token_endpoint,
response=response.copy(),
data=kwargs.get("data", {}),
))
return response

self._validate_ssh_cert_input_data(kwargs.get("data", {}))
claims = _merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)
Expand Down
3 changes: 2 additions & 1 deletion msal/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def __add(self, event, now=None):
id_token = response.get("id_token")
id_token_claims = (
decode_id_token(id_token, client_id=event["client_id"])
if id_token else {})
if id_token
else response.get("id_token_claims", {})) # Mid-tier would provide id_token_claims
client_info, home_account_id = self.__parse_account(response, id_token_claims)

target = ' '.join(event.get("scope") or []) # Per schema, we don't sort it
Expand Down
12 changes: 7 additions & 5 deletions msal/wam.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import pymsalruntime # See https://github.com/AzureAD/microsoft-authentication-library-for-cpp/pull/2419/files#diff-d5ea5122ff04e14411a4f695895c923daba73c117d6c8ceb19c4fa3520c3c08a
import win32gui # Came from package pywin32


logger = logging.getLogger(__name__)


Expand All @@ -28,13 +27,14 @@ def _read_account_by_id(account_id):
callback_data = _CallbackData()
pymsalruntime.read_account_by_id(
account_id,
"correlation_id",
lambda result, callback_data=callback_data: callback_data.complete(result)
)
callback_data.signal.wait()
return callback_data.auth_result


def _convert_result(result):
def _convert_result(result): # Mimic an on-the-wire response from AAD
error = result.get_error()
if error:
return {
Expand All @@ -43,13 +43,15 @@ def _convert_result(result):
error.get_context(), # Available since pymsalruntime 0.0.4
error.get_status(), error.get_error_code(), error.get_tag()),
}
id_token_claims = json.loads(result.get_id_token()) if result.get_id_token() else {}
account = result.get_account()
assert account.get_account_id() == id_token_claims.get("oid"), "Emperical observation" # TBD
return {k: v for k, v in {
"access_token": result.get_access_token(),
"expires_in": result.get_access_token_expiry_time(),
#"scope": result.get_granted_scopes(), # TODO
"id_token_claims": json.loads(result.get_id_token())
if result.get_id_token() else None,
"account": result.get_account(),
"id_token_claims": id_token_claims,
"client_info": account.get_client_info(),
}.items() if v}


Expand Down

0 comments on commit b488f16

Please sign in to comment.