diff --git a/sdk/identity/azure-identity/azure/identity/_constants.py b/sdk/identity/azure-identity/azure/identity/_constants.py index cb13afefacbe..ea1b086cd8e7 100644 --- a/sdk/identity/azure-identity/azure/identity/_constants.py +++ b/sdk/identity/azure-identity/azure/identity/_constants.py @@ -35,6 +35,7 @@ class EnvironmentVariables: AZURE_PASSWORD = "AZURE_PASSWORD" USERNAME_PASSWORD_VARS = (AZURE_CLIENT_ID, AZURE_USERNAME, AZURE_PASSWORD) + AZURE_POD_IDENTITY_TOKEN_URL = "AZURE_POD_IDENTITY_TOKEN_URL" IDENTITY_ENDPOINT = "IDENTITY_ENDPOINT" IDENTITY_HEADER = "IDENTITY_HEADER" IDENTITY_SERVER_THUMBPRINT = "IDENTITY_SERVER_THUMBPRINT" diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/imds.py b/sdk/identity/azure-identity/azure/identity/_credentials/imds.py index fef901282e99..291a9387ebd4 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/imds.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/imds.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ import logging +import os from typing import TYPE_CHECKING import six @@ -11,6 +12,7 @@ from azure.core.pipeline.transport import HttpRequest from .. import CredentialUnavailableError +from .._constants import EnvironmentVariables from .._internal.get_token_mixin import GetTokenMixin from .._internal.managed_identity_client import ManagedIdentityClient @@ -34,7 +36,7 @@ def get_request(scope, identity_config): - request = HttpRequest("GET", IMDS_URL) + request = HttpRequest("GET", os.environ.get(EnvironmentVariables.AZURE_POD_IDENTITY_TOKEN_URL, IMDS_URL)) request.format_parameters(dict({"api-version": "2018-02-01", "resource": scope}, **identity_config)) return request @@ -45,7 +47,10 @@ def __init__(self, **kwargs): super(ImdsCredential, self).__init__() self._client = ManagedIdentityClient(get_request, **dict(PIPELINE_SETTINGS, **kwargs)) - self._endpoint_available = None # type: Optional[bool] + if EnvironmentVariables.AZURE_POD_IDENTITY_TOKEN_URL in os.environ: + self._endpoint_available = True # type: Optional[bool] + else: + self._endpoint_available = None self._user_assigned_identity = "client_id" in kwargs or "identity_config" in kwargs def _acquire_token_silently(self, *scopes): diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py index 5417ddfb5aeb..628e9e7dfb76 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py @@ -3,11 +3,13 @@ # Licensed under the MIT License. # ------------------------------------ import logging +import os from typing import TYPE_CHECKING from azure.core.exceptions import ClientAuthenticationError, HttpResponseError from ... import CredentialUnavailableError +from ..._constants import EnvironmentVariables from .._internal import AsyncContextManager from .._internal.get_token_mixin import GetTokenMixin from .._internal.managed_identity_client import AsyncManagedIdentityClient @@ -25,7 +27,10 @@ def __init__(self, **kwargs: "Any") -> None: super().__init__() self._client = AsyncManagedIdentityClient(get_request, **PIPELINE_SETTINGS, **kwargs) - self._endpoint_available = None # type: Optional[bool] + if EnvironmentVariables.AZURE_POD_IDENTITY_TOKEN_URL in os.environ: + self._endpoint_available = True # type: Optional[bool] + else: + self._endpoint_available = None self._user_assigned_identity = "client_id" in kwargs or "identity_config" in kwargs async def close(self) -> None: diff --git a/sdk/identity/azure-identity/tests/test_imds_credential.py b/sdk/identity/azure-identity/tests/test_imds_credential.py index f4b766248cfc..8a771496ec43 100644 --- a/sdk/identity/azure-identity/tests/test_imds_credential.py +++ b/sdk/identity/azure-identity/tests/test_imds_credential.py @@ -9,6 +9,7 @@ from azure.core.exceptions import ClientAuthenticationError from azure.identity import CredentialUnavailableError +from azure.identity._constants import EnvironmentVariables from azure.identity._credentials.imds import ImdsCredential, IMDS_URL, PIPELINE_SETTINGS from azure.identity._internal.user_agent import USER_AGENT import pytest @@ -176,6 +177,43 @@ def test_identity_config(): assert token == expected_token +def test_imds_url_override(): + url = "https://localhost/token" + expected_token = "***" + scope = "scope" + now = int(time.time()) + + transport = validating_transport( + requests=[ + Request( + base_url=url, + method="GET", + required_headers={"Metadata": "true", "User-Agent": USER_AGENT}, + required_params={"api-version": "2018-02-01", "resource": scope}, + ), + ], + responses=[ + mock_response( + json_payload={ + "access_token": expected_token, + "expires_in": 42, + "expires_on": now + 42, + "ext_expires_in": 42, + "not_before": now, + "resource": scope, + "token_type": "Bearer", + } + ), + ], + ) + + with mock.patch.dict("os.environ", {EnvironmentVariables.AZURE_POD_IDENTITY_TOKEN_URL: url}, clear=True): + credential = ImdsCredential(transport=transport) + token = credential.get_token(scope) + + assert token.token == expected_token + + @pytest.mark.usefixtures("record_imds_test") class RecordedTests(RecordedTestCase): def test_system_assigned(self): diff --git a/sdk/identity/azure-identity/tests/test_imds_credential_async.py b/sdk/identity/azure-identity/tests/test_imds_credential_async.py index 7ac43285562b..e7ce37dc4669 100644 --- a/sdk/identity/azure-identity/tests/test_imds_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_imds_credential_async.py @@ -9,9 +9,10 @@ from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError from azure.identity import CredentialUnavailableError +from azure.identity._constants import EnvironmentVariables +from azure.identity._credentials.imds import IMDS_URL from azure.identity._internal.user_agent import USER_AGENT from azure.identity.aio._credentials.imds import ImdsCredential, PIPELINE_SETTINGS -from azure.identity._credentials.imds import IMDS_URL import pytest from helpers import mock_response, Request @@ -211,6 +212,43 @@ async def test_identity_config(): assert token == expected_token +async def test_imds_url_override(): + url = "https://localhost/token" + expected_token = "***" + scope = "scope" + now = int(time.time()) + + transport = async_validating_transport( + requests=[ + Request( + base_url=url, + method="GET", + required_headers={"Metadata": "true", "User-Agent": USER_AGENT}, + required_params={"api-version": "2018-02-01", "resource": scope}, + ), + ], + responses=[ + mock_response( + json_payload={ + "access_token": expected_token, + "expires_in": 42, + "expires_on": now + 42, + "ext_expires_in": 42, + "not_before": now, + "resource": scope, + "token_type": "Bearer", + } + ), + ], + ) + + with mock.patch.dict("os.environ", {EnvironmentVariables.AZURE_POD_IDENTITY_TOKEN_URL: url}, clear=True): + credential = ImdsCredential(transport=transport) + token = await credential.get_token(scope) + + assert token.token == expected_token + + @pytest.mark.usefixtures("record_imds_test") class RecordedTests(RecordedTestCase): @await_test