Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 113 additions & 2 deletions homeassistant/helpers/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

from __future__ import annotations

from collections.abc import Callable
import dataclasses
import logging
from logging import Logger
from typing import TypeGuard
from typing import Any, TypeGuard

from homeassistant.const import (
ATTR_AREA_ID,
Expand All @@ -14,7 +16,14 @@
ATTR_LABEL_ID,
ENTITY_MATCH_NONE,
)
from homeassistant.core import HomeAssistant
from homeassistant.core import (
CALLBACK_TYPE,
Event,
EventStateChangedData,
HomeAssistant,
callback,
)
from homeassistant.exceptions import HomeAssistantError

from . import (
area_registry as ar,
Expand All @@ -25,8 +34,11 @@
group,
label_registry as lr,
)
from .event import async_track_state_change_event
from .typing import ConfigType

_LOGGER = logging.getLogger(__name__)


def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]:
"""Check if ids can match anything."""
Expand Down Expand Up @@ -238,3 +250,102 @@ def async_extract_referenced_entity_ids(
)

return selected


class TargetStateChangeTracker:
"""Helper class to manage state change tracking for targets."""

def __init__(
self,
hass: HomeAssistant,
selector_data: TargetSelectorData,
action: Callable[[Event[EventStateChangedData]], Any],
) -> None:
"""Initialize the state change tracker."""
self._hass = hass
self._selector_data = selector_data
self._action = action

self._state_change_unsub: CALLBACK_TYPE | None = None
self._registry_unsubs: list[CALLBACK_TYPE] = []

def async_setup(self) -> Callable[[], None]:
"""Set up the state change tracking."""
self._setup_registry_listeners()
self._track_entities_state_change()
Comment thread
abmantis marked this conversation as resolved.
return self._unsubscribe

def _track_entities_state_change(self) -> None:
"""Set up state change tracking for currently selected entities."""
selected = async_extract_referenced_entity_ids(
self._hass, self._selector_data, expand_group=False
)

@callback
def state_change_listener(event: Event[EventStateChangedData]) -> None:
"""Handle state change events."""
if (
event.data["entity_id"] in selected.referenced
or event.data["entity_id"] in selected.indirectly_referenced
):
self._action(event)

tracked_entities = selected.referenced.union(selected.indirectly_referenced)

_LOGGER.debug("Tracking state changes for entities: %s", tracked_entities)
self._state_change_unsub = async_track_state_change_event(
self._hass, tracked_entities, state_change_listener
)

def _setup_registry_listeners(self) -> None:
"""Set up listeners for registry changes that require resubscription."""

@callback
def resubscribe_state_change_event(event: Event[Any] | None = None) -> None:
"""Resubscribe to state change events when registry changes."""
if self._state_change_unsub:
self._state_change_unsub()
self._track_entities_state_change()

# Subscribe to registry updates that can change the entities to track:
# - Entity registry: entity added/removed; entity labels changed; entity area changed.
# - Device registry: device labels changed; device area changed.
# - Area registry: area floor changed.
#
# We don't track other registries (like floor or label registries) because their
# changes don't affect which entities are tracked.
self._registry_unsubs = [
self._hass.bus.async_listen(
er.EVENT_ENTITY_REGISTRY_UPDATED, resubscribe_state_change_event
),
self._hass.bus.async_listen(
dr.EVENT_DEVICE_REGISTRY_UPDATED, resubscribe_state_change_event
),
self._hass.bus.async_listen(
ar.EVENT_AREA_REGISTRY_UPDATED, resubscribe_state_change_event
),
]

def _unsubscribe(self) -> None:
"""Unsubscribe from all events."""
for registry_unsub in self._registry_unsubs:
registry_unsub()
self._registry_unsubs.clear()
if self._state_change_unsub:
self._state_change_unsub()
self._state_change_unsub = None


def async_track_target_selector_state_change_event(
hass: HomeAssistant,
target_selector_config: ConfigType,
action: Callable[[Event[EventStateChangedData]], Any],
) -> CALLBACK_TYPE:
"""Track state changes for entities referenced directly or indirectly in a target selector."""
selector_data = TargetSelectorData(target_selector_config)
if not selector_data.has_any_selector:
raise HomeAssistantError(
f"Target selector {target_selector_config} does not have any selectors defined"
)
tracker = TargetStateChangeTracker(hass, selector_data, action)
return tracker.async_setup()
194 changes: 190 additions & 4 deletions tests/helpers/test_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@

import pytest

# TODO(abmantis): is this import needed?
# To prevent circular import when running just this file
import homeassistant.components # noqa: F401
from homeassistant.components.group import Group
from homeassistant.const import (
ATTR_AREA_ID,
Expand All @@ -17,17 +14,21 @@
STATE_ON,
EntityCategory,
)
from homeassistant.core import HomeAssistant
from homeassistant.core import Event, EventStateChangedData, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
floor_registry as fr,
label_registry as lr,
target,
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import async_setup_component

from tests.common import (
MockConfigEntry,
RegistryEntryWithDefaults,
mock_area_registry,
mock_device_registry,
Expand Down Expand Up @@ -457,3 +458,188 @@ async def test_extract_referenced_entity_ids(
)
== expected_selected
)


async def test_async_track_target_selector_state_change_event_empty_selector(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test async_track_target_selector_state_change_event with empty selector."""

@callback
def state_change_callback(event):
"""Handle state change events."""

with pytest.raises(HomeAssistantError) as excinfo:
target.async_track_target_selector_state_change_event(
hass, {}, state_change_callback
)
assert str(excinfo.value) == (
"Target selector {} does not have any selectors defined"
)


async def test_async_track_target_selector_state_change_event(
hass: HomeAssistant,
) -> None:
"""Test async_track_target_selector_state_change_event with multiple targets."""
events: list[Event[EventStateChangedData]] = []

@callback
def state_change_callback(event: Event[EventStateChangedData]):
"""Handle state change events."""
events.append(event)

last_state = STATE_OFF

async def set_states_and_check_events(
entities_to_set_state: list[str], entities_to_assert_change: list[str]
) -> None:
"""Toggle the state entities and check for events."""
nonlocal last_state
last_state = STATE_ON if last_state == STATE_OFF else STATE_OFF
for entity_id in entities_to_set_state:
hass.states.async_set(entity_id, last_state)
await hass.async_block_till_done()

assert len(events) == len(entities_to_assert_change)
entities_seen = set()
for event in events:
entities_seen.add(event.data["entity_id"])
assert event.data["new_state"].state == last_state
assert entities_seen == set(entities_to_assert_change)
events.clear()

config_entry = MockConfigEntry(domain="test")
config_entry.add_to_hass(hass)

device_reg = dr.async_get(hass)
device_entry = device_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("test", "device_1")},
)

untargeted_device_entry = device_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("test", "area_device")},
)

entity_reg = er.async_get(hass)
device_entity = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="device_light",
device_id=device_entry.id,
).entity_id

untargeted_device_entity = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="area_device_light",
device_id=untargeted_device_entry.id,
).entity_id

untargeted_entity = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="untargeted_light",
).entity_id

targeted_entity = "light.test_light"

targeted_entities = [targeted_entity, device_entity]
await set_states_and_check_events(targeted_entities, [])

label = lr.async_get(hass).async_create("Test Label").name
area = ar.async_get(hass).async_create("Test Area").id
floor = fr.async_get(hass).async_create("Test Floor").floor_id

selector_config = {
ATTR_ENTITY_ID: targeted_entity,
ATTR_DEVICE_ID: device_entry.id,
ATTR_AREA_ID: area,
ATTR_FLOOR_ID: floor,
ATTR_LABEL_ID: label,
}
unsub = target.async_track_target_selector_state_change_event(
hass, selector_config, state_change_callback
)

# Test directly targeted entity and device
await set_states_and_check_events(targeted_entities, targeted_entities)

# Add new entity to the targeted device -> should trigger on state change
device_entity_2 = entity_reg.async_get_or_create(
domain="light",
platform="test",
unique_id="device_light_2",
device_id=device_entry.id,
).entity_id

targeted_entities = [targeted_entity, device_entity, device_entity_2]
await set_states_and_check_events(targeted_entities, targeted_entities)

# Test untargeted entity -> should not trigger
await set_states_and_check_events(
[*targeted_entities, untargeted_entity], targeted_entities
)

# Add label to untargeted entity -> should trigger now
entity_reg.async_update_entity(untargeted_entity, labels={label})
await set_states_and_check_events(
[*targeted_entities, untargeted_entity], [*targeted_entities, untargeted_entity]
)

# Remove label from untargeted entity -> should not trigger anymore
entity_reg.async_update_entity(untargeted_entity, labels={})
await set_states_and_check_events(
[*targeted_entities, untargeted_entity], targeted_entities
)

# Add area to untargeted entity -> should trigger now
entity_reg.async_update_entity(untargeted_entity, area_id=area)
await set_states_and_check_events(
[*targeted_entities, untargeted_entity], [*targeted_entities, untargeted_entity]
)

# Remove area from untargeted entity -> should not trigger anymore
entity_reg.async_update_entity(untargeted_entity, area_id=None)
await set_states_and_check_events(
[*targeted_entities, untargeted_entity], targeted_entities
)

# Add area to untargeted device -> should trigger on state change
device_reg.async_update_device(untargeted_device_entry.id, area_id=area)
await set_states_and_check_events(
[*targeted_entities, untargeted_device_entity],
[*targeted_entities, untargeted_device_entity],
)

# Remove area from untargeted device -> should not trigger anymore
device_reg.async_update_device(untargeted_device_entry.id, area_id=None)
await set_states_and_check_events(
[*targeted_entities, untargeted_device_entity], targeted_entities
)

# Set the untargeted area on the untargeted entity -> should not trigger
untracked_area = ar.async_get(hass).async_create("Untargeted Area").id
entity_reg.async_update_entity(untargeted_entity, area_id=untracked_area)
await set_states_and_check_events(
[*targeted_entities, untargeted_entity], targeted_entities
)

# Set targeted floor on the untargeted area -> should trigger now
ar.async_get(hass).async_update(untracked_area, floor_id=floor)
await set_states_and_check_events(
[*targeted_entities, untargeted_entity],
[*targeted_entities, untargeted_entity],
)

# Remove untargeted area from targeted floor -> should not trigger anymore
ar.async_get(hass).async_update(untracked_area, floor_id=None)
await set_states_and_check_events(
[*targeted_entities, untargeted_entity], targeted_entities
)

# After unsubscribing, changes should not trigger
unsub()
await set_states_and_check_events(targeted_entities, [])