Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from azure.core.tracing.decorator import distributed_trace

from ._search_service_client_base import SearchServiceClientBase
from ._generated import SearchServiceClient as _SearchServiceClient
from .._headers_mixin import HeadersMixin
from .._version import SDK_MONIKER
from ._datasources_client import SearchDataSourcesClient
from ._indexes_client import SearchIndexesClient
Expand All @@ -22,7 +22,7 @@
from azure.core.credentials import AzureKeyCredential


class SearchServiceClient(HeadersMixin): # pylint: disable=too-many-public-methods
class SearchServiceClient(SearchServiceClientBase): # pylint: disable=too-many-public-methods
"""A client to interact with an existing Azure search service.

:param endpoint: The URL endpoint of an Azure search service
Expand All @@ -44,13 +44,10 @@ class SearchServiceClient(HeadersMixin): # pylint: disable=too-many-public-meth

def __init__(self, endpoint, credential, **kwargs):
# type: (str, AzureKeyCredential, **Any) -> None

self._endpoint = endpoint # type: str
self._credential = credential # type: AzureKeyCredential
super(SearchServiceClient, self).__init__(endpoint, credential, **kwargs)
self._client = _SearchServiceClient(
endpoint=endpoint, sdk_moniker=SDK_MONIKER, **kwargs
) # type: _SearchServiceClient

self._indexes_client = SearchIndexesClient(endpoint, credential, **kwargs)

self._synonym_maps_client = SearchSynonymMapsClient(
Expand All @@ -65,10 +62,6 @@ def __init__(self, endpoint, credential, **kwargs):

self._indexers_client = SearchIndexersClient(endpoint, credential, **kwargs)

def __repr__(self):
# type: () -> str
return "<SearchServiceClient [endpoint={}]>".format(repr(self._endpoint))[:1024]

def __enter__(self):
# type: () -> SearchServiceClient
self._client.__enter__() # pylint:disable=no-member
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from typing import TYPE_CHECKING

from .._headers_mixin import HeadersMixin
from ._datasources_client import SearchDataSourcesClient
from ._indexes_client import SearchIndexesClient
from ._indexers_client import SearchIndexersClient
from ._skillsets_client import SearchSkillsetsClient
from ._synonym_maps_client import SearchSynonymMapsClient
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't seem those imports are necessary here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is pylint passing? ⭕

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed these - but im shocked to see pylint not catching unused-import


if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Dict, List, Optional, Sequence
from azure.core.credentials import AzureKeyCredential


class SearchServiceClientBase(HeadersMixin): # pylint: disable=too-many-public-methods
"""A client to interact with an existing Azure search service.

:param endpoint: The URL endpoint of an Azure search service
:type endpoint: str
:param credential: A credential to authorize search client requests
:type credential: ~azure.core.credentials import AzureKeyCredential

.. admonition:: Example:

.. literalinclude:: ../samples/sample_authentication.py
:start-after: [START create_search_service_with_key]
:end-before: [END create_search_service_with_key]
:language: python
:dedent: 4
:caption: Creating the SearchServiceClient with an API key.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the snippet be on this base class that users will never use directly themselves?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed these

"""

_ODATA_ACCEPT = "application/json;odata.metadata=minimal" # type: str

def __init__(self, endpoint, credential):
# type: (str, AzureKeyCredential) -> None

try:
if endpoint.lower().startswith('http') and not endpoint.lower().startswith('https'):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For clarity I would do it:

if not endpoint.lower().startswith('http'):
    endpoint = "https://" + endpoint
elif not endpoint.lower().startswith('https'):
    raise ValueError("Bearer token authentication is not permitted for non-TLS protected (non-https) URLs.")

raise ValueError("Bearer token authentication is not permitted for non-TLS protected (non-https) URLs.")
if not endpoint.lower().startswith('http'):
endpoint = "https://" + endpoint
except AttributeError:
raise ValueError("Endpoint must be a string.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could put this in a util function that has simple unit tests, then

self._endpoint = normalize_endpoint(endpoint)

below

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


self._endpoint = endpoint # type: str
self._credential = credential # type: AzureKeyCredential

def __repr__(self):
# type: () -> str
return "<SearchServiceClient [endpoint={}]>".format(repr(self._endpoint))[:1024]
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from azure.core.tracing.decorator_async import distributed_trace_async

from .._generated.aio import SearchServiceClient as _SearchServiceClient
from ..._headers_mixin import HeadersMixin
from .._search_service_client_base import SearchServiceClientBase
from ..._version import SDK_MONIKER
from ._datasources_client import SearchDataSourcesClient
from ._indexes_client import SearchIndexesClient
Expand All @@ -22,7 +22,7 @@
from azure.core.credentials import AzureKeyCredential


class SearchServiceClient(HeadersMixin): # pylint: disable=too-many-public-methods
class SearchServiceClient(SearchServiceClientBase): # pylint: disable=too-many-public-methods
"""A client to interact with an existing Azure search service.

:param endpoint: The URL endpoint of an Azure search service
Expand All @@ -45,8 +45,7 @@ class SearchServiceClient(HeadersMixin): # pylint: disable=too-many-public-meth
def __init__(self, endpoint, credential, **kwargs):
# type: (str, AzureKeyCredential, **Any) -> None

self._endpoint = endpoint # type: str
self._credential = credential # type: AzureKeyCredential
super().__init__(endpoint, credential, **kwargs)
self._client = _SearchServiceClient(
endpoint=endpoint, sdk_moniker=SDK_MONIKER, **kwargs
) # type: _SearchServiceClient
Expand All @@ -65,10 +64,6 @@ def __init__(self, endpoint, credential, **kwargs):

self._indexers_client = SearchIndexersClient(endpoint, credential, **kwargs)

def __repr__(self):
# type: () -> str
return "<SearchServiceClient [endpoint={}]>".format(repr(self._endpoint))[:1024]

async def __aenter__(self):
# type: () -> SearchServiceClient
await self._client.__aenter__() # pylint:disable=no-member
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_credential_roll(self):
def test_repr(self):
client = SearchServiceClient("endpoint", CREDENTIAL)
assert repr(client) == "<SearchServiceClient [endpoint={}]>".format(
repr("endpoint")
repr("https://endpoint")
)

@mock.patch(
Expand All @@ -52,3 +52,17 @@ def test_get_service_statistics(self, mock_get_stats):
assert mock_get_stats.called
assert mock_get_stats.call_args[0] == ()
assert mock_get_stats.call_args[1] == {"headers": client._headers}

def test_endpoint_https(self):
credential = AzureKeyCredential(key="old_api_key")
client = SearchServiceClient("endpoint", credential)
assert client._endpoint.startswith('https')

client = SearchServiceClient("https://endpoint", credential)
assert client._endpoint.startswith('https')

with pytest.raises(ValueError):
client = SearchServiceClient("http://endpoint", credential)

with pytest.raises(ValueError):
client = SearchServiceClient(12345, credential)