diff --git a/msal/application.py b/msal/application.py index 42c12f96..84493340 100644 --- a/msal/application.py +++ b/msal/application.py @@ -72,7 +72,8 @@ def __init__( self, client_id, client_credential=None, authority=None, validate_authority=True, token_cache=None, - verify=True, proxies=None, timeout=None): + verify=True, proxies=None, timeout=None, + client_claims=None): """Create an instance of application. :param client_id: Your app has a client_id after you register it on AAD. @@ -91,6 +92,22 @@ def __init__( public_certificate (optional) is public key certificate which is sent through 'x5c' JWT header only for subject name and issuer authentication to support cert auto rolls + + :param dict client_claims: + It is a dictionary of extra claims that would be signed by + by this :class:`ConfidentialClientApplication` 's private key. + For example, you can use {"client_ip": "x.x.x.x"}. + You may also override any of the following default claims: + + { + "aud": the_token_endpoint, + "iss": self.client_id, + "sub": same_as_issuer, + "exp": now + 10_min, + "iat": now, + "jti": a_random_uuid + } + :param str authority: A URL that identifies a token authority. It should be of the format https://login.microsoftonline.com/your_tenant @@ -115,6 +132,7 @@ def __init__( """ self.client_id = client_id self.client_credential = client_credential + self.client_claims = client_claims self.verify = verify self.proxies = proxies self.timeout = timeout @@ -140,7 +158,8 @@ def _build_client(self, client_credential, authority): client_credential["private_key"], algorithm="RS256", sha1_thumbprint=client_credential.get("thumbprint"), headers=headers) client_assertion = signer.sign_assertion( - audience=authority.token_endpoint, issuer=self.client_id) + audience=authority.token_endpoint, issuer=self.client_id, + additional_claims=self.client_claims or {}) client_assertion_type = Client.CLIENT_ASSERTION_TYPE_JWT else: default_body['client_secret'] = client_credential diff --git a/tests/test_assertion.py b/tests/test_assertion.py new file mode 100644 index 00000000..a4921138 --- /dev/null +++ b/tests/test_assertion.py @@ -0,0 +1,15 @@ +import json + +from msal.oauth2cli import JwtSigner +from msal.oauth2cli.oidc import base64decode + +from tests import unittest + + +class AssertionTestCase(unittest.TestCase): + def test_extra_claims(self): + assertion = JwtSigner(key=None, algorithm="none").sign_assertion( + "audience", "issuer", additional_claims={"client_ip": "1.2.3.4"}) + payload = json.loads(base64decode(assertion.split(b'.')[1].decode('utf-8'))) + self.assertEqual("1.2.3.4", payload.get("client_ip")) +