diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 7b83f97da478..4e2deb7729e4 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -14,6 +14,11 @@ `CertificateCredential(tenant_id, client_id, certificate_bytes=cert_bytes)` ([#14055](https://github.com/Azure/azure-sdk-for-python/issues/14055)) +### Fixed +- `ManagedIdentityCredential` correctly parses responses from the current + (preview) version of Azure ML managed identity + ([#15361](https://github.com/Azure/azure-sdk-for-python/issues/15361)) + ## 1.5.0 (2020-11-11) ### Breaking Changes - Renamed optional `CertificateCredential` keyword argument `send_certificate` diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py b/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py index 2887311fd58b..bc01868c9697 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py @@ -86,10 +86,18 @@ def _parse_app_service_expires_on(content): :raises ValueError: ``expires_on`` didn't match an expected format """ + + # Azure ML sets the same environment variables as App Service but returns expires_on as an integer. + # That means we could have an Azure ML response here, so let's first try to parse expires_on as an int. + try: + content["expires_on"] = int(content["expires_on"]) + return + except ValueError: + pass + import calendar import time - # parse the string minus the timezone offset expires_on = content["expires_on"] if expires_on.endswith(" +00:00"): date_string = expires_on[: -len(" +00:00")] diff --git a/sdk/identity/azure-identity/tests/test_managed_identity.py b/sdk/identity/azure-identity/tests/test_managed_identity.py index 0d82d35c19a1..a4658f9205e1 100644 --- a/sdk/identity/azure-identity/tests/test_managed_identity.py +++ b/sdk/identity/azure-identity/tests/test_managed_identity.py @@ -58,6 +58,58 @@ def test_cloud_shell(): assert token == expected_token +def test_azure_ml(): + """Azure ML: MSI_ENDPOINT, MSI_SECRET set (like App Service 2017-09-01 but with a different response format)""" + + expected_token = AccessToken("****", int(time.time()) + 3600) + url = "http://localhost:42/token" + secret = "expected-secret" + scope = "scope" + client_id = "client" + + transport = validating_transport( + requests=[ + Request( + url, + method="GET", + required_headers={"secret": secret, "User-Agent": USER_AGENT}, + required_params={"api-version": "2017-09-01", "resource": scope}, + ), + Request( + url, + method="GET", + required_headers={"secret": secret, "User-Agent": USER_AGENT}, + required_params={"api-version": "2017-09-01", "resource": scope, "clientid": client_id}, + ), + ], + responses=[ + mock_response( + json_payload={ + "access_token": expected_token.token, + "expires_in": 3600, + "expires_on": expected_token.expires_on, + "resource": scope, + "token_type": "Bearer", + } + ) + ] + * 2, + ) + + with mock.patch.dict( + MANAGED_IDENTITY_ENVIRON, + {EnvironmentVariables.MSI_ENDPOINT: url, EnvironmentVariables.MSI_SECRET: secret}, + clear=True, + ): + token = ManagedIdentityCredential(transport=transport).get_token(scope) + assert token.token == expected_token.token + assert token.expires_on == expected_token.expires_on + + token = ManagedIdentityCredential(transport=transport, client_id=client_id).get_token(scope) + assert token.token == expected_token.token + assert token.expires_on == expected_token.expires_on + + def test_cloud_shell_user_assigned_identity(): """Cloud Shell environment: only MSI_ENDPOINT set""" @@ -172,7 +224,9 @@ def test_prefers_app_service_2017_09_01(): assert token.expires_on == expires_on -@pytest.mark.skip("2019-08-01 support was removed due to https://github.com/Azure/azure-sdk-for-python/issues/14670. This test should be enabled when that support is added back.") +@pytest.mark.skip( + "2019-08-01 support was removed due to https://github.com/Azure/azure-sdk-for-python/issues/14670. This test should be enabled when that support is added back." +) def test_prefers_app_service_2019_08_01(): """When the environment is configured for both App Service versions, the credential should prefer the most recent""" @@ -214,7 +268,9 @@ def test_prefers_app_service_2019_08_01(): assert token.expires_on == expires_on -@pytest.mark.skip("2019-08-01 support was removed due to https://github.com/Azure/azure-sdk-for-python/issues/14670. This test should be enabled when that support is added back.") +@pytest.mark.skip( + "2019-08-01 support was removed due to https://github.com/Azure/azure-sdk-for-python/issues/14670. This test should be enabled when that support is added back." +) def test_app_service_2019_08_01(): """App Service 2019-08-01: IDENTITY_ENDPOINT, IDENTITY_HEADER set""" @@ -371,6 +427,7 @@ def test_app_service_user_assigned_identity(): assert token.token == expected_token assert token.expires_on == expires_on + def test_imds(): access_token = "****" expires_on = 42 @@ -421,11 +478,7 @@ def send(request, **_): if request.data: assert "client_id" not in request.body # Cloud Shell return mock_response( - json_payload=( - build_aad_response( - access_token=expected_access_token, expires_on="42", resource=scope - ) - ) + json_payload=(build_aad_response(access_token=expected_access_token, expires_on="42", resource=scope)) ) # IMDS @@ -596,11 +649,8 @@ def test_azure_arc(tmpdir): ) with mock.patch( - "os.environ", - { - EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, - EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint, - }, + "os.environ", + {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint}, ): token = ManagedIdentityCredential(transport=transport).get_token(scope) assert token.token == access_token @@ -610,11 +660,11 @@ def test_azure_arc(tmpdir): def test_azure_arc_client_id(): """Azure Arc doesn't support user-assigned managed identity""" with mock.patch( - "os.environ", - { - EnvironmentVariables.IDENTITY_ENDPOINT: "http://localhost:42/token", - EnvironmentVariables.IMDS_ENDPOINT: "http://localhost:42", - } + "os.environ", + { + EnvironmentVariables.IDENTITY_ENDPOINT: "http://localhost:42/token", + EnvironmentVariables.IMDS_ENDPOINT: "http://localhost:42", + }, ): credential = ManagedIdentityCredential(client_id="some-guid") diff --git a/sdk/identity/azure-identity/tests/test_managed_identity_async.py b/sdk/identity/azure-identity/tests/test_managed_identity_async.py index 90ceeaee3ead..b7227f4bb06b 100644 --- a/sdk/identity/azure-identity/tests/test_managed_identity_async.py +++ b/sdk/identity/azure-identity/tests/test_managed_identity_async.py @@ -57,6 +57,61 @@ async def test_cloud_shell(): assert token == expected_token +@pytest.mark.asyncio +async def test_azure_ml(): + """Azure ML: MSI_ENDPOINT, MSI_SECRET set (like App Service 2017-09-01 but with a different response format)""" + + expected_token = AccessToken("****", int(time.time()) + 3600) + url = "http://localhost:42/token" + secret = "expected-secret" + scope = "scope" + client_id = "client" + + transport = async_validating_transport( + requests=[ + Request( + url, + method="GET", + required_headers={"secret": secret, "User-Agent": USER_AGENT}, + required_params={"api-version": "2017-09-01", "resource": scope}, + ), + Request( + url, + method="GET", + required_headers={"secret": secret, "User-Agent": USER_AGENT}, + required_params={"api-version": "2017-09-01", "resource": scope, "clientid": client_id}, + ), + ], + responses=[ + mock_response( + json_payload={ + "access_token": expected_token.token, + "expires_in": 3600, + "expires_on": expected_token.expires_on, + "resource": scope, + "token_type": "Bearer", + } + ) + ] + * 2, + ) + + with mock.patch.dict( + MANAGED_IDENTITY_ENVIRON, + {EnvironmentVariables.MSI_ENDPOINT: url, EnvironmentVariables.MSI_SECRET: secret}, + clear=True, + ): + credential = ManagedIdentityCredential(transport=transport) + token = await credential.get_token(scope) + assert token.token == expected_token.token + assert token.expires_on == expected_token.expires_on + + credential = ManagedIdentityCredential(transport=transport, client_id=client_id) + token = await credential.get_token(scope) + assert token.token == expected_token.token + assert token.expires_on == expected_token.expires_on + + @pytest.mark.asyncio async def test_cloud_shell_user_assigned_identity(): """Cloud Shell environment: only MSI_ENDPOINT set""" @@ -176,7 +231,9 @@ async def test_prefers_app_service_2017_09_01(): assert token.expires_on == expires_on -@pytest.mark.skip("2019-08-01 support was removed due to https://github.com/Azure/azure-sdk-for-python/issues/14670. This test should be enabled when that support is added back.") +@pytest.mark.skip( + "2019-08-01 support was removed due to https://github.com/Azure/azure-sdk-for-python/issues/14670. This test should be enabled when that support is added back." +) @pytest.mark.asyncio async def test_app_service_2019_08_01(): """App Service 2019-08-01: IDENTITY_ENDPOINT, IDENTITY_HEADER set""" @@ -299,11 +356,7 @@ async def test_app_service_user_assigned_identity(): base_url=endpoint, method="GET", required_headers={"secret": secret, "User-Agent": USER_AGENT}, - required_params={ - "api-version": "2017-09-01", - "resource": scope, - param_name: param_value, - }, + required_params={"api-version": "2017-09-01", "resource": scope, param_name: param_value}, ), ], responses=[ @@ -350,11 +403,7 @@ async def send(request, **_): if request.data: assert "client_id" not in request.body # Cloud Shell return mock_response( - json_payload=( - build_aad_response( - access_token=expected_access_token, expires_on="42", resource=scope - ) - ) + json_payload=(build_aad_response(access_token=expected_access_token, expires_on="42", resource=scope)) ) with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, {}, clear=True): @@ -567,11 +616,8 @@ async def test_azure_arc(tmpdir): ) with mock.patch( - "os.environ", - { - EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, - EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint, - }, + "os.environ", + {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint}, ): token = await ManagedIdentityCredential(transport=transport).get_token(scope) assert token.token == access_token @@ -582,11 +628,11 @@ async def test_azure_arc(tmpdir): async def test_azure_arc_client_id(): """Azure Arc doesn't support user-assigned managed identity""" with mock.patch( - "os.environ", - { - EnvironmentVariables.IDENTITY_ENDPOINT: "http://localhost:42/token", - EnvironmentVariables.IMDS_ENDPOINT: "http://localhost:42", - } + "os.environ", + { + EnvironmentVariables.IDENTITY_ENDPOINT: "http://localhost:42/token", + EnvironmentVariables.IMDS_ENDPOINT: "http://localhost:42", + }, ): credential = ManagedIdentityCredential(client_id="some-guid")