diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index 652d436b82144..517bc144cb7ef 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -1,13 +1,13 @@ """Test service helpers.""" import asyncio -from collections.abc import Callable, Iterable +from collections.abc import Callable, Generator, Iterable from copy import deepcopy import dataclasses import io import threading from typing import Any -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, Mock, call as mock_call, patch import pytest from pytest_unordered import unordered @@ -48,6 +48,7 @@ entity_registry as er, service, ) +from homeassistant.helpers.entity import Entity from homeassistant.loader import ( Integration, async_get_integration, @@ -75,17 +76,6 @@ SUPPORT_C = 4 -@pytest.fixture -def mock_handle_entity_call(): - """Mock service platform call.""" - with patch( - "homeassistant.helpers.service._handle_single_entity_call", - new_callable=AsyncMock, - return_value=None, - ) as mock_call: - yield mock_call - - @pytest.fixture def mock_entities(hass: HomeAssistant) -> dict[str, MockEntity]: """Return mock entities in an ordered dict.""" @@ -127,6 +117,18 @@ def mock_entities(hass: HomeAssistant) -> dict[str, MockEntity]: return entities +@pytest.fixture +def mock_entities_method() -> Generator[AsyncMock]: + """Patch test_method on the base Entity class.""" + mock = AsyncMock() + + async def _stub(self: Entity, **kwargs: Any) -> None: + await mock(self, **kwargs) + + with patch.object(Entity, "test_method", _stub, create=True): + yield mock + + @pytest.fixture def floor_area_mock(hass: HomeAssistant) -> None: """Mock including floor and area info.""" @@ -1686,7 +1688,9 @@ async def test_call_with_sync_attr(hass: HomeAssistant, mock_entities) -> None: async def test_call_single_entity_uses_parallel_updates( - hass: HomeAssistant, mock_handle_entity_call, mock_entities + hass: HomeAssistant, + mock_entities: dict[str, MockEntity], + mock_entities_method: AsyncMock, ) -> None: """Check that single entity calls go through async_request_call.""" entity = mock_entities["light.kitchen"] @@ -1698,7 +1702,7 @@ async def test_call_single_entity_uses_parallel_updates( service_call = service.entity_service_call( hass, mock_entities, - Mock(), + "test_method", ServiceCall( hass, "test_domain", @@ -1710,13 +1714,13 @@ async def test_call_single_entity_uses_parallel_updates( # Give the event loop a chance to progress; the call should be blocked await asyncio.sleep(0) - assert mock_handle_entity_call.await_count == 0 + mock_entities_method.assert_not_called() # Release the semaphore so the call can proceed entity.parallel_updates.release() await task - assert mock_handle_entity_call.await_count == 1 + mock_entities_method.assert_called_once_with(entity) async def test_call_context_user_not_exist(hass: HomeAssistant) -> None: @@ -1738,7 +1742,9 @@ async def test_call_context_user_not_exist(hass: HomeAssistant) -> None: async def test_call_context_target_all( - hass: HomeAssistant, mock_handle_entity_call, mock_entities + hass: HomeAssistant, + mock_entities: dict[str, MockEntity], + mock_entities_method: AsyncMock, ) -> None: """Check we only target allowed entities if targeting all.""" with patch( @@ -1753,7 +1759,7 @@ async def test_call_context_target_all( await service.entity_service_call( hass, mock_entities, - Mock(), + "test_method", ServiceCall( hass, "test_domain", @@ -1763,12 +1769,13 @@ async def test_call_context_target_all( ), ) - assert len(mock_handle_entity_call.mock_calls) == 1 - assert mock_handle_entity_call.mock_calls[0][1][1].entity_id == "light.kitchen" + mock_entities_method.assert_called_once_with(mock_entities["light.kitchen"]) async def test_call_context_target_specific( - hass: HomeAssistant, mock_handle_entity_call, mock_entities + hass: HomeAssistant, + mock_entities: dict[str, MockEntity], + mock_entities_method: AsyncMock, ) -> None: """Check targeting specific entities.""" with patch( @@ -1782,7 +1789,7 @@ async def test_call_context_target_specific( await service.entity_service_call( hass, mock_entities, - Mock(), + "test_method", ServiceCall( hass, "test_domain", @@ -1792,12 +1799,12 @@ async def test_call_context_target_specific( ), ) - assert len(mock_handle_entity_call.mock_calls) == 1 - assert mock_handle_entity_call.mock_calls[0][1][1].entity_id == "light.kitchen" + mock_entities_method.assert_called_once_with(mock_entities["light.kitchen"]) async def test_call_context_target_specific_no_auth( - hass: HomeAssistant, mock_handle_entity_call, mock_entities + hass: HomeAssistant, + mock_entities: dict[str, MockEntity], ) -> None: """Check targeting specific entities without auth.""" with ( @@ -1810,7 +1817,7 @@ async def test_call_context_target_specific_no_auth( await service.entity_service_call( hass, mock_entities, - Mock(), + "test_method", ServiceCall( hass, "test_domain", @@ -1825,32 +1832,35 @@ async def test_call_context_target_specific_no_auth( async def test_call_no_context_target_all( - hass: HomeAssistant, mock_handle_entity_call, mock_entities + hass: HomeAssistant, + mock_entities: dict[str, MockEntity], + mock_entities_method: AsyncMock, ) -> None: """Check we target all if no user context given.""" await service.entity_service_call( hass, mock_entities, - Mock(), + "test_method", ServiceCall( hass, "test_domain", "test_service", data={"entity_id": ENTITY_MATCH_ALL} ), ) - assert len(mock_handle_entity_call.mock_calls) == 4 - assert [call[1][1] for call in mock_handle_entity_call.mock_calls] == list( - mock_entities.values() + assert mock_entities_method.call_args_list == unordered( + mock_call(entity) for entity in mock_entities.values() ) async def test_call_no_context_target_specific( - hass: HomeAssistant, mock_handle_entity_call, mock_entities + hass: HomeAssistant, + mock_entities: dict[str, MockEntity], + mock_entities_method: AsyncMock, ) -> None: """Check we can target specified entities.""" await service.entity_service_call( hass, mock_entities, - Mock(), + "test_method", ServiceCall( hass, "test_domain", @@ -1859,42 +1869,23 @@ async def test_call_no_context_target_specific( ), ) - assert len(mock_handle_entity_call.mock_calls) == 1 - assert mock_handle_entity_call.mock_calls[0][1][1].entity_id == "light.kitchen" - - -async def test_call_with_match_all( - hass: HomeAssistant, - mock_handle_entity_call, - mock_entities, - caplog: pytest.LogCaptureFixture, -) -> None: - """Check we only target allowed entities if targeting all.""" - await service.entity_service_call( - hass, - mock_entities, - Mock(), - ServiceCall(hass, "test_domain", "test_service", {"entity_id": "all"}), - ) - - assert len(mock_handle_entity_call.mock_calls) == 4 - assert [call[1][1] for call in mock_handle_entity_call.mock_calls] == list( - mock_entities.values() - ) + mock_entities_method.assert_called_once_with(mock_entities["light.kitchen"]) async def test_call_with_omit_entity_id( - hass: HomeAssistant, mock_handle_entity_call, mock_entities + hass: HomeAssistant, + mock_entities: dict[str, MockEntity], + mock_entities_method: AsyncMock, ) -> None: """Check service call if we do not pass an entity ID.""" await service.entity_service_call( hass, mock_entities, - Mock(), + "test_method", ServiceCall(hass, "test_domain", "test_service"), ) - assert len(mock_handle_entity_call.mock_calls) == 0 + mock_entities_method.assert_not_called() async def test_register_admin_service(