diff --git a/sdk/search/azure-search-documents/azure/search/documents/_service/_search_service_client.py b/sdk/search/azure-search-documents/azure/search/documents/_service/_search_service_client.py index 9c5205d3dd6b..19fa5c3bbdec 100644 --- a/sdk/search/azure-search-documents/azure/search/documents/_service/_search_service_client.py +++ b/sdk/search/azure-search-documents/azure/search/documents/_service/_search_service_client.py @@ -7,8 +7,8 @@ from azure.core.tracing.decorator import distributed_trace +from ._search_service_client_base import SearchServiceClientBase from ._generated import SearchServiceClient as _SearchServiceClient -from .._headers_mixin import HeadersMixin from .._version import SDK_MONIKER from ._datasources_client import SearchDataSourcesClient from ._indexes_client import SearchIndexesClient @@ -22,7 +22,7 @@ from azure.core.credentials import AzureKeyCredential -class SearchServiceClient(HeadersMixin): # pylint: disable=too-many-public-methods +class SearchServiceClient(SearchServiceClientBase): # pylint: disable=too-many-public-methods """A client to interact with an existing Azure search service. :param endpoint: The URL endpoint of an Azure search service @@ -44,13 +44,10 @@ class SearchServiceClient(HeadersMixin): # pylint: disable=too-many-public-meth def __init__(self, endpoint, credential, **kwargs): # type: (str, AzureKeyCredential, **Any) -> None - - self._endpoint = endpoint # type: str - self._credential = credential # type: AzureKeyCredential + super(SearchServiceClient, self).__init__(endpoint, credential, **kwargs) self._client = _SearchServiceClient( endpoint=endpoint, sdk_moniker=SDK_MONIKER, **kwargs ) # type: _SearchServiceClient - self._indexes_client = SearchIndexesClient(endpoint, credential, **kwargs) self._synonym_maps_client = SearchSynonymMapsClient( @@ -65,10 +62,6 @@ def __init__(self, endpoint, credential, **kwargs): self._indexers_client = SearchIndexersClient(endpoint, credential, **kwargs) - def __repr__(self): - # type: () -> str - return "".format(repr(self._endpoint))[:1024] - def __enter__(self): # type: () -> SearchServiceClient self._client.__enter__() # pylint:disable=no-member diff --git a/sdk/search/azure-search-documents/azure/search/documents/_service/_search_service_client_base.py b/sdk/search/azure-search-documents/azure/search/documents/_service/_search_service_client_base.py new file mode 100644 index 000000000000..cccac76a15d4 --- /dev/null +++ b/sdk/search/azure-search-documents/azure/search/documents/_service/_search_service_client_base.py @@ -0,0 +1,36 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from typing import TYPE_CHECKING + +from .._headers_mixin import HeadersMixin +from ._utils import _normalize_endpoint + +if TYPE_CHECKING: + # pylint:disable=unused-import,ungrouped-imports + from typing import Any, Dict, List, Optional, Sequence + from azure.core.credentials import AzureKeyCredential + + +class SearchServiceClientBase(HeadersMixin): # pylint: disable=too-many-public-methods + """A client to interact with an existing Azure search service. + + :param endpoint: The URL endpoint of an Azure search service + :type endpoint: str + :param credential: A credential to authorize search client requests + :type credential: ~azure.core.credentials import AzureKeyCredential + """ + + _ODATA_ACCEPT = "application/json;odata.metadata=minimal" # type: str + + def __init__(self, endpoint, credential): + # type: (str, AzureKeyCredential) -> None + + self._endpoint = _normalize_endpoint(endpoint) # type: str + self._credential = credential # type: AzureKeyCredential + + def __repr__(self): + # type: () -> str + return "".format(repr(self._endpoint))[:1024] diff --git a/sdk/search/azure-search-documents/azure/search/documents/_service/_utils.py b/sdk/search/azure-search-documents/azure/search/documents/_service/_utils.py index 6ad76de52769..0b0f6e6d2605 100644 --- a/sdk/search/azure-search-documents/azure/search/documents/_service/_utils.py +++ b/sdk/search/azure-search-documents/azure/search/documents/_service/_utils.py @@ -187,3 +187,13 @@ def get_access_conditions(model, match_condition=MatchConditions.Unconditionally return (error_map, AccessCondition(if_match=if_match, if_none_match=if_none_match)) except AttributeError: raise ValueError("Unable to get e_tag from the model") + +def _normalize_endpoint(endpoint): + try: + if not endpoint.lower().startswith('http'): + endpoint = "https://" + endpoint + elif not endpoint.lower().startswith('https'): + raise ValueError("Bearer token authentication is not permitted for non-TLS protected (non-https) URLs.") + return endpoint + except AttributeError: + raise ValueError("Endpoint must be a string.") diff --git a/sdk/search/azure-search-documents/azure/search/documents/_service/aio/_search_service_client_async.py b/sdk/search/azure-search-documents/azure/search/documents/_service/aio/_search_service_client_async.py index 6a5969b39443..82db403cbfca 100644 --- a/sdk/search/azure-search-documents/azure/search/documents/_service/aio/_search_service_client_async.py +++ b/sdk/search/azure-search-documents/azure/search/documents/_service/aio/_search_service_client_async.py @@ -8,7 +8,7 @@ from azure.core.tracing.decorator_async import distributed_trace_async from .._generated.aio import SearchServiceClient as _SearchServiceClient -from ..._headers_mixin import HeadersMixin +from .._search_service_client_base import SearchServiceClientBase from ..._version import SDK_MONIKER from ._datasources_client import SearchDataSourcesClient from ._indexes_client import SearchIndexesClient @@ -22,7 +22,7 @@ from azure.core.credentials import AzureKeyCredential -class SearchServiceClient(HeadersMixin): # pylint: disable=too-many-public-methods +class SearchServiceClient(SearchServiceClientBase): # pylint: disable=too-many-public-methods """A client to interact with an existing Azure search service. :param endpoint: The URL endpoint of an Azure search service @@ -45,8 +45,7 @@ class SearchServiceClient(HeadersMixin): # pylint: disable=too-many-public-meth def __init__(self, endpoint, credential, **kwargs): # type: (str, AzureKeyCredential, **Any) -> None - self._endpoint = endpoint # type: str - self._credential = credential # type: AzureKeyCredential + super().__init__(endpoint, credential, **kwargs) self._client = _SearchServiceClient( endpoint=endpoint, sdk_moniker=SDK_MONIKER, **kwargs ) # type: _SearchServiceClient @@ -65,10 +64,6 @@ def __init__(self, endpoint, credential, **kwargs): self._indexers_client = SearchIndexersClient(endpoint, credential, **kwargs) - def __repr__(self): - # type: () -> str - return "".format(repr(self._endpoint))[:1024] - async def __aenter__(self): # type: () -> SearchServiceClient await self._client.__aenter__() # pylint:disable=no-member diff --git a/sdk/search/azure-search-documents/tests/test_search_service_client.py b/sdk/search/azure-search-documents/tests/test_search_service_client.py index 094e6ca00934..167d46936cd9 100644 --- a/sdk/search/azure-search-documents/tests/test_search_service_client.py +++ b/sdk/search/azure-search-documents/tests/test_search_service_client.py @@ -40,7 +40,7 @@ def test_credential_roll(self): def test_repr(self): client = SearchServiceClient("endpoint", CREDENTIAL) assert repr(client) == "".format( - repr("endpoint") + repr("https://endpoint") ) @mock.patch( @@ -52,3 +52,17 @@ def test_get_service_statistics(self, mock_get_stats): assert mock_get_stats.called assert mock_get_stats.call_args[0] == () assert mock_get_stats.call_args[1] == {"headers": client._headers} + + def test_endpoint_https(self): + credential = AzureKeyCredential(key="old_api_key") + client = SearchServiceClient("endpoint", credential) + assert client._endpoint.startswith('https') + + client = SearchServiceClient("https://endpoint", credential) + assert client._endpoint.startswith('https') + + with pytest.raises(ValueError): + client = SearchServiceClient("http://endpoint", credential) + + with pytest.raises(ValueError): + client = SearchServiceClient(12345, credential)