Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 50 additions & 59 deletions tests/helpers/test_service.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Test service helpers."""

import asyncio
from collections.abc import Callable, Iterable
from collections.abc import Callable, Generator, Iterable
from copy import deepcopy
import dataclasses
import io
import threading
from typing import Any
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import AsyncMock, Mock, call as mock_call, patch

import pytest
from pytest_unordered import unordered
Expand Down Expand Up @@ -48,6 +48,7 @@
entity_registry as er,
service,
)
from homeassistant.helpers.entity import Entity
from homeassistant.loader import (
Integration,
async_get_integration,
Expand Down Expand Up @@ -75,17 +76,6 @@
SUPPORT_C = 4


@pytest.fixture
def mock_handle_entity_call():
"""Mock service platform call."""
with patch(
"homeassistant.helpers.service._handle_single_entity_call",
new_callable=AsyncMock,
return_value=None,
) as mock_call:
yield mock_call


@pytest.fixture
def mock_entities(hass: HomeAssistant) -> dict[str, MockEntity]:
"""Return mock entities in an ordered dict."""
Expand Down Expand Up @@ -127,6 +117,18 @@ def mock_entities(hass: HomeAssistant) -> dict[str, MockEntity]:
return entities


@pytest.fixture
def mock_entities_method() -> Generator[AsyncMock]:
"""Patch test_method on the base Entity class."""
mock = AsyncMock()

async def _stub(self: Entity, **kwargs: Any) -> None:
await mock(self, **kwargs)

with patch.object(Entity, "test_method", _stub, create=True):
yield mock


@pytest.fixture
def floor_area_mock(hass: HomeAssistant) -> None:
"""Mock including floor and area info."""
Expand Down Expand Up @@ -1686,7 +1688,9 @@ async def test_call_with_sync_attr(hass: HomeAssistant, mock_entities) -> None:


async def test_call_single_entity_uses_parallel_updates(
hass: HomeAssistant, mock_handle_entity_call, mock_entities
hass: HomeAssistant,
mock_entities: dict[str, MockEntity],
mock_entities_method: AsyncMock,
) -> None:
"""Check that single entity calls go through async_request_call."""
entity = mock_entities["light.kitchen"]
Expand All @@ -1698,7 +1702,7 @@ async def test_call_single_entity_uses_parallel_updates(
service_call = service.entity_service_call(
hass,
mock_entities,
Mock(),
"test_method",
ServiceCall(
hass,
"test_domain",
Expand All @@ -1710,13 +1714,13 @@ async def test_call_single_entity_uses_parallel_updates(

# Give the event loop a chance to progress; the call should be blocked
await asyncio.sleep(0)
assert mock_handle_entity_call.await_count == 0
mock_entities_method.assert_not_called()

# Release the semaphore so the call can proceed
entity.parallel_updates.release()
await task

assert mock_handle_entity_call.await_count == 1
mock_entities_method.assert_called_once_with(entity)


async def test_call_context_user_not_exist(hass: HomeAssistant) -> None:
Expand All @@ -1738,7 +1742,9 @@ async def test_call_context_user_not_exist(hass: HomeAssistant) -> None:


async def test_call_context_target_all(
hass: HomeAssistant, mock_handle_entity_call, mock_entities
hass: HomeAssistant,
mock_entities: dict[str, MockEntity],
mock_entities_method: AsyncMock,
) -> None:
"""Check we only target allowed entities if targeting all."""
with patch(
Expand All @@ -1753,7 +1759,7 @@ async def test_call_context_target_all(
await service.entity_service_call(
hass,
mock_entities,
Mock(),
"test_method",
ServiceCall(
hass,
"test_domain",
Expand All @@ -1763,12 +1769,13 @@ async def test_call_context_target_all(
),
)

assert len(mock_handle_entity_call.mock_calls) == 1
assert mock_handle_entity_call.mock_calls[0][1][1].entity_id == "light.kitchen"
mock_entities_method.assert_called_once_with(mock_entities["light.kitchen"])


async def test_call_context_target_specific(
hass: HomeAssistant, mock_handle_entity_call, mock_entities
hass: HomeAssistant,
mock_entities: dict[str, MockEntity],
mock_entities_method: AsyncMock,
) -> None:
"""Check targeting specific entities."""
with patch(
Expand All @@ -1782,7 +1789,7 @@ async def test_call_context_target_specific(
await service.entity_service_call(
hass,
mock_entities,
Mock(),
"test_method",
ServiceCall(
hass,
"test_domain",
Expand All @@ -1792,12 +1799,12 @@ async def test_call_context_target_specific(
),
)

assert len(mock_handle_entity_call.mock_calls) == 1
assert mock_handle_entity_call.mock_calls[0][1][1].entity_id == "light.kitchen"
mock_entities_method.assert_called_once_with(mock_entities["light.kitchen"])


async def test_call_context_target_specific_no_auth(
hass: HomeAssistant, mock_handle_entity_call, mock_entities
hass: HomeAssistant,
mock_entities: dict[str, MockEntity],
) -> None:
"""Check targeting specific entities without auth."""
with (
Expand All @@ -1810,7 +1817,7 @@ async def test_call_context_target_specific_no_auth(
await service.entity_service_call(
hass,
mock_entities,
Mock(),
"test_method",
ServiceCall(
hass,
"test_domain",
Expand All @@ -1825,32 +1832,35 @@ async def test_call_context_target_specific_no_auth(


async def test_call_no_context_target_all(
hass: HomeAssistant, mock_handle_entity_call, mock_entities
hass: HomeAssistant,
mock_entities: dict[str, MockEntity],
mock_entities_method: AsyncMock,
) -> None:
"""Check we target all if no user context given."""
await service.entity_service_call(
hass,
mock_entities,
Mock(),
"test_method",
ServiceCall(
hass, "test_domain", "test_service", data={"entity_id": ENTITY_MATCH_ALL}
),
)

assert len(mock_handle_entity_call.mock_calls) == 4
assert [call[1][1] for call in mock_handle_entity_call.mock_calls] == list(
mock_entities.values()
assert mock_entities_method.call_args_list == unordered(
mock_call(entity) for entity in mock_entities.values()
)


async def test_call_no_context_target_specific(
hass: HomeAssistant, mock_handle_entity_call, mock_entities
hass: HomeAssistant,
mock_entities: dict[str, MockEntity],
mock_entities_method: AsyncMock,
) -> None:
"""Check we can target specified entities."""
await service.entity_service_call(
hass,
mock_entities,
Mock(),
"test_method",
ServiceCall(
hass,
"test_domain",
Expand All @@ -1859,42 +1869,23 @@ async def test_call_no_context_target_specific(
),
)

assert len(mock_handle_entity_call.mock_calls) == 1
assert mock_handle_entity_call.mock_calls[0][1][1].entity_id == "light.kitchen"


async def test_call_with_match_all(
hass: HomeAssistant,
mock_handle_entity_call,
mock_entities,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Check we only target allowed entities if targeting all."""
await service.entity_service_call(
hass,
mock_entities,
Mock(),
ServiceCall(hass, "test_domain", "test_service", {"entity_id": "all"}),
)

assert len(mock_handle_entity_call.mock_calls) == 4
assert [call[1][1] for call in mock_handle_entity_call.mock_calls] == list(
mock_entities.values()
)
mock_entities_method.assert_called_once_with(mock_entities["light.kitchen"])


async def test_call_with_omit_entity_id(
hass: HomeAssistant, mock_handle_entity_call, mock_entities
hass: HomeAssistant,
mock_entities: dict[str, MockEntity],
mock_entities_method: AsyncMock,
) -> None:
"""Check service call if we do not pass an entity ID."""
await service.entity_service_call(
hass,
mock_entities,
Mock(),
"test_method",
ServiceCall(hass, "test_domain", "test_service"),
)

assert len(mock_handle_entity_call.mock_calls) == 0
mock_entities_method.assert_not_called()


async def test_register_admin_service(
Expand Down