Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion homeassistant/helpers/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,8 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non

# Skip entities that don't have the required feature.
if required_features is not None and not any(
entity.supported_features & feature_set for feature_set in required_features
entity.supported_features & feature_set == feature_set
for feature_set in required_features
):
continue

Expand Down
77 changes: 69 additions & 8 deletions tests/helpers/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
mock_service,
)

SUPPORT_A = 1
SUPPORT_B = 2
SUPPORT_C = 4


@pytest.fixture
def mock_handle_entity_call():
Expand All @@ -52,17 +56,31 @@ def mock_entities(hass):
entity_id="light.kitchen",
available=True,
should_poll=False,
supported_features=1,
supported_features=SUPPORT_A,
)
living_room = MockEntity(
entity_id="light.living_room",
available=True,
should_poll=False,
supported_features=0,
supported_features=SUPPORT_B,
)
bedroom = MockEntity(
entity_id="light.bedroom",
available=True,
should_poll=False,
supported_features=(SUPPORT_A | SUPPORT_B),
)
bathroom = MockEntity(
entity_id="light.bathroom",
available=True,
should_poll=False,
supported_features=(SUPPORT_B | SUPPORT_C),
)
entities = OrderedDict()
entities[kitchen.entity_id] = kitchen
entities[living_room.entity_id] = living_room
entities[bedroom.entity_id] = bedroom
entities[bathroom.entity_id] = bathroom
return entities


Expand Down Expand Up @@ -307,18 +325,61 @@ async def test_async_get_all_descriptions(hass):


async def test_call_with_required_features(hass, mock_entities):
"""Test service calls invoked only if entity has required feautres."""
"""Test service calls invoked only if entity has required features."""
test_service_mock = AsyncMock(return_value=None)
await service.entity_service_call(
hass,
[Mock(entities=mock_entities)],
test_service_mock,
ha.ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
required_features=[1],
required_features=[SUPPORT_A],
)
assert len(mock_entities) == 2
# Called once because only one of the entities had the required features

assert test_service_mock.call_count == 2
expected = [
mock_entities["light.kitchen"],
mock_entities["light.bedroom"],
]
actual = [call[0][0] for call in test_service_mock.call_args_list]
assert all(entity in actual for entity in expected)


async def test_call_with_both_required_features(hass, mock_entities):
"""Test service calls invoked only if entity has both features."""
test_service_mock = AsyncMock(return_value=None)
await service.entity_service_call(
hass,
[Mock(entities=mock_entities)],
test_service_mock,
ha.ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
required_features=[SUPPORT_A | SUPPORT_B],
)

assert test_service_mock.call_count == 1
Comment thread
MartinHjelmare marked this conversation as resolved.
Outdated
assert [call[0][0] for call in test_service_mock.call_args_list] == [
mock_entities["light.bedroom"]
]


async def test_call_with_one_of_required_features(hass, mock_entities):
"""Test service calls invoked with one entity having the required features."""
test_service_mock = AsyncMock(return_value=None)
await service.entity_service_call(
hass,
[Mock(entities=mock_entities)],
test_service_mock,
ha.ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
required_features=[SUPPORT_A, SUPPORT_C],
)

assert test_service_mock.call_count == 3
expected = [
mock_entities["light.kitchen"],
mock_entities["light.bedroom"],
mock_entities["light.bathroom"],
]
actual = [call[0][0] for call in test_service_mock.call_args_list]
assert all(entity in actual for entity in expected)


async def test_call_with_sync_func(hass, mock_entities):
Expand Down Expand Up @@ -458,7 +519,7 @@ async def test_call_no_context_target_all(hass, mock_handle_entity_call, mock_en
),
)

assert len(mock_handle_entity_call.mock_calls) == 2
assert len(mock_handle_entity_call.mock_calls) == 4
assert [call[1][1] for call in mock_handle_entity_call.mock_calls] == list(
mock_entities.values()
)
Expand Down Expand Up @@ -494,7 +555,7 @@ async def test_call_with_match_all(
ha.ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
)

assert len(mock_handle_entity_call.mock_calls) == 2
assert len(mock_handle_entity_call.mock_calls) == 4
assert [call[1][1] for call in mock_handle_entity_call.mock_calls] == list(
mock_entities.values()
)
Expand Down