From 89a9ab699df74b1e2ef27856b32a787fa5e7363e Mon Sep 17 00:00:00 2001 From: abmantis Date: Thu, 3 Jul 2025 19:17:04 +0100 Subject: [PATCH 01/19] Add method to track entity state changes from target selectors --- homeassistant/helpers/trigger.py | 308 ++++++++++++++++++++++++++++++- tests/helpers/test_trigger.py | 240 +++++++++++++++++++++++- 2 files changed, 544 insertions(+), 4 deletions(-) diff --git a/homeassistant/helpers/trigger.py b/homeassistant/helpers/trigger.py index 66d1560ac70a2a..5f8f45834cd9e0 100644 --- a/homeassistant/helpers/trigger.py +++ b/homeassistant/helpers/trigger.py @@ -6,24 +6,33 @@ import asyncio from collections import defaultdict from collections.abc import Callable, Coroutine, Iterable +import dataclasses from dataclasses import dataclass, field import functools import logging -from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast +from typing import TYPE_CHECKING, Any, Protocol, TypedDict, TypeGuard, cast import voluptuous as vol from homeassistant.const import ( + ATTR_AREA_ID, + ATTR_DEVICE_ID, + ATTR_ENTITY_ID, + ATTR_FLOOR_ID, + ATTR_LABEL_ID, CONF_ALIAS, CONF_ENABLED, CONF_ID, CONF_PLATFORM, CONF_VARIABLES, + ENTITY_MATCH_NONE, ) from homeassistant.core import ( CALLBACK_TYPE, Context, + Event, HassJob, + HassJobType, HomeAssistant, callback, is_callback, @@ -40,7 +49,16 @@ from homeassistant.util.yaml import load_yaml_dict from homeassistant.util.yaml.loader import JSON_TYPE -from . import config_validation as cv +from . import ( + area_registry, + config_validation as cv, + device_registry, + entity_registry, + floor_registry, + label_registry, +) +from .event import EventStateChangedData, async_track_state_change_event +from .group import expand_entity_ids from .integration_platform import async_process_integration_platforms from .template import Template from .typing import ConfigType, TemplateVarsType @@ -617,3 +635,289 @@ async def async_get_all_descriptions( hass.data[TRIGGER_DESCRIPTION_CACHE] = new_descriptions_cache return new_descriptions_cache + + +def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]: + """Check if ids can match anything.""" + return ids not in (None, ENTITY_MATCH_NONE) + + +# TODO(abmantis): Since this is a copy from the service one, move it to a common place and use it in both places +class TargetSelectorData: + """Class to hold data of target selector.""" + + __slots__ = ("area_ids", "device_ids", "entity_ids", "floor_ids", "label_ids") + + def __init__(self, config: ConfigType) -> None: + """Extract ids from the config.""" + entity_ids: str | list | None = config.get(ATTR_ENTITY_ID) + device_ids: str | list | None = config.get(ATTR_DEVICE_ID) + area_ids: str | list | None = config.get(ATTR_AREA_ID) + floor_ids: str | list | None = config.get(ATTR_FLOOR_ID) + label_ids: str | list | None = config.get(ATTR_LABEL_ID) + + self.entity_ids = ( + set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set() + ) + self.device_ids = ( + set(cv.ensure_list(device_ids)) if _has_match(device_ids) else set() + ) + self.area_ids = set(cv.ensure_list(area_ids)) if _has_match(area_ids) else set() + self.floor_ids = ( + set(cv.ensure_list(floor_ids)) if _has_match(floor_ids) else set() + ) + self.label_ids = ( + set(cv.ensure_list(label_ids)) if _has_match(label_ids) else set() + ) + + @property + def has_any_selector(self) -> bool: + """Determine if any selectors are present.""" + return bool( + self.entity_ids + or self.device_ids + or self.area_ids + or self.floor_ids + or self.label_ids + ) + + +# TODO(abmantis): Since this is a copy from the service one, move it to a common place and use it in both places +@dataclasses.dataclass(slots=True) +class SelectedEntities: + """Class to hold the selected entities.""" + + # Entities that were explicitly mentioned. + referenced: set[str] = dataclasses.field(default_factory=set) + + # Entities that were referenced via device/area/floor/label ID. + # Should not trigger a warning when they don't exist. + indirectly_referenced: set[str] = dataclasses.field(default_factory=set) + + # Referenced items that could not be found. + missing_devices: set[str] = dataclasses.field(default_factory=set) + missing_areas: set[str] = dataclasses.field(default_factory=set) + missing_floors: set[str] = dataclasses.field(default_factory=set) + missing_labels: set[str] = dataclasses.field(default_factory=set) + + referenced_devices: set[str] = dataclasses.field(default_factory=set) + referenced_areas: set[str] = dataclasses.field(default_factory=set) + + def log_missing(self, missing_entities: set[str]) -> None: + """Log about missing items.""" + parts = [] + for label, items in ( + ("floors", self.missing_floors), + ("areas", self.missing_areas), + ("devices", self.missing_devices), + ("entities", missing_entities), + ("labels", self.missing_labels), + ): + if items: + parts.append(f"{label} {', '.join(sorted(items))}") + + if not parts: + return + + _LOGGER.warning( + "Referenced %s are missing or not currently available", + ", ".join(parts), + ) + + +# TODO(abmantis): Since this is a copy from the service one, move it to a common place and use it in both places +def async_extract_referenced_entity_ids( + hass: HomeAssistant, selector_data: TargetSelectorData, expand_group: bool = True +) -> SelectedEntities: + """Extract referenced entity IDs from a target selector.""" + selected = SelectedEntities() + + if not selector_data.has_any_selector: + return selected + + entity_ids: set[str] | list[str] = selector_data.entity_ids + if expand_group: + entity_ids = expand_entity_ids(hass, entity_ids) + + selected.referenced.update(entity_ids) + + if ( + not selector_data.device_ids + and not selector_data.area_ids + and not selector_data.floor_ids + and not selector_data.label_ids + ): + return selected + + entities = entity_registry.async_get(hass).entities + dev_reg = device_registry.async_get(hass) + area_reg = area_registry.async_get(hass) + + if selector_data.floor_ids: + floor_reg = floor_registry.async_get(hass) + for floor_id in selector_data.floor_ids: + if floor_id not in floor_reg.floors: + selected.missing_floors.add(floor_id) + + for area_id in selector_data.area_ids: + if area_id not in area_reg.areas: + selected.missing_areas.add(area_id) + + for device_id in selector_data.device_ids: + if device_id not in dev_reg.devices: + selected.missing_devices.add(device_id) + + if selector_data.label_ids: + label_reg = label_registry.async_get(hass) + for label_id in selector_data.label_ids: + if label_id not in label_reg.labels: + selected.missing_labels.add(label_id) + + for entity_entry in entities.get_entries_for_label(label_id): + if ( + entity_entry.entity_category is None + and entity_entry.hidden_by is None + ): + selected.indirectly_referenced.add(entity_entry.entity_id) + + for device_entry in dev_reg.devices.get_devices_for_label(label_id): + selected.referenced_devices.add(device_entry.id) + + for area_entry in area_reg.areas.get_areas_for_label(label_id): + selected.referenced_areas.add(area_entry.id) + + # Find areas for targeted floors + if selector_data.floor_ids: + selected.referenced_areas.update( + area_entry.id + for floor_id in selector_data.floor_ids + for area_entry in area_reg.areas.get_areas_for_floor(floor_id) + ) + + selected.referenced_areas.update(selector_data.area_ids) + selected.referenced_devices.update(selector_data.device_ids) + + if not selected.referenced_areas and not selected.referenced_devices: + return selected + + # Add indirectly referenced by device + selected.indirectly_referenced.update( + entry.entity_id + for device_id in selected.referenced_devices + for entry in entities.get_entries_for_device_id(device_id) + # Do not add entities which are hidden or which are config + # or diagnostic entities. + if (entry.entity_category is None and entry.hidden_by is None) + ) + + # Find devices for targeted areas + referenced_devices_by_area: set[str] = set() + if selected.referenced_areas: + for area_id in selected.referenced_areas: + referenced_devices_by_area.update( + device_entry.id + for device_entry in dev_reg.devices.get_devices_for_area_id(area_id) + ) + selected.referenced_devices.update(referenced_devices_by_area) + + # Add indirectly referenced by area + selected.indirectly_referenced.update( + entry.entity_id + for area_id in selected.referenced_areas + # The entity's area matches a targeted area + for entry in entities.get_entries_for_area_id(area_id) + # Do not add entities which are hidden or which are config + # or diagnostic entities. + if entry.entity_category is None and entry.hidden_by is None + ) + # Add indirectly referenced by area through device + selected.indirectly_referenced.update( + entry.entity_id + for device_id in referenced_devices_by_area + for entry in entities.get_entries_for_device_id(device_id) + # Do not add entities which are hidden or which are config + # or diagnostic entities. + if ( + entry.entity_category is None + and entry.hidden_by is None + and ( + # The entity's device matches a device referenced + # by an area and the entity + # has no explicitly set area + not entry.area_id + ) + ) + ) + + return selected + + +def async_track_target_selector_state_change_event( + hass: HomeAssistant, + target_selector_config: ConfigType, + action: Callable[[Event[EventStateChangedData]], Any], + job_type: HassJobType | None = None, +) -> CALLBACK_TYPE: + """Track state changes for entities referenced directly or indirectly (by device, area, label, etc) in a target selector.""" + selector_data = TargetSelectorData(target_selector_config) + if not selector_data.has_any_selector: + _LOGGER.warning( + "Target selector %s does not have any selectors defined", + target_selector_config, + ) + return lambda: None + + def track_entities_state_change() -> CALLBACK_TYPE: + selected = async_extract_referenced_entity_ids( + hass, selector_data, expand_group=False + ) + + @callback + def state_change_listener(event: Event[EventStateChangedData]) -> None: + """Handle state change events.""" + if ( + event.data["entity_id"] in selected.referenced + or event.data["entity_id"] in selected.indirectly_referenced + ): + action(event) + + tracked_entities = selected.referenced.union(selected.indirectly_referenced) + + _LOGGER.debug("Tracking state changes for entities: %s", tracked_entities) + return async_track_state_change_event( + hass, tracked_entities, state_change_listener, job_type=job_type + ) + + unsub_state_change = track_entities_state_change() + + def resubscribe_state_change_event(event: Event[Any] | None = None) -> None: + # TODO(abmantis): Check if there is a better way to do this + nonlocal unsub_state_change + unsub_state_change() + unsub_state_change = track_entities_state_change() + + unsub_registry_updates = [ + hass.bus.async_listen( + entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, + resubscribe_state_change_event, + # TODO(abmantis): filter for entities that match the target selector? + # event_filter=self._filter_entity_registry_changes, + ), + hass.bus.async_listen( + device_registry.EVENT_DEVICE_REGISTRY_UPDATED, + resubscribe_state_change_event, + ), + hass.bus.async_listen( + area_registry.EVENT_AREA_REGISTRY_UPDATED, + resubscribe_state_change_event, + ), + ] + + def unsub() -> None: + """Unsubscribe from state change and registry update events.""" + for registry_unsub in unsub_registry_updates: + registry_unsub() + unsub_registry_updates.clear() + unsub_state_change() + + return unsub diff --git a/tests/helpers/test_trigger.py b/tests/helpers/test_trigger.py index 27cde92d14ff24..1f3219c538f7c3 100644 --- a/tests/helpers/test_trigger.py +++ b/tests/helpers/test_trigger.py @@ -10,6 +10,15 @@ from homeassistant.components.sun import DOMAIN as DOMAIN_SUN from homeassistant.components.system_health import DOMAIN as DOMAIN_SYSTEM_HEALTH from homeassistant.components.tag import DOMAIN as DOMAIN_TAG +from homeassistant.const import ( + ATTR_AREA_ID, + ATTR_DEVICE_ID, + ATTR_ENTITY_ID, + ATTR_FLOOR_ID, + ATTR_LABEL_ID, + STATE_OFF, + STATE_ON, +) from homeassistant.core import ( CALLBACK_TYPE, Context, @@ -18,7 +27,14 @@ callback, ) from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import trigger +from homeassistant.helpers import ( + area_registry as ar, + device_registry as dr, + entity_registry as er, + floor_registry as fr, + label_registry as lr, + trigger, +) from homeassistant.helpers.trigger import ( DATA_PLUGGABLE_ACTIONS, PluggableAction, @@ -27,6 +43,7 @@ TriggerInfo, _async_get_trigger_platform, async_initialize_triggers, + async_track_target_selector_state_change_event, async_validate_trigger_config, ) from homeassistant.helpers.typing import ConfigType @@ -34,7 +51,13 @@ from homeassistant.setup import async_setup_component from homeassistant.util.yaml.loader import parse_yaml -from tests.common import MockModule, MockPlatform, mock_integration, mock_platform +from tests.common import ( + MockConfigEntry, + MockModule, + MockPlatform, + mock_integration, + mock_platform, +) async def test_bad_trigger_platform(hass: HomeAssistant) -> None: @@ -738,3 +761,216 @@ async def test_invalid_trigger_platform( await async_setup_component(hass, "test", {}) assert "Integration test does not provide trigger support, skipping" in caplog.text + + +async def test_async_track_target_selector_state_change_event_empty_selector( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: + """Test async_track_target_selector_state_change_event with empty selector.""" + calls = [] + + @callback + def state_change_callback(event): + """Handle state change events.""" + calls.append(event) + + unsub = async_track_target_selector_state_change_event( + hass, {}, state_change_callback + ) + + assert "Target selector {} does not have any selectors defined" in caplog.text + + # Test that no state changes are tracked + hass.states.async_set("light.test", "on") + await hass.async_block_till_done() + + assert len(calls) == 0 + + unsub() + + +async def test_async_track_target_selector_state_change_event( + hass: HomeAssistant, +) -> None: + """Test async_track_target_selector_state_change_event with multiple targets.""" + calls = [] + + @callback + def state_change_callback(event): + """Handle state change events.""" + calls.append(event) + + async def set_state(entity_id, state): + """Set the state of an entity.""" + hass.states.async_set(entity_id, state) + await hass.async_block_till_done() + + def assert_entity_calls_and_reset(entity_id: str) -> None: + assert len(calls) == 1 + assert calls[0].data["entity_id"] == entity_id + calls.clear() + + config_entry = MockConfigEntry(domain="test") + config_entry.add_to_hass(hass) + + device_reg = dr.async_get(hass) + device_entry = device_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, + identifiers={("test", "device_1")}, + ) + + untargeted_device_entry = device_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, + identifiers={("test", "area_device")}, + ) + + entity_reg = er.async_get(hass) + device_entity = entity_reg.async_get_or_create( + domain="light", + platform="test", + unique_id="device_light", + device_id=device_entry.id, + ).entity_id + + untargeted_device_entity = entity_reg.async_get_or_create( + domain="light", + platform="test", + unique_id="area_device_light", + device_id=untargeted_device_entry.id, + ).entity_id + + untargeted_entity = entity_reg.async_get_or_create( + domain="light", + platform="test", + unique_id="untargeted_light", + ).entity_id + + targeted_entity = "light.test_light" + + for entity_id in (targeted_entity, device_entity, untargeted_entity): + hass.states.async_set(entity_id, STATE_OFF) + await hass.async_block_till_done() + + label = lr.async_get(hass).async_create("Test Label").name + area = ar.async_get(hass).async_create("Test Area").id + floor = fr.async_get(hass).async_create("Test Floor").floor_id + + selector_config = { + ATTR_ENTITY_ID: targeted_entity, + ATTR_DEVICE_ID: device_entry.id, + ATTR_AREA_ID: area, + ATTR_FLOOR_ID: floor, + ATTR_LABEL_ID: label, + } + unsub = async_track_target_selector_state_change_event( + hass, selector_config, state_change_callback + ) + + # Test directly targeted entity and device + await set_state(targeted_entity, STATE_ON) + await set_state(device_entity, STATE_ON) + + assert len(calls) == 2 + assert calls[0].data["entity_id"] == targeted_entity + assert calls[0].data["old_state"].state == STATE_OFF + assert calls[0].data["new_state"].state == STATE_ON + assert calls[1].data["entity_id"] == device_entity + assert calls[1].data["old_state"].state == STATE_OFF + assert calls[1].data["new_state"].state == STATE_ON + calls.clear() + + # Add new entity to the targeted device -> should trigger on state change + device_entity_2 = entity_reg.async_get_or_create( + domain="light", + platform="test", + unique_id="device_light_2", + device_id=device_entry.id, + ).entity_id + await hass.async_block_till_done() + + await set_state(device_entity_2, STATE_ON) + + assert_entity_calls_and_reset(device_entity_2) + + # Test untargeted entity -> should not trigger + await set_state(untargeted_entity, STATE_ON) + + assert len(calls) == 0 + calls.clear() + + # Add label to untargeted entity -> should trigger now + entity_reg.async_update_entity(untargeted_entity, labels={label}) + await hass.async_block_till_done() + await set_state(untargeted_entity, STATE_OFF) + + assert_entity_calls_and_reset(untargeted_entity) + + # Remove label from untargeted entity -> should not trigger anymore + entity_reg.async_update_entity(untargeted_entity, labels={}) + await hass.async_block_till_done() + await set_state(untargeted_entity, STATE_ON) + await set_state(untargeted_entity, STATE_OFF) + + assert len(calls) == 0 + + # Add area to untargeted entity -> should trigger now + entity_reg.async_update_entity(untargeted_entity, area_id=area) + await hass.async_block_till_done() + await set_state(untargeted_entity, STATE_ON) + + assert_entity_calls_and_reset(untargeted_entity) + + # Remove area from untargeted entity -> should not trigger anymore + entity_reg.async_update_entity(untargeted_entity, area_id=None) + await hass.async_block_till_done() + await set_state(untargeted_entity, STATE_ON) + await set_state(untargeted_entity, STATE_OFF) + + assert len(calls) == 0 + + # Add area to untargeted device -> should trigger on state change + device_reg.async_update_device(untargeted_device_entry.id, area_id=area) + await hass.async_block_till_done() + + await set_state(untargeted_device_entity, STATE_ON) + + assert_entity_calls_and_reset(untargeted_device_entity) + + # Remove area from untargeted device -> should not trigger anymore + device_reg.async_update_device(untargeted_device_entry.id, area_id=None) + await hass.async_block_till_done() + await set_state(untargeted_device_entity, STATE_OFF) + await set_state(untargeted_device_entity, STATE_ON) + + assert len(calls) == 0 + + # Set the untargeted area on the untargeted entity -> should not trigger + untracked_area = ar.async_get(hass).async_create("Untargeted Area").id + entity_reg.async_update_entity(untargeted_entity, area_id=untracked_area) + await hass.async_block_till_done() + + await set_state(untargeted_entity, STATE_ON) + assert len(calls) == 0 + + # Set targeted floor on the untargeted area -> should trigger now + ar.async_get(hass).async_update(untracked_area, floor_id=floor) + await hass.async_block_till_done() + + await set_state(untargeted_entity, STATE_OFF) + assert_entity_calls_and_reset(untargeted_entity) + + # Remove untargeted area from targeted floor -> should not trigger anymore + ar.async_get(hass).async_update(untracked_area, floor_id=None) + await hass.async_block_till_done() + + await set_state(untargeted_entity, STATE_ON) + await set_state(untargeted_entity, STATE_OFF) + assert len(calls) == 0 + + # After unsubscribing, changes should not trigger + unsub() + + for entity_id in (targeted_entity, device_entity, untargeted_entity): + await set_state(entity_id, STATE_OFF) + await set_state(entity_id, STATE_ON) + assert len(calls) == 0 From 5d553e56411315b676e0e43ca5f14fd47e451dd4 Mon Sep 17 00:00:00 2001 From: abmantis Date: Thu, 3 Jul 2025 19:19:49 +0100 Subject: [PATCH 02/19] Use class to manage subscriptions --- homeassistant/helpers/trigger.py | 129 +++++++++++++++++++------------ 1 file changed, 81 insertions(+), 48 deletions(-) diff --git a/homeassistant/helpers/trigger.py b/homeassistant/helpers/trigger.py index 5f8f45834cd9e0..370954de7448a6 100644 --- a/homeassistant/helpers/trigger.py +++ b/homeassistant/helpers/trigger.py @@ -852,24 +852,32 @@ def async_extract_referenced_entity_ids( return selected -def async_track_target_selector_state_change_event( - hass: HomeAssistant, - target_selector_config: ConfigType, - action: Callable[[Event[EventStateChangedData]], Any], - job_type: HassJobType | None = None, -) -> CALLBACK_TYPE: - """Track state changes for entities referenced directly or indirectly (by device, area, label, etc) in a target selector.""" - selector_data = TargetSelectorData(target_selector_config) - if not selector_data.has_any_selector: - _LOGGER.warning( - "Target selector %s does not have any selectors defined", - target_selector_config, - ) - return lambda: None +class TargetSelectorStateChangeTracker: + """Helper class to manage state change tracking for target selectors.""" - def track_entities_state_change() -> CALLBACK_TYPE: + def __init__( + self, + hass: HomeAssistant, + selector_data: TargetSelectorData, + job_type: HassJobType | None, + action: Callable[[Event[EventStateChangedData]], Any], + ) -> None: + """Initialize the state change tracker.""" + self._hass = hass + self._selector_data = selector_data + self._job_type = job_type + self._action = action + + self._state_change_unsub: CALLBACK_TYPE | None = None + self._registry_unsubs: list[CALLBACK_TYPE] = [] + + self._setup_tracking() + self._setup_registry_listeners() + + def _track_entities_state_change(self) -> CALLBACK_TYPE: + """Set up state change tracking for currently selected entities.""" selected = async_extract_referenced_entity_ids( - hass, selector_data, expand_group=False + self._hass, self._selector_data, expand_group=False ) @callback @@ -879,45 +887,70 @@ def state_change_listener(event: Event[EventStateChangedData]) -> None: event.data["entity_id"] in selected.referenced or event.data["entity_id"] in selected.indirectly_referenced ): - action(event) + self._action(event) tracked_entities = selected.referenced.union(selected.indirectly_referenced) _LOGGER.debug("Tracking state changes for entities: %s", tracked_entities) return async_track_state_change_event( - hass, tracked_entities, state_change_listener, job_type=job_type + self._hass, tracked_entities, state_change_listener, job_type=self._job_type ) - unsub_state_change = track_entities_state_change() - - def resubscribe_state_change_event(event: Event[Any] | None = None) -> None: - # TODO(abmantis): Check if there is a better way to do this - nonlocal unsub_state_change - unsub_state_change() - unsub_state_change = track_entities_state_change() - - unsub_registry_updates = [ - hass.bus.async_listen( - entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, - resubscribe_state_change_event, - # TODO(abmantis): filter for entities that match the target selector? - # event_filter=self._filter_entity_registry_changes, - ), - hass.bus.async_listen( - device_registry.EVENT_DEVICE_REGISTRY_UPDATED, - resubscribe_state_change_event, - ), - hass.bus.async_listen( - area_registry.EVENT_AREA_REGISTRY_UPDATED, - resubscribe_state_change_event, - ), - ] + def _setup_tracking(self) -> None: + """Initialize state change tracking.""" + self._state_change_unsub = self._track_entities_state_change() + + def _setup_registry_listeners(self) -> None: + """Set up listeners for registry changes that require resubscription.""" - def unsub() -> None: - """Unsubscribe from state change and registry update events.""" - for registry_unsub in unsub_registry_updates: + @callback + def resubscribe_state_change_event(event: Event[Any] | None = None) -> None: + """Resubscribe to state change events when registry changes.""" + if self._state_change_unsub: + self._state_change_unsub() + self._state_change_unsub = self._track_entities_state_change() + + self._registry_unsubs = [ + self._hass.bus.async_listen( + entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, + resubscribe_state_change_event, + # TODO(abmantis): filter for entities that match the target selector? + # event_filter=self._filter_entity_registry_changes, + ), + self._hass.bus.async_listen( + device_registry.EVENT_DEVICE_REGISTRY_UPDATED, + resubscribe_state_change_event, + ), + self._hass.bus.async_listen( + area_registry.EVENT_AREA_REGISTRY_UPDATED, + resubscribe_state_change_event, + ), + ] + + def unsub(self) -> None: + """Unsubscribe from all events.""" + for registry_unsub in self._registry_unsubs: registry_unsub() - unsub_registry_updates.clear() - unsub_state_change() + self._registry_unsubs.clear() + if self._state_change_unsub: + self._state_change_unsub() + self._state_change_unsub = None + + +def async_track_target_selector_state_change_event( + hass: HomeAssistant, + target_selector_config: ConfigType, + action: Callable[[Event[EventStateChangedData]], Any], + job_type: HassJobType | None = None, +) -> CALLBACK_TYPE: + """Track state changes for entities referenced directly or indirectly (by device, area, label, etc) in a target selector.""" + selector_data = TargetSelectorData(target_selector_config) + if not selector_data.has_any_selector: + _LOGGER.warning( + "Target selector %s does not have any selectors defined", + target_selector_config, + ) + return lambda: None - return unsub + tracker = TargetSelectorStateChangeTracker(hass, selector_data, job_type, action) + return tracker.unsub From 0ead4c033e7bb4c8a5e2d193eaf5608c23e0895b Mon Sep 17 00:00:00 2001 From: abmantis Date: Thu, 3 Jul 2025 22:55:38 +0100 Subject: [PATCH 03/19] Move target selector extractor method to common module --- homeassistant/helpers/selector.py | 240 +++++++++++++++++++++++++++- homeassistant/helpers/service.py | 250 ++---------------------------- 2 files changed, 253 insertions(+), 237 deletions(-) diff --git a/homeassistant/helpers/selector.py b/homeassistant/helpers/selector.py index acb91ddc148bfa..50beaae3057ca0 100644 --- a/homeassistant/helpers/selector.py +++ b/homeassistant/helpers/selector.py @@ -3,21 +3,41 @@ from __future__ import annotations from collections.abc import Callable, Mapping, Sequence +import dataclasses from enum import StrEnum from functools import cache import importlib -from typing import Any, Literal, Required, TypedDict, cast +from logging import Logger +from typing import Any, Literal, Required, TypedDict, TypeGuard, cast from uuid import UUID import voluptuous as vol -from homeassistant.const import CONF_MODE, CONF_UNIT_OF_MEASUREMENT -from homeassistant.core import split_entity_id, valid_entity_id +from homeassistant.const import ( + ATTR_AREA_ID, + ATTR_DEVICE_ID, + ATTR_ENTITY_ID, + ATTR_FLOOR_ID, + ATTR_LABEL_ID, + CONF_MODE, + CONF_UNIT_OF_MEASUREMENT, + ENTITY_MATCH_NONE, +) +from homeassistant.core import HomeAssistant, split_entity_id, valid_entity_id from homeassistant.generated.countries import COUNTRIES from homeassistant.util import decorator from homeassistant.util.yaml import dumper -from . import config_validation as cv +from . import ( + area_registry as ar, + config_validation as cv, + device_registry as dr, + entity_registry as er, + floor_registry as fr, + group, + label_registry as lr, +) +from .typing import ConfigType SELECTORS: decorator.Registry[str, type[Selector]] = decorator.Registry() @@ -1551,3 +1571,215 @@ def __call__(self, data: Any) -> str: dumper, "tag:yaml.org,2002:map", value.serialize() ), ) + + +def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]: + """Check if ids can match anything.""" + return ids not in (None, ENTITY_MATCH_NONE) + + +class TargetSelectorData: + """Class to hold data of target selector.""" + + __slots__ = ("area_ids", "device_ids", "entity_ids", "floor_ids", "label_ids") + + def __init__(self, config: ConfigType) -> None: + """Extract ids from the config.""" + entity_ids: str | list | None = config.get(ATTR_ENTITY_ID) + device_ids: str | list | None = config.get(ATTR_DEVICE_ID) + area_ids: str | list | None = config.get(ATTR_AREA_ID) + floor_ids: str | list | None = config.get(ATTR_FLOOR_ID) + label_ids: str | list | None = config.get(ATTR_LABEL_ID) + + self.entity_ids = ( + set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set() + ) + self.device_ids = ( + set(cv.ensure_list(device_ids)) if _has_match(device_ids) else set() + ) + self.area_ids = set(cv.ensure_list(area_ids)) if _has_match(area_ids) else set() + self.floor_ids = ( + set(cv.ensure_list(floor_ids)) if _has_match(floor_ids) else set() + ) + self.label_ids = ( + set(cv.ensure_list(label_ids)) if _has_match(label_ids) else set() + ) + + @property + def has_any_selector(self) -> bool: + """Determine if any selectors are present.""" + return bool( + self.entity_ids + or self.device_ids + or self.area_ids + or self.floor_ids + or self.label_ids + ) + + +@dataclasses.dataclass(slots=True) +class SelectedEntities: + """Class to hold the selected entities.""" + + # Entities that were explicitly mentioned. + referenced: set[str] = dataclasses.field(default_factory=set) + + # Entities that were referenced via device/area/floor/label ID. + # Should not trigger a warning when they don't exist. + indirectly_referenced: set[str] = dataclasses.field(default_factory=set) + + # Referenced items that could not be found. + missing_devices: set[str] = dataclasses.field(default_factory=set) + missing_areas: set[str] = dataclasses.field(default_factory=set) + missing_floors: set[str] = dataclasses.field(default_factory=set) + missing_labels: set[str] = dataclasses.field(default_factory=set) + + referenced_devices: set[str] = dataclasses.field(default_factory=set) + referenced_areas: set[str] = dataclasses.field(default_factory=set) + + def log_missing(self, missing_entities: set[str], logger: Logger) -> None: + """Log about missing items.""" + parts = [] + for label, items in ( + ("floors", self.missing_floors), + ("areas", self.missing_areas), + ("devices", self.missing_devices), + ("entities", missing_entities), + ("labels", self.missing_labels), + ): + if items: + parts.append(f"{label} {', '.join(sorted(items))}") + + if not parts: + return + + logger.warning( + "Referenced %s are missing or not currently available", + ", ".join(parts), + ) + + +def async_extract_referenced_entity_ids( + hass: HomeAssistant, selector_data: TargetSelectorData, expand_group: bool = True +) -> SelectedEntities: + """Extract referenced entity IDs from a target selector.""" + selected = SelectedEntities() + + if not selector_data.has_any_selector: + return selected + + entity_ids: set[str] | list[str] = selector_data.entity_ids + if expand_group: + entity_ids = group.expand_entity_ids(hass, entity_ids) + + selected.referenced.update(entity_ids) + + if ( + not selector_data.device_ids + and not selector_data.area_ids + and not selector_data.floor_ids + and not selector_data.label_ids + ): + return selected + + entities = er.async_get(hass).entities + dev_reg = dr.async_get(hass) + area_reg = ar.async_get(hass) + + if selector_data.floor_ids: + floor_reg = fr.async_get(hass) + for floor_id in selector_data.floor_ids: + if floor_id not in floor_reg.floors: + selected.missing_floors.add(floor_id) + + for area_id in selector_data.area_ids: + if area_id not in area_reg.areas: + selected.missing_areas.add(area_id) + + for device_id in selector_data.device_ids: + if device_id not in dev_reg.devices: + selected.missing_devices.add(device_id) + + if selector_data.label_ids: + label_reg = lr.async_get(hass) + for label_id in selector_data.label_ids: + if label_id not in label_reg.labels: + selected.missing_labels.add(label_id) + + for entity_entry in entities.get_entries_for_label(label_id): + if ( + entity_entry.entity_category is None + and entity_entry.hidden_by is None + ): + selected.indirectly_referenced.add(entity_entry.entity_id) + + for device_entry in dev_reg.devices.get_devices_for_label(label_id): + selected.referenced_devices.add(device_entry.id) + + for area_entry in area_reg.areas.get_areas_for_label(label_id): + selected.referenced_areas.add(area_entry.id) + + # Find areas for targeted floors + if selector_data.floor_ids: + selected.referenced_areas.update( + area_entry.id + for floor_id in selector_data.floor_ids + for area_entry in area_reg.areas.get_areas_for_floor(floor_id) + ) + + selected.referenced_areas.update(selector_data.area_ids) + selected.referenced_devices.update(selector_data.device_ids) + + if not selected.referenced_areas and not selected.referenced_devices: + return selected + + # Add indirectly referenced by device + selected.indirectly_referenced.update( + entry.entity_id + for device_id in selected.referenced_devices + for entry in entities.get_entries_for_device_id(device_id) + # Do not add entities which are hidden or which are config + # or diagnostic entities. + if (entry.entity_category is None and entry.hidden_by is None) + ) + + # Find devices for targeted areas + referenced_devices_by_area: set[str] = set() + if selected.referenced_areas: + for area_id in selected.referenced_areas: + referenced_devices_by_area.update( + device_entry.id + for device_entry in dev_reg.devices.get_devices_for_area_id(area_id) + ) + selected.referenced_devices.update(referenced_devices_by_area) + + # Add indirectly referenced by area + selected.indirectly_referenced.update( + entry.entity_id + for area_id in selected.referenced_areas + # The entity's area matches a targeted area + for entry in entities.get_entries_for_area_id(area_id) + # Do not add entities which are hidden or which are config + # or diagnostic entities. + if entry.entity_category is None and entry.hidden_by is None + ) + # Add indirectly referenced by area through device + selected.indirectly_referenced.update( + entry.entity_id + for device_id in referenced_devices_by_area + for entry in entities.get_entries_for_device_id(device_id) + # Do not add entities which are hidden or which are config + # or diagnostic entities. + if ( + entry.entity_category is None + and entry.hidden_by is None + and ( + # The entity's device matches a device referenced + # by an area and the entity + # has no explicitly set area + not entry.area_id + ) + ) + ) + + return selected diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 51d9c97ceebd12..0d83806db84863 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -4,22 +4,17 @@ import asyncio from collections.abc import Callable, Coroutine, Iterable -import dataclasses from enum import Enum from functools import cache, partial import logging from types import ModuleType -from typing import TYPE_CHECKING, Any, TypedDict, TypeGuard, cast +from typing import TYPE_CHECKING, Any, TypedDict, cast import voluptuous as vol from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_CONTROL from homeassistant.const import ( - ATTR_AREA_ID, - ATTR_DEVICE_ID, ATTR_ENTITY_ID, - ATTR_FLOOR_ID, - ATTR_LABEL_ID, CONF_ACTION, CONF_ENTITY_ID, CONF_SERVICE_DATA, @@ -54,17 +49,18 @@ from homeassistant.util.yaml.loader import JSON_TYPE from . import ( - area_registry, config_validation as cv, device_registry, entity_registry, - floor_registry, - label_registry, template, translation, ) -from .group import expand_entity_ids -from .selector import TargetSelector +from .selector import ( + SelectedEntities, + TargetSelector, + TargetSelectorData, + async_extract_referenced_entity_ids, +) from .typing import ConfigType, TemplateVarsType, VolDictType, VolSchemaType if TYPE_CHECKING: @@ -223,89 +219,6 @@ class ServiceParams(TypedDict): target: dict | None -class ServiceTargetSelector: - """Class to hold a target selector for a service.""" - - __slots__ = ("area_ids", "device_ids", "entity_ids", "floor_ids", "label_ids") - - def __init__(self, service_call: ServiceCall) -> None: - """Extract ids from service call data.""" - service_call_data = service_call.data - entity_ids: str | list | None = service_call_data.get(ATTR_ENTITY_ID) - device_ids: str | list | None = service_call_data.get(ATTR_DEVICE_ID) - area_ids: str | list | None = service_call_data.get(ATTR_AREA_ID) - floor_ids: str | list | None = service_call_data.get(ATTR_FLOOR_ID) - label_ids: str | list | None = service_call_data.get(ATTR_LABEL_ID) - - self.entity_ids = ( - set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set() - ) - self.device_ids = ( - set(cv.ensure_list(device_ids)) if _has_match(device_ids) else set() - ) - self.area_ids = set(cv.ensure_list(area_ids)) if _has_match(area_ids) else set() - self.floor_ids = ( - set(cv.ensure_list(floor_ids)) if _has_match(floor_ids) else set() - ) - self.label_ids = ( - set(cv.ensure_list(label_ids)) if _has_match(label_ids) else set() - ) - - @property - def has_any_selector(self) -> bool: - """Determine if any selectors are present.""" - return bool( - self.entity_ids - or self.device_ids - or self.area_ids - or self.floor_ids - or self.label_ids - ) - - -@dataclasses.dataclass(slots=True) -class SelectedEntities: - """Class to hold the selected entities.""" - - # Entities that were explicitly mentioned. - referenced: set[str] = dataclasses.field(default_factory=set) - - # Entities that were referenced via device/area/floor/label ID. - # Should not trigger a warning when they don't exist. - indirectly_referenced: set[str] = dataclasses.field(default_factory=set) - - # Referenced items that could not be found. - missing_devices: set[str] = dataclasses.field(default_factory=set) - missing_areas: set[str] = dataclasses.field(default_factory=set) - missing_floors: set[str] = dataclasses.field(default_factory=set) - missing_labels: set[str] = dataclasses.field(default_factory=set) - - # Referenced devices - referenced_devices: set[str] = dataclasses.field(default_factory=set) - referenced_areas: set[str] = dataclasses.field(default_factory=set) - - def log_missing(self, missing_entities: set[str]) -> None: - """Log about missing items.""" - parts = [] - for label, items in ( - ("floors", self.missing_floors), - ("areas", self.missing_areas), - ("devices", self.missing_devices), - ("entities", missing_entities), - ("labels", self.missing_labels), - ): - if items: - parts.append(f"{label} {', '.join(sorted(items))}") - - if not parts: - return - - _LOGGER.warning( - "Referenced %s are missing or not currently available", - ", ".join(parts), - ) - - @bind_hass def call_from_config( hass: HomeAssistant, @@ -464,7 +377,8 @@ async def async_extract_entities[_EntityT: Entity]( if data_ent_id == ENTITY_MATCH_ALL: return [entity for entity in entities if entity.available] - referenced = async_extract_referenced_entity_ids(hass, service_call, expand_group) + selector_data = TargetSelectorData(service_call.data) + referenced = async_extract_referenced_entity_ids(hass, selector_data, expand_group) combined = referenced.referenced | referenced.indirectly_referenced found = [] @@ -480,7 +394,7 @@ async def async_extract_entities[_EntityT: Entity]( found.append(entity) - referenced.log_missing(referenced.referenced & combined) + referenced.log_missing(referenced.referenced & combined, _LOGGER) return found @@ -493,149 +407,18 @@ async def async_extract_entity_ids( Will convert group entity ids to the entity ids it represents. """ - referenced = async_extract_referenced_entity_ids(hass, service_call, expand_group) + selector_data = TargetSelectorData(service_call.data) + referenced = async_extract_referenced_entity_ids(hass, selector_data, expand_group) return referenced.referenced | referenced.indirectly_referenced -def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]: - """Check if ids can match anything.""" - return ids not in (None, ENTITY_MATCH_NONE) - - -@bind_hass -def async_extract_referenced_entity_ids( - hass: HomeAssistant, service_call: ServiceCall, expand_group: bool = True -) -> SelectedEntities: - """Extract referenced entity IDs from a service call.""" - selector = ServiceTargetSelector(service_call) - selected = SelectedEntities() - - if not selector.has_any_selector: - return selected - - entity_ids: set[str] | list[str] = selector.entity_ids - if expand_group: - entity_ids = expand_entity_ids(hass, entity_ids) - - selected.referenced.update(entity_ids) - - if ( - not selector.device_ids - and not selector.area_ids - and not selector.floor_ids - and not selector.label_ids - ): - return selected - - entities = entity_registry.async_get(hass).entities - dev_reg = device_registry.async_get(hass) - area_reg = area_registry.async_get(hass) - - if selector.floor_ids: - floor_reg = floor_registry.async_get(hass) - for floor_id in selector.floor_ids: - if floor_id not in floor_reg.floors: - selected.missing_floors.add(floor_id) - - for area_id in selector.area_ids: - if area_id not in area_reg.areas: - selected.missing_areas.add(area_id) - - for device_id in selector.device_ids: - if device_id not in dev_reg.devices: - selected.missing_devices.add(device_id) - - if selector.label_ids: - label_reg = label_registry.async_get(hass) - for label_id in selector.label_ids: - if label_id not in label_reg.labels: - selected.missing_labels.add(label_id) - - for entity_entry in entities.get_entries_for_label(label_id): - if ( - entity_entry.entity_category is None - and entity_entry.hidden_by is None - ): - selected.indirectly_referenced.add(entity_entry.entity_id) - - for device_entry in dev_reg.devices.get_devices_for_label(label_id): - selected.referenced_devices.add(device_entry.id) - - for area_entry in area_reg.areas.get_areas_for_label(label_id): - selected.referenced_areas.add(area_entry.id) - - # Find areas for targeted floors - if selector.floor_ids: - selected.referenced_areas.update( - area_entry.id - for floor_id in selector.floor_ids - for area_entry in area_reg.areas.get_areas_for_floor(floor_id) - ) - - selected.referenced_areas.update(selector.area_ids) - selected.referenced_devices.update(selector.device_ids) - - if not selected.referenced_areas and not selected.referenced_devices: - return selected - - # Add indirectly referenced by device - selected.indirectly_referenced.update( - entry.entity_id - for device_id in selected.referenced_devices - for entry in entities.get_entries_for_device_id(device_id) - # Do not add entities which are hidden or which are config - # or diagnostic entities. - if (entry.entity_category is None and entry.hidden_by is None) - ) - - # Find devices for targeted areas - referenced_devices_by_area: set[str] = set() - if selected.referenced_areas: - for area_id in selected.referenced_areas: - referenced_devices_by_area.update( - device_entry.id - for device_entry in dev_reg.devices.get_devices_for_area_id(area_id) - ) - selected.referenced_devices.update(referenced_devices_by_area) - - # Add indirectly referenced by area - selected.indirectly_referenced.update( - entry.entity_id - for area_id in selected.referenced_areas - # The entity's area matches a targeted area - for entry in entities.get_entries_for_area_id(area_id) - # Do not add entities which are hidden or which are config - # or diagnostic entities. - if entry.entity_category is None and entry.hidden_by is None - ) - # Add indirectly referenced by area through device - selected.indirectly_referenced.update( - entry.entity_id - for device_id in referenced_devices_by_area - for entry in entities.get_entries_for_device_id(device_id) - # Do not add entities which are hidden or which are config - # or diagnostic entities. - if ( - entry.entity_category is None - and entry.hidden_by is None - and ( - # The entity's device matches a device referenced - # by an area and the entity - # has no explicitly set area - not entry.area_id - ) - ) - ) - - return selected - - @bind_hass async def async_extract_config_entry_ids( hass: HomeAssistant, service_call: ServiceCall, expand_group: bool = True ) -> set[str]: """Extract referenced config entry ids from a service call.""" - referenced = async_extract_referenced_entity_ids(hass, service_call, expand_group) + selector_data = TargetSelectorData(service_call.data) + referenced = async_extract_referenced_entity_ids(hass, selector_data, expand_group) ent_reg = entity_registry.async_get(hass) dev_reg = device_registry.async_get(hass) config_entry_ids: set[str] = set() @@ -950,7 +733,8 @@ async def entity_service_call( all_referenced: set[str] | None = None else: # A set of entities we're trying to target. - referenced = async_extract_referenced_entity_ids(hass, call, True) + selector_data = TargetSelectorData(call.data) + referenced = async_extract_referenced_entity_ids(hass, selector_data, True) all_referenced = referenced.referenced | referenced.indirectly_referenced # If the service function is a string, we'll pass it the service call data @@ -975,7 +759,7 @@ async def entity_service_call( missing = referenced.referenced.copy() for entity in entity_candidates: missing.discard(entity.entity_id) - referenced.log_missing(missing) + referenced.log_missing(missing, _LOGGER) entities: list[Entity] = [] for entity in entity_candidates: From 0335c9e32b66cb68e15feaf1eacbd84aaa646a9e Mon Sep 17 00:00:00 2001 From: abmantis Date: Thu, 3 Jul 2025 22:59:32 +0100 Subject: [PATCH 04/19] Use common method in triggers.py --- homeassistant/helpers/trigger.py | 235 +------------------------------ 1 file changed, 3 insertions(+), 232 deletions(-) diff --git a/homeassistant/helpers/trigger.py b/homeassistant/helpers/trigger.py index 370954de7448a6..e3dc4fd57536d6 100644 --- a/homeassistant/helpers/trigger.py +++ b/homeassistant/helpers/trigger.py @@ -6,26 +6,19 @@ import asyncio from collections import defaultdict from collections.abc import Callable, Coroutine, Iterable -import dataclasses from dataclasses import dataclass, field import functools import logging -from typing import TYPE_CHECKING, Any, Protocol, TypedDict, TypeGuard, cast +from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast import voluptuous as vol from homeassistant.const import ( - ATTR_AREA_ID, - ATTR_DEVICE_ID, - ATTR_ENTITY_ID, - ATTR_FLOOR_ID, - ATTR_LABEL_ID, CONF_ALIAS, CONF_ENABLED, CONF_ID, CONF_PLATFORM, CONF_VARIABLES, - ENTITY_MATCH_NONE, ) from homeassistant.core import ( CALLBACK_TYPE, @@ -49,17 +42,10 @@ from homeassistant.util.yaml import load_yaml_dict from homeassistant.util.yaml.loader import JSON_TYPE -from . import ( - area_registry, - config_validation as cv, - device_registry, - entity_registry, - floor_registry, - label_registry, -) +from . import area_registry, config_validation as cv, device_registry, entity_registry from .event import EventStateChangedData, async_track_state_change_event -from .group import expand_entity_ids from .integration_platform import async_process_integration_platforms +from .selector import TargetSelectorData, async_extract_referenced_entity_ids from .template import Template from .typing import ConfigType, TemplateVarsType @@ -637,221 +623,6 @@ async def async_get_all_descriptions( return new_descriptions_cache -def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]: - """Check if ids can match anything.""" - return ids not in (None, ENTITY_MATCH_NONE) - - -# TODO(abmantis): Since this is a copy from the service one, move it to a common place and use it in both places -class TargetSelectorData: - """Class to hold data of target selector.""" - - __slots__ = ("area_ids", "device_ids", "entity_ids", "floor_ids", "label_ids") - - def __init__(self, config: ConfigType) -> None: - """Extract ids from the config.""" - entity_ids: str | list | None = config.get(ATTR_ENTITY_ID) - device_ids: str | list | None = config.get(ATTR_DEVICE_ID) - area_ids: str | list | None = config.get(ATTR_AREA_ID) - floor_ids: str | list | None = config.get(ATTR_FLOOR_ID) - label_ids: str | list | None = config.get(ATTR_LABEL_ID) - - self.entity_ids = ( - set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set() - ) - self.device_ids = ( - set(cv.ensure_list(device_ids)) if _has_match(device_ids) else set() - ) - self.area_ids = set(cv.ensure_list(area_ids)) if _has_match(area_ids) else set() - self.floor_ids = ( - set(cv.ensure_list(floor_ids)) if _has_match(floor_ids) else set() - ) - self.label_ids = ( - set(cv.ensure_list(label_ids)) if _has_match(label_ids) else set() - ) - - @property - def has_any_selector(self) -> bool: - """Determine if any selectors are present.""" - return bool( - self.entity_ids - or self.device_ids - or self.area_ids - or self.floor_ids - or self.label_ids - ) - - -# TODO(abmantis): Since this is a copy from the service one, move it to a common place and use it in both places -@dataclasses.dataclass(slots=True) -class SelectedEntities: - """Class to hold the selected entities.""" - - # Entities that were explicitly mentioned. - referenced: set[str] = dataclasses.field(default_factory=set) - - # Entities that were referenced via device/area/floor/label ID. - # Should not trigger a warning when they don't exist. - indirectly_referenced: set[str] = dataclasses.field(default_factory=set) - - # Referenced items that could not be found. - missing_devices: set[str] = dataclasses.field(default_factory=set) - missing_areas: set[str] = dataclasses.field(default_factory=set) - missing_floors: set[str] = dataclasses.field(default_factory=set) - missing_labels: set[str] = dataclasses.field(default_factory=set) - - referenced_devices: set[str] = dataclasses.field(default_factory=set) - referenced_areas: set[str] = dataclasses.field(default_factory=set) - - def log_missing(self, missing_entities: set[str]) -> None: - """Log about missing items.""" - parts = [] - for label, items in ( - ("floors", self.missing_floors), - ("areas", self.missing_areas), - ("devices", self.missing_devices), - ("entities", missing_entities), - ("labels", self.missing_labels), - ): - if items: - parts.append(f"{label} {', '.join(sorted(items))}") - - if not parts: - return - - _LOGGER.warning( - "Referenced %s are missing or not currently available", - ", ".join(parts), - ) - - -# TODO(abmantis): Since this is a copy from the service one, move it to a common place and use it in both places -def async_extract_referenced_entity_ids( - hass: HomeAssistant, selector_data: TargetSelectorData, expand_group: bool = True -) -> SelectedEntities: - """Extract referenced entity IDs from a target selector.""" - selected = SelectedEntities() - - if not selector_data.has_any_selector: - return selected - - entity_ids: set[str] | list[str] = selector_data.entity_ids - if expand_group: - entity_ids = expand_entity_ids(hass, entity_ids) - - selected.referenced.update(entity_ids) - - if ( - not selector_data.device_ids - and not selector_data.area_ids - and not selector_data.floor_ids - and not selector_data.label_ids - ): - return selected - - entities = entity_registry.async_get(hass).entities - dev_reg = device_registry.async_get(hass) - area_reg = area_registry.async_get(hass) - - if selector_data.floor_ids: - floor_reg = floor_registry.async_get(hass) - for floor_id in selector_data.floor_ids: - if floor_id not in floor_reg.floors: - selected.missing_floors.add(floor_id) - - for area_id in selector_data.area_ids: - if area_id not in area_reg.areas: - selected.missing_areas.add(area_id) - - for device_id in selector_data.device_ids: - if device_id not in dev_reg.devices: - selected.missing_devices.add(device_id) - - if selector_data.label_ids: - label_reg = label_registry.async_get(hass) - for label_id in selector_data.label_ids: - if label_id not in label_reg.labels: - selected.missing_labels.add(label_id) - - for entity_entry in entities.get_entries_for_label(label_id): - if ( - entity_entry.entity_category is None - and entity_entry.hidden_by is None - ): - selected.indirectly_referenced.add(entity_entry.entity_id) - - for device_entry in dev_reg.devices.get_devices_for_label(label_id): - selected.referenced_devices.add(device_entry.id) - - for area_entry in area_reg.areas.get_areas_for_label(label_id): - selected.referenced_areas.add(area_entry.id) - - # Find areas for targeted floors - if selector_data.floor_ids: - selected.referenced_areas.update( - area_entry.id - for floor_id in selector_data.floor_ids - for area_entry in area_reg.areas.get_areas_for_floor(floor_id) - ) - - selected.referenced_areas.update(selector_data.area_ids) - selected.referenced_devices.update(selector_data.device_ids) - - if not selected.referenced_areas and not selected.referenced_devices: - return selected - - # Add indirectly referenced by device - selected.indirectly_referenced.update( - entry.entity_id - for device_id in selected.referenced_devices - for entry in entities.get_entries_for_device_id(device_id) - # Do not add entities which are hidden or which are config - # or diagnostic entities. - if (entry.entity_category is None and entry.hidden_by is None) - ) - - # Find devices for targeted areas - referenced_devices_by_area: set[str] = set() - if selected.referenced_areas: - for area_id in selected.referenced_areas: - referenced_devices_by_area.update( - device_entry.id - for device_entry in dev_reg.devices.get_devices_for_area_id(area_id) - ) - selected.referenced_devices.update(referenced_devices_by_area) - - # Add indirectly referenced by area - selected.indirectly_referenced.update( - entry.entity_id - for area_id in selected.referenced_areas - # The entity's area matches a targeted area - for entry in entities.get_entries_for_area_id(area_id) - # Do not add entities which are hidden or which are config - # or diagnostic entities. - if entry.entity_category is None and entry.hidden_by is None - ) - # Add indirectly referenced by area through device - selected.indirectly_referenced.update( - entry.entity_id - for device_id in referenced_devices_by_area - for entry in entities.get_entries_for_device_id(device_id) - # Do not add entities which are hidden or which are config - # or diagnostic entities. - if ( - entry.entity_category is None - and entry.hidden_by is None - and ( - # The entity's device matches a device referenced - # by an area and the entity - # has no explicitly set area - not entry.area_id - ) - ) - ) - - return selected - - class TargetSelectorStateChangeTracker: """Helper class to manage state change tracking for target selectors.""" From 695f47c5fc96e7875511a297c0efe8d0cd79c1ad Mon Sep 17 00:00:00 2001 From: abmantis Date: Thu, 3 Jul 2025 23:24:01 +0100 Subject: [PATCH 05/19] Add missed components --- .../components/homeassistant/__init__.py | 9 +++++++-- homeassistant/components/homekit/__init__.py | 9 ++++++--- homeassistant/components/lifx/manager.py | 10 +++++++--- homeassistant/components/unifiprotect/services.py | 15 ++++++++++----- 4 files changed, 30 insertions(+), 13 deletions(-) diff --git a/homeassistant/components/homeassistant/__init__.py b/homeassistant/components/homeassistant/__init__.py index d5dabfa2e0834d..cc03cb731018fe 100644 --- a/homeassistant/components/homeassistant/__init__.py +++ b/homeassistant/components/homeassistant/__init__.py @@ -42,9 +42,12 @@ ) from homeassistant.helpers.entity_component import async_update_entity from homeassistant.helpers.issue_registry import IssueSeverity +from homeassistant.helpers.selector import ( + TargetSelectorData, + async_extract_referenced_entity_ids, +) from homeassistant.helpers.service import ( async_extract_config_entry_ids, - async_extract_referenced_entity_ids, async_register_admin_service, ) from homeassistant.helpers.signal import KEY_HA_STOP @@ -111,7 +114,9 @@ async def async_save_persistent_states(service: ServiceCall) -> None: async def async_handle_turn_service(service: ServiceCall) -> None: """Handle calls to homeassistant.turn_on/off.""" - referenced = async_extract_referenced_entity_ids(hass, service) + referenced = async_extract_referenced_entity_ids( + hass, TargetSelectorData(service.data) + ) all_referenced = referenced.referenced | referenced.indirectly_referenced # Generic turn on/off method requires entity id diff --git a/homeassistant/components/homekit/__init__.py b/homeassistant/components/homekit/__init__.py index 8b526b62302801..8edf4151ac44a9 100644 --- a/homeassistant/components/homekit/__init__.py +++ b/homeassistant/components/homekit/__init__.py @@ -75,10 +75,11 @@ EntityFilter, ) from homeassistant.helpers.reload import async_integration_yaml_config -from homeassistant.helpers.service import ( +from homeassistant.helpers.selector import ( + TargetSelectorData, async_extract_referenced_entity_ids, - async_register_admin_service, ) +from homeassistant.helpers.service import async_register_admin_service from homeassistant.helpers.start import async_at_started from homeassistant.helpers.typing import ConfigType from homeassistant.loader import IntegrationNotFound, async_get_integration @@ -482,7 +483,9 @@ async def async_handle_homekit_reset_accessory(service: ServiceCall) -> None: async def async_handle_homekit_unpair(service: ServiceCall) -> None: """Handle unpair HomeKit service call.""" - referenced = async_extract_referenced_entity_ids(hass, service) + referenced = async_extract_referenced_entity_ids( + hass, TargetSelectorData(service.data) + ) dev_reg = dr.async_get(hass) for device_id in referenced.referenced_devices: if not (dev_reg_ent := dev_reg.async_get(device_id)): diff --git a/homeassistant/components/lifx/manager.py b/homeassistant/components/lifx/manager.py index 33712441157fad..54af8f67c9ba72 100644 --- a/homeassistant/components/lifx/manager.py +++ b/homeassistant/components/lifx/manager.py @@ -28,7 +28,10 @@ from homeassistant.const import ATTR_MODE from homeassistant.core import HomeAssistant, ServiceCall, callback from homeassistant.helpers import config_validation as cv -from homeassistant.helpers.service import async_extract_referenced_entity_ids +from homeassistant.helpers.selector import ( + TargetSelectorData, + async_extract_referenced_entity_ids, +) from .const import _ATTR_COLOR_TEMP, ATTR_THEME, DOMAIN from .coordinator import LIFXConfigEntry, LIFXUpdateCoordinator @@ -268,7 +271,9 @@ def async_setup(self) -> None: async def service_handler(service: ServiceCall) -> None: """Apply a service, i.e. start an effect.""" - referenced = async_extract_referenced_entity_ids(self.hass, service) + referenced = async_extract_referenced_entity_ids( + self.hass, TargetSelectorData(service.data) + ) all_referenced = referenced.referenced | referenced.indirectly_referenced if all_referenced: await self.start_effect(all_referenced, service.service, **service.data) @@ -499,6 +504,5 @@ async def start_effect( if self.entry_id_to_entity_id[entry.entry_id] in entity_ids: coordinators.append(entry.runtime_data) bulbs.append(entry.runtime_data.device) - if start_effect_func := self._effect_dispatch.get(service): await start_effect_func(self, bulbs, coordinators, **kwargs) diff --git a/homeassistant/components/unifiprotect/services.py b/homeassistant/components/unifiprotect/services.py index 40fe0a991f2f40..8f9b0cb0c3baac 100644 --- a/homeassistant/components/unifiprotect/services.py +++ b/homeassistant/components/unifiprotect/services.py @@ -26,7 +26,10 @@ device_registry as dr, entity_registry as er, ) -from homeassistant.helpers.service import async_extract_referenced_entity_ids +from homeassistant.helpers.selector import ( + TargetSelectorData, + async_extract_referenced_entity_ids, +) from homeassistant.util.json import JsonValueType from homeassistant.util.read_only_dict import ReadOnlyDict @@ -115,7 +118,7 @@ def _async_get_ufp_instance(hass: HomeAssistant, device_id: str) -> ProtectApiCl @callback def _async_get_ufp_camera(call: ServiceCall) -> Camera: - ref = async_extract_referenced_entity_ids(call.hass, call) + ref = async_extract_referenced_entity_ids(call.hass, TargetSelectorData(call.data)) entity_registry = er.async_get(call.hass) entity_id = ref.indirectly_referenced.pop() @@ -133,7 +136,7 @@ def _async_get_protect_from_call(call: ServiceCall) -> set[ProtectApiClient]: return { _async_get_ufp_instance(call.hass, device_id) for device_id in async_extract_referenced_entity_ids( - call.hass, call + call.hass, TargetSelectorData(call.data) ).referenced_devices } @@ -196,7 +199,7 @@ def _async_unique_id_to_mac(unique_id: str) -> str: async def set_chime_paired_doorbells(call: ServiceCall) -> None: """Set paired doorbells on chime.""" - ref = async_extract_referenced_entity_ids(call.hass, call) + ref = async_extract_referenced_entity_ids(call.hass, TargetSelectorData(call.data)) entity_registry = er.async_get(call.hass) entity_id = ref.indirectly_referenced.pop() @@ -211,7 +214,9 @@ async def set_chime_paired_doorbells(call: ServiceCall) -> None: assert chime is not None call.data = ReadOnlyDict(call.data.get("doorbells") or {}) - doorbell_refs = async_extract_referenced_entity_ids(call.hass, call) + doorbell_refs = async_extract_referenced_entity_ids( + call.hass, TargetSelectorData(call.data) + ) doorbell_ids: set[str] = set() for camera_id in doorbell_refs.referenced | doorbell_refs.indirectly_referenced: doorbell_sensor = entity_registry.async_get(camera_id) From 2d931c56b1e67d6c79a63065ecb5087fc95a9809 Mon Sep 17 00:00:00 2001 From: abmantis Date: Thu, 3 Jul 2025 23:35:06 +0100 Subject: [PATCH 06/19] Simplify tracker class --- homeassistant/helpers/trigger.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/homeassistant/helpers/trigger.py b/homeassistant/helpers/trigger.py index 46088b1d449e07..7241f44f95816b 100644 --- a/homeassistant/helpers/trigger.py +++ b/homeassistant/helpers/trigger.py @@ -646,10 +646,10 @@ def __init__( self._state_change_unsub: CALLBACK_TYPE | None = None self._registry_unsubs: list[CALLBACK_TYPE] = [] - self._setup_tracking() self._setup_registry_listeners() + self._track_entities_state_change() - def _track_entities_state_change(self) -> CALLBACK_TYPE: + def _track_entities_state_change(self) -> None: """Set up state change tracking for currently selected entities.""" selected = async_extract_referenced_entity_ids( self._hass, self._selector_data, expand_group=False @@ -667,14 +667,10 @@ def state_change_listener(event: Event[EventStateChangedData]) -> None: tracked_entities = selected.referenced.union(selected.indirectly_referenced) _LOGGER.debug("Tracking state changes for entities: %s", tracked_entities) - return async_track_state_change_event( + self._state_change_unsub = async_track_state_change_event( self._hass, tracked_entities, state_change_listener, job_type=self._job_type ) - def _setup_tracking(self) -> None: - """Initialize state change tracking.""" - self._state_change_unsub = self._track_entities_state_change() - def _setup_registry_listeners(self) -> None: """Set up listeners for registry changes that require resubscription.""" @@ -683,7 +679,7 @@ def resubscribe_state_change_event(event: Event[Any] | None = None) -> None: """Resubscribe to state change events when registry changes.""" if self._state_change_unsub: self._state_change_unsub() - self._state_change_unsub = self._track_entities_state_change() + self._track_entities_state_change() self._registry_unsubs = [ self._hass.bus.async_listen( From 7a57ab4cd2df4b8b62ca6f5904c71786a81824aa Mon Sep 17 00:00:00 2001 From: abmantis Date: Mon, 7 Jul 2025 13:58:18 +0100 Subject: [PATCH 07/19] Update import --- homeassistant/helpers/trigger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/homeassistant/helpers/trigger.py b/homeassistant/helpers/trigger.py index 7241f44f95816b..51583c51bac94a 100644 --- a/homeassistant/helpers/trigger.py +++ b/homeassistant/helpers/trigger.py @@ -45,7 +45,7 @@ from . import area_registry, config_validation as cv, device_registry, entity_registry from .event import EventStateChangedData, async_track_state_change_event from .integration_platform import async_process_integration_platforms -from .selector import TargetSelectorData, async_extract_referenced_entity_ids +from .target import TargetSelectorData, async_extract_referenced_entity_ids from .template import Template from .typing import ConfigType, TemplateVarsType From 4cc97d260c382ddab7a04eab3a137774114caae8 Mon Sep 17 00:00:00 2001 From: abmantis Date: Mon, 7 Jul 2025 15:47:51 +0100 Subject: [PATCH 08/19] Implement review suggestion on triggering all test entities --- tests/helpers/test_trigger.py | 121 +++++++++++++++++----------------- 1 file changed, 59 insertions(+), 62 deletions(-) diff --git a/tests/helpers/test_trigger.py b/tests/helpers/test_trigger.py index 28ef8b94623215..c6ea369a01a758 100644 --- a/tests/helpers/test_trigger.py +++ b/tests/helpers/test_trigger.py @@ -842,14 +842,27 @@ def state_change_callback(event): """Handle state change events.""" calls.append(event) - async def set_state(entity_id, state): - """Set the state of an entity.""" - hass.states.async_set(entity_id, state) + # List of entities to toggle state during the test. This list should be insert-only + # so that all entities are changed every time. + entities_to_set_state = [] + # List of entities that should assert a state change when toggled. Contrary to + # entities_to_set_state, entities should be added and removed. + entities_to_assert_change = [] + last_state = STATE_OFF + + async def toggle_states(): + """Toggle the state of all the entities in test.""" + nonlocal last_state + last_state = STATE_ON if last_state == STATE_OFF else STATE_OFF + for entity_id in entities_to_set_state: + hass.states.async_set(entity_id, last_state) await hass.async_block_till_done() - def assert_entity_calls_and_reset(entity_id: str) -> None: - assert len(calls) == 1 - assert calls[0].data["entity_id"] == entity_id + def assert_entity_calls_and_reset() -> None: + assert len(calls) == len(entities_to_assert_change) + for change_call in calls: + assert change_call.data["entity_id"] in entities_to_assert_change + assert change_call.data["new_state"].state == last_state calls.clear() config_entry = MockConfigEntry(domain="test") @@ -889,9 +902,8 @@ def assert_entity_calls_and_reset(entity_id: str) -> None: targeted_entity = "light.test_light" - for entity_id in (targeted_entity, device_entity, untargeted_entity): - hass.states.async_set(entity_id, STATE_OFF) - await hass.async_block_till_done() + entities_to_set_state.extend([targeted_entity, device_entity, untargeted_entity]) + await toggle_states() label = lr.async_get(hass).async_create("Test Label").name area = ar.async_get(hass).async_create("Test Area").id @@ -909,17 +921,9 @@ def assert_entity_calls_and_reset(entity_id: str) -> None: ) # Test directly targeted entity and device - await set_state(targeted_entity, STATE_ON) - await set_state(device_entity, STATE_ON) - - assert len(calls) == 2 - assert calls[0].data["entity_id"] == targeted_entity - assert calls[0].data["old_state"].state == STATE_OFF - assert calls[0].data["new_state"].state == STATE_ON - assert calls[1].data["entity_id"] == device_entity - assert calls[1].data["old_state"].state == STATE_OFF - assert calls[1].data["new_state"].state == STATE_ON - calls.clear() + entities_to_assert_change.extend([targeted_entity, device_entity]) + await toggle_states() + assert_entity_calls_and_reset() # Add new entity to the targeted device -> should trigger on state change device_entity_2 = entity_reg.async_get_or_create( @@ -930,89 +934,82 @@ def assert_entity_calls_and_reset(entity_id: str) -> None: ).entity_id await hass.async_block_till_done() - await set_state(device_entity_2, STATE_ON) - - assert_entity_calls_and_reset(device_entity_2) + entities_to_set_state.append(device_entity_2) + entities_to_assert_change.append(device_entity_2) + await toggle_states() + assert_entity_calls_and_reset() # Test untargeted entity -> should not trigger - await set_state(untargeted_entity, STATE_ON) - - assert len(calls) == 0 - calls.clear() + entities_to_set_state.append(untargeted_entity) + await toggle_states() + assert_entity_calls_and_reset() # Add label to untargeted entity -> should trigger now entity_reg.async_update_entity(untargeted_entity, labels={label}) await hass.async_block_till_done() - await set_state(untargeted_entity, STATE_OFF) - - assert_entity_calls_and_reset(untargeted_entity) + entities_to_assert_change.append(untargeted_entity) + await toggle_states() + assert_entity_calls_and_reset() # Remove label from untargeted entity -> should not trigger anymore entity_reg.async_update_entity(untargeted_entity, labels={}) await hass.async_block_till_done() - await set_state(untargeted_entity, STATE_ON) - await set_state(untargeted_entity, STATE_OFF) - - assert len(calls) == 0 + entities_to_assert_change.remove(untargeted_entity) + await toggle_states() + assert_entity_calls_and_reset() # Add area to untargeted entity -> should trigger now entity_reg.async_update_entity(untargeted_entity, area_id=area) await hass.async_block_till_done() - await set_state(untargeted_entity, STATE_ON) - - assert_entity_calls_and_reset(untargeted_entity) + entities_to_assert_change.append(untargeted_entity) + await toggle_states() + assert_entity_calls_and_reset() # Remove area from untargeted entity -> should not trigger anymore entity_reg.async_update_entity(untargeted_entity, area_id=None) await hass.async_block_till_done() - await set_state(untargeted_entity, STATE_ON) - await set_state(untargeted_entity, STATE_OFF) - - assert len(calls) == 0 + entities_to_assert_change.remove(untargeted_entity) + await toggle_states() + assert_entity_calls_and_reset() # Add area to untargeted device -> should trigger on state change device_reg.async_update_device(untargeted_device_entry.id, area_id=area) await hass.async_block_till_done() - - await set_state(untargeted_device_entity, STATE_ON) - - assert_entity_calls_and_reset(untargeted_device_entity) + entities_to_set_state.append(untargeted_device_entity) + entities_to_assert_change.append(untargeted_device_entity) + await toggle_states() + assert_entity_calls_and_reset() # Remove area from untargeted device -> should not trigger anymore device_reg.async_update_device(untargeted_device_entry.id, area_id=None) await hass.async_block_till_done() - await set_state(untargeted_device_entity, STATE_OFF) - await set_state(untargeted_device_entity, STATE_ON) - - assert len(calls) == 0 + entities_to_assert_change.remove(untargeted_device_entity) + await toggle_states() + assert_entity_calls_and_reset() # Set the untargeted area on the untargeted entity -> should not trigger untracked_area = ar.async_get(hass).async_create("Untargeted Area").id entity_reg.async_update_entity(untargeted_entity, area_id=untracked_area) await hass.async_block_till_done() - - await set_state(untargeted_entity, STATE_ON) - assert len(calls) == 0 + await toggle_states() + assert_entity_calls_and_reset() # Set targeted floor on the untargeted area -> should trigger now ar.async_get(hass).async_update(untracked_area, floor_id=floor) await hass.async_block_till_done() - - await set_state(untargeted_entity, STATE_OFF) - assert_entity_calls_and_reset(untargeted_entity) + entities_to_assert_change.append(untargeted_entity) + await toggle_states() + assert_entity_calls_and_reset() # Remove untargeted area from targeted floor -> should not trigger anymore ar.async_get(hass).async_update(untracked_area, floor_id=None) await hass.async_block_till_done() - - await set_state(untargeted_entity, STATE_ON) - await set_state(untargeted_entity, STATE_OFF) - assert len(calls) == 0 + entities_to_assert_change.remove(untargeted_entity) + await toggle_states() + assert_entity_calls_and_reset() # After unsubscribing, changes should not trigger unsub() - for entity_id in (targeted_entity, device_entity, untargeted_entity): - await set_state(entity_id, STATE_OFF) - await set_state(entity_id, STATE_ON) + await toggle_states() assert len(calls) == 0 From 67a7cf83c2d545dce029c5a054bf29c60aa969b7 Mon Sep 17 00:00:00 2001 From: abmantis Date: Mon, 7 Jul 2025 17:27:19 +0100 Subject: [PATCH 09/19] Rename class; add comment --- homeassistant/helpers/trigger.py | 13 ++++++++++--- tests/helpers/test_trigger.py | 1 - 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/homeassistant/helpers/trigger.py b/homeassistant/helpers/trigger.py index 51583c51bac94a..6f64c60992a48e 100644 --- a/homeassistant/helpers/trigger.py +++ b/homeassistant/helpers/trigger.py @@ -627,8 +627,8 @@ async def async_get_all_descriptions( return new_descriptions_cache -class TargetSelectorStateChangeTracker: - """Helper class to manage state change tracking for target selectors.""" +class TargetStateChangeTracker: + """Helper class to manage state change tracking for targets.""" def __init__( self, @@ -681,6 +681,13 @@ def resubscribe_state_change_event(event: Event[Any] | None = None) -> None: self._state_change_unsub() self._track_entities_state_change() + # Subscribe to registry updates that can change the entities to track: + # - Entity registry: entity added/removed; entity labels changed; entity area changed. + # - Device registry: device labels changed; device area changed. + # - Area registry: area floor changed. + # + # We don't track other registries (like floor or label registries) because their + # changes don't affect which entities are tracked. self._registry_unsubs = [ self._hass.bus.async_listen( entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, @@ -723,5 +730,5 @@ def async_track_target_selector_state_change_event( ) return lambda: None - tracker = TargetSelectorStateChangeTracker(hass, selector_data, job_type, action) + tracker = TargetStateChangeTracker(hass, selector_data, job_type, action) return tracker.unsub diff --git a/tests/helpers/test_trigger.py b/tests/helpers/test_trigger.py index c6ea369a01a758..4dd4da94835fc0 100644 --- a/tests/helpers/test_trigger.py +++ b/tests/helpers/test_trigger.py @@ -1010,6 +1010,5 @@ def assert_entity_calls_and_reset() -> None: # After unsubscribing, changes should not trigger unsub() - await toggle_states() assert len(calls) == 0 From 937671561f584091b121e07f1ae2e6154fba79ff Mon Sep 17 00:00:00 2001 From: abmantis Date: Mon, 7 Jul 2025 17:50:24 +0100 Subject: [PATCH 10/19] Move to target.py --- homeassistant/helpers/target.py | 123 +++++++++++++++- homeassistant/helpers/trigger.py | 113 +-------------- tests/helpers/test_target.py | 214 +++++++++++++++++++++++++++- tests/helpers/test_trigger.py | 236 +------------------------------ 4 files changed, 337 insertions(+), 349 deletions(-) diff --git a/homeassistant/helpers/target.py b/homeassistant/helpers/target.py index c16819235b9baf..1c0fd0745ca479 100644 --- a/homeassistant/helpers/target.py +++ b/homeassistant/helpers/target.py @@ -2,9 +2,11 @@ from __future__ import annotations +from collections.abc import Callable import dataclasses +import logging from logging import Logger -from typing import TypeGuard +from typing import Any, TypeGuard from homeassistant.const import ( ATTR_AREA_ID, @@ -14,7 +16,14 @@ ATTR_LABEL_ID, ENTITY_MATCH_NONE, ) -from homeassistant.core import HomeAssistant +from homeassistant.core import ( + CALLBACK_TYPE, + Event, + EventStateChangedData, + HassJobType, + HomeAssistant, + callback, +) from . import ( area_registry as ar, @@ -25,8 +34,11 @@ group, label_registry as lr, ) +from .event import async_track_state_change_event from .typing import ConfigType +_LOGGER = logging.getLogger(__name__) + def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]: """Check if ids can match anything.""" @@ -238,3 +250,110 @@ def async_extract_referenced_entity_ids( ) return selected + + +class TargetStateChangeTracker: + """Helper class to manage state change tracking for targets.""" + + def __init__( + self, + hass: HomeAssistant, + selector_data: TargetSelectorData, + job_type: HassJobType | None, + action: Callable[[Event[EventStateChangedData]], Any], + ) -> None: + """Initialize the state change tracker.""" + self._hass = hass + self._selector_data = selector_data + self._job_type = job_type + self._action = action + + self._state_change_unsub: CALLBACK_TYPE | None = None + self._registry_unsubs: list[CALLBACK_TYPE] = [] + + self._setup_registry_listeners() + self._track_entities_state_change() + + def _track_entities_state_change(self) -> None: + """Set up state change tracking for currently selected entities.""" + selected = async_extract_referenced_entity_ids( + self._hass, self._selector_data, expand_group=False + ) + + @callback + def state_change_listener(event: Event[EventStateChangedData]) -> None: + """Handle state change events.""" + if ( + event.data["entity_id"] in selected.referenced + or event.data["entity_id"] in selected.indirectly_referenced + ): + self._action(event) + + tracked_entities = selected.referenced.union(selected.indirectly_referenced) + + _LOGGER.debug("Tracking state changes for entities: %s", tracked_entities) + self._state_change_unsub = async_track_state_change_event( + self._hass, tracked_entities, state_change_listener, job_type=self._job_type + ) + + def _setup_registry_listeners(self) -> None: + """Set up listeners for registry changes that require resubscription.""" + + @callback + def resubscribe_state_change_event(event: Event[Any] | None = None) -> None: + """Resubscribe to state change events when registry changes.""" + if self._state_change_unsub: + self._state_change_unsub() + self._track_entities_state_change() + + # Subscribe to registry updates that can change the entities to track: + # - Entity registry: entity added/removed; entity labels changed; entity area changed. + # - Device registry: device labels changed; device area changed. + # - Area registry: area floor changed. + # + # We don't track other registries (like floor or label registries) because their + # changes don't affect which entities are tracked. + self._registry_unsubs = [ + self._hass.bus.async_listen( + er.EVENT_ENTITY_REGISTRY_UPDATED, + resubscribe_state_change_event, + # TODO(abmantis): filter for entities that match the target selector? + # event_filter=self._filter_entity_registry_changes, + ), + self._hass.bus.async_listen( + dr.EVENT_DEVICE_REGISTRY_UPDATED, + resubscribe_state_change_event, + ), + self._hass.bus.async_listen( + ar.EVENT_AREA_REGISTRY_UPDATED, + resubscribe_state_change_event, + ), + ] + + def unsub(self) -> None: + """Unsubscribe from all events.""" + for registry_unsub in self._registry_unsubs: + registry_unsub() + self._registry_unsubs.clear() + if self._state_change_unsub: + self._state_change_unsub() + self._state_change_unsub = None + + +def async_track_target_selector_state_change_event( + hass: HomeAssistant, + target_selector_config: ConfigType, + action: Callable[[Event[EventStateChangedData]], Any], + job_type: HassJobType | None = None, +) -> CALLBACK_TYPE: + """Track state changes for entities referenced directly or indirectly in a target selector.""" + selector_data = TargetSelectorData(target_selector_config) + if not selector_data.has_any_selector: + _LOGGER.warning( + "Target selector %s does not have any selectors defined", + target_selector_config, + ) + return lambda: None + + tracker = TargetStateChangeTracker(hass, selector_data, job_type, action) + return tracker.unsub diff --git a/homeassistant/helpers/trigger.py b/homeassistant/helpers/trigger.py index 6f64c60992a48e..57ee6b99029463 100644 --- a/homeassistant/helpers/trigger.py +++ b/homeassistant/helpers/trigger.py @@ -23,9 +23,7 @@ from homeassistant.core import ( CALLBACK_TYPE, Context, - Event, HassJob, - HassJobType, HomeAssistant, callback, is_callback, @@ -42,10 +40,8 @@ from homeassistant.util.yaml import load_yaml_dict from homeassistant.util.yaml.loader import JSON_TYPE -from . import area_registry, config_validation as cv, device_registry, entity_registry -from .event import EventStateChangedData, async_track_state_change_event +from . import config_validation as cv from .integration_platform import async_process_integration_platforms -from .target import TargetSelectorData, async_extract_referenced_entity_ids from .template import Template from .typing import ConfigType, TemplateVarsType @@ -625,110 +621,3 @@ async def async_get_all_descriptions( hass.data[TRIGGER_DESCRIPTION_CACHE] = new_descriptions_cache return new_descriptions_cache - - -class TargetStateChangeTracker: - """Helper class to manage state change tracking for targets.""" - - def __init__( - self, - hass: HomeAssistant, - selector_data: TargetSelectorData, - job_type: HassJobType | None, - action: Callable[[Event[EventStateChangedData]], Any], - ) -> None: - """Initialize the state change tracker.""" - self._hass = hass - self._selector_data = selector_data - self._job_type = job_type - self._action = action - - self._state_change_unsub: CALLBACK_TYPE | None = None - self._registry_unsubs: list[CALLBACK_TYPE] = [] - - self._setup_registry_listeners() - self._track_entities_state_change() - - def _track_entities_state_change(self) -> None: - """Set up state change tracking for currently selected entities.""" - selected = async_extract_referenced_entity_ids( - self._hass, self._selector_data, expand_group=False - ) - - @callback - def state_change_listener(event: Event[EventStateChangedData]) -> None: - """Handle state change events.""" - if ( - event.data["entity_id"] in selected.referenced - or event.data["entity_id"] in selected.indirectly_referenced - ): - self._action(event) - - tracked_entities = selected.referenced.union(selected.indirectly_referenced) - - _LOGGER.debug("Tracking state changes for entities: %s", tracked_entities) - self._state_change_unsub = async_track_state_change_event( - self._hass, tracked_entities, state_change_listener, job_type=self._job_type - ) - - def _setup_registry_listeners(self) -> None: - """Set up listeners for registry changes that require resubscription.""" - - @callback - def resubscribe_state_change_event(event: Event[Any] | None = None) -> None: - """Resubscribe to state change events when registry changes.""" - if self._state_change_unsub: - self._state_change_unsub() - self._track_entities_state_change() - - # Subscribe to registry updates that can change the entities to track: - # - Entity registry: entity added/removed; entity labels changed; entity area changed. - # - Device registry: device labels changed; device area changed. - # - Area registry: area floor changed. - # - # We don't track other registries (like floor or label registries) because their - # changes don't affect which entities are tracked. - self._registry_unsubs = [ - self._hass.bus.async_listen( - entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, - resubscribe_state_change_event, - # TODO(abmantis): filter for entities that match the target selector? - # event_filter=self._filter_entity_registry_changes, - ), - self._hass.bus.async_listen( - device_registry.EVENT_DEVICE_REGISTRY_UPDATED, - resubscribe_state_change_event, - ), - self._hass.bus.async_listen( - area_registry.EVENT_AREA_REGISTRY_UPDATED, - resubscribe_state_change_event, - ), - ] - - def unsub(self) -> None: - """Unsubscribe from all events.""" - for registry_unsub in self._registry_unsubs: - registry_unsub() - self._registry_unsubs.clear() - if self._state_change_unsub: - self._state_change_unsub() - self._state_change_unsub = None - - -def async_track_target_selector_state_change_event( - hass: HomeAssistant, - target_selector_config: ConfigType, - action: Callable[[Event[EventStateChangedData]], Any], - job_type: HassJobType | None = None, -) -> CALLBACK_TYPE: - """Track state changes for entities referenced directly or indirectly (by device, area, label, etc) in a target selector.""" - selector_data = TargetSelectorData(target_selector_config) - if not selector_data.has_any_selector: - _LOGGER.warning( - "Target selector %s does not have any selectors defined", - target_selector_config, - ) - return lambda: None - - tracker = TargetStateChangeTracker(hass, selector_data, job_type, action) - return tracker.unsub diff --git a/tests/helpers/test_target.py b/tests/helpers/test_target.py index ca38f316d89ae4..69f022bc9cd00b 100644 --- a/tests/helpers/test_target.py +++ b/tests/helpers/test_target.py @@ -17,17 +17,20 @@ STATE_ON, EntityCategory, ) -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import ( area_registry as ar, device_registry as dr, entity_registry as er, + floor_registry as fr, + label_registry as lr, target, ) from homeassistant.helpers.typing import ConfigType from homeassistant.setup import async_setup_component from tests.common import ( + MockConfigEntry, RegistryEntryWithDefaults, mock_area_registry, mock_device_registry, @@ -457,3 +460,212 @@ async def test_extract_referenced_entity_ids( ) == expected_selected ) + + +async def test_async_track_target_selector_state_change_event_empty_selector( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: + """Test async_track_target_selector_state_change_event with empty selector.""" + calls = [] + + @callback + def state_change_callback(event): + """Handle state change events.""" + calls.append(event) + + unsub = target.async_track_target_selector_state_change_event( + hass, {}, state_change_callback + ) + + assert "Target selector {} does not have any selectors defined" in caplog.text + + # Test that no state changes are tracked + hass.states.async_set("light.test", "on") + await hass.async_block_till_done() + + assert len(calls) == 0 + + unsub() + + +async def test_async_track_target_selector_state_change_event( + hass: HomeAssistant, +) -> None: + """Test async_track_target_selector_state_change_event with multiple targets.""" + calls = [] + + @callback + def state_change_callback(event): + """Handle state change events.""" + calls.append(event) + + # List of entities to toggle state during the test. This list should be insert-only + # so that all entities are changed every time. + entities_to_set_state = [] + # List of entities that should assert a state change when toggled. Contrary to + # entities_to_set_state, entities should be added and removed. + entities_to_assert_change = [] + last_state = STATE_OFF + + async def toggle_states(): + """Toggle the state of all the entities in test.""" + nonlocal last_state + last_state = STATE_ON if last_state == STATE_OFF else STATE_OFF + for entity_id in entities_to_set_state: + hass.states.async_set(entity_id, last_state) + await hass.async_block_till_done() + + def assert_entity_calls_and_reset() -> None: + assert len(calls) == len(entities_to_assert_change) + for change_call in calls: + assert change_call.data["entity_id"] in entities_to_assert_change + assert change_call.data["new_state"].state == last_state + calls.clear() + + config_entry = MockConfigEntry(domain="test") + config_entry.add_to_hass(hass) + + device_reg = dr.async_get(hass) + device_entry = device_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, + identifiers={("test", "device_1")}, + ) + + untargeted_device_entry = device_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, + identifiers={("test", "area_device")}, + ) + + entity_reg = er.async_get(hass) + device_entity = entity_reg.async_get_or_create( + domain="light", + platform="test", + unique_id="device_light", + device_id=device_entry.id, + ).entity_id + + untargeted_device_entity = entity_reg.async_get_or_create( + domain="light", + platform="test", + unique_id="area_device_light", + device_id=untargeted_device_entry.id, + ).entity_id + + untargeted_entity = entity_reg.async_get_or_create( + domain="light", + platform="test", + unique_id="untargeted_light", + ).entity_id + + targeted_entity = "light.test_light" + + entities_to_set_state.extend([targeted_entity, device_entity, untargeted_entity]) + await toggle_states() + + label = lr.async_get(hass).async_create("Test Label").name + area = ar.async_get(hass).async_create("Test Area").id + floor = fr.async_get(hass).async_create("Test Floor").floor_id + + selector_config = { + ATTR_ENTITY_ID: targeted_entity, + ATTR_DEVICE_ID: device_entry.id, + ATTR_AREA_ID: area, + ATTR_FLOOR_ID: floor, + ATTR_LABEL_ID: label, + } + unsub = target.async_track_target_selector_state_change_event( + hass, selector_config, state_change_callback + ) + + # Test directly targeted entity and device + entities_to_assert_change.extend([targeted_entity, device_entity]) + await toggle_states() + assert_entity_calls_and_reset() + + # Add new entity to the targeted device -> should trigger on state change + device_entity_2 = entity_reg.async_get_or_create( + domain="light", + platform="test", + unique_id="device_light_2", + device_id=device_entry.id, + ).entity_id + await hass.async_block_till_done() + + entities_to_set_state.append(device_entity_2) + entities_to_assert_change.append(device_entity_2) + await toggle_states() + assert_entity_calls_and_reset() + + # Test untargeted entity -> should not trigger + entities_to_set_state.append(untargeted_entity) + await toggle_states() + assert_entity_calls_and_reset() + + # Add label to untargeted entity -> should trigger now + entity_reg.async_update_entity(untargeted_entity, labels={label}) + await hass.async_block_till_done() + entities_to_assert_change.append(untargeted_entity) + await toggle_states() + assert_entity_calls_and_reset() + + # Remove label from untargeted entity -> should not trigger anymore + entity_reg.async_update_entity(untargeted_entity, labels={}) + await hass.async_block_till_done() + entities_to_assert_change.remove(untargeted_entity) + await toggle_states() + assert_entity_calls_and_reset() + + # Add area to untargeted entity -> should trigger now + entity_reg.async_update_entity(untargeted_entity, area_id=area) + await hass.async_block_till_done() + entities_to_assert_change.append(untargeted_entity) + await toggle_states() + assert_entity_calls_and_reset() + + # Remove area from untargeted entity -> should not trigger anymore + entity_reg.async_update_entity(untargeted_entity, area_id=None) + await hass.async_block_till_done() + entities_to_assert_change.remove(untargeted_entity) + await toggle_states() + assert_entity_calls_and_reset() + + # Add area to untargeted device -> should trigger on state change + device_reg.async_update_device(untargeted_device_entry.id, area_id=area) + await hass.async_block_till_done() + entities_to_set_state.append(untargeted_device_entity) + entities_to_assert_change.append(untargeted_device_entity) + await toggle_states() + assert_entity_calls_and_reset() + + # Remove area from untargeted device -> should not trigger anymore + device_reg.async_update_device(untargeted_device_entry.id, area_id=None) + await hass.async_block_till_done() + entities_to_assert_change.remove(untargeted_device_entity) + await toggle_states() + assert_entity_calls_and_reset() + + # Set the untargeted area on the untargeted entity -> should not trigger + untracked_area = ar.async_get(hass).async_create("Untargeted Area").id + entity_reg.async_update_entity(untargeted_entity, area_id=untracked_area) + await hass.async_block_till_done() + await toggle_states() + assert_entity_calls_and_reset() + + # Set targeted floor on the untargeted area -> should trigger now + ar.async_get(hass).async_update(untracked_area, floor_id=floor) + await hass.async_block_till_done() + entities_to_assert_change.append(untargeted_entity) + await toggle_states() + assert_entity_calls_and_reset() + + # Remove untargeted area from targeted floor -> should not trigger anymore + ar.async_get(hass).async_update(untracked_area, floor_id=None) + await hass.async_block_till_done() + entities_to_assert_change.remove(untargeted_entity) + await toggle_states() + assert_entity_calls_and_reset() + + # After unsubscribing, changes should not trigger + unsub() + await toggle_states() + assert len(calls) == 0 diff --git a/tests/helpers/test_trigger.py b/tests/helpers/test_trigger.py index 4dd4da94835fc0..f2f81e8809c17e 100644 --- a/tests/helpers/test_trigger.py +++ b/tests/helpers/test_trigger.py @@ -10,15 +10,6 @@ from homeassistant.components.sun import DOMAIN as DOMAIN_SUN from homeassistant.components.system_health import DOMAIN as DOMAIN_SYSTEM_HEALTH from homeassistant.components.tag import DOMAIN as DOMAIN_TAG -from homeassistant.const import ( - ATTR_AREA_ID, - ATTR_DEVICE_ID, - ATTR_ENTITY_ID, - ATTR_FLOOR_ID, - ATTR_LABEL_ID, - STATE_OFF, - STATE_ON, -) from homeassistant.core import ( CALLBACK_TYPE, Context, @@ -27,14 +18,7 @@ callback, ) from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import ( - area_registry as ar, - device_registry as dr, - entity_registry as er, - floor_registry as fr, - label_registry as lr, - trigger, -) +from homeassistant.helpers import trigger from homeassistant.helpers.trigger import ( DATA_PLUGGABLE_ACTIONS, PluggableAction, @@ -43,7 +27,6 @@ TriggerInfo, _async_get_trigger_platform, async_initialize_triggers, - async_track_target_selector_state_change_event, async_validate_trigger_config, ) from homeassistant.helpers.typing import ConfigType @@ -51,13 +34,7 @@ from homeassistant.setup import async_setup_component from homeassistant.util.yaml.loader import parse_yaml -from tests.common import ( - MockConfigEntry, - MockModule, - MockPlatform, - mock_integration, - mock_platform, -) +from tests.common import MockModule, MockPlatform, mock_integration, mock_platform async def test_bad_trigger_platform(hass: HomeAssistant) -> None: @@ -803,212 +780,3 @@ async def test_invalid_trigger_platform( await async_setup_component(hass, "test", {}) assert "Integration test does not provide trigger support, skipping" in caplog.text - - -async def test_async_track_target_selector_state_change_event_empty_selector( - hass: HomeAssistant, caplog: pytest.LogCaptureFixture -) -> None: - """Test async_track_target_selector_state_change_event with empty selector.""" - calls = [] - - @callback - def state_change_callback(event): - """Handle state change events.""" - calls.append(event) - - unsub = async_track_target_selector_state_change_event( - hass, {}, state_change_callback - ) - - assert "Target selector {} does not have any selectors defined" in caplog.text - - # Test that no state changes are tracked - hass.states.async_set("light.test", "on") - await hass.async_block_till_done() - - assert len(calls) == 0 - - unsub() - - -async def test_async_track_target_selector_state_change_event( - hass: HomeAssistant, -) -> None: - """Test async_track_target_selector_state_change_event with multiple targets.""" - calls = [] - - @callback - def state_change_callback(event): - """Handle state change events.""" - calls.append(event) - - # List of entities to toggle state during the test. This list should be insert-only - # so that all entities are changed every time. - entities_to_set_state = [] - # List of entities that should assert a state change when toggled. Contrary to - # entities_to_set_state, entities should be added and removed. - entities_to_assert_change = [] - last_state = STATE_OFF - - async def toggle_states(): - """Toggle the state of all the entities in test.""" - nonlocal last_state - last_state = STATE_ON if last_state == STATE_OFF else STATE_OFF - for entity_id in entities_to_set_state: - hass.states.async_set(entity_id, last_state) - await hass.async_block_till_done() - - def assert_entity_calls_and_reset() -> None: - assert len(calls) == len(entities_to_assert_change) - for change_call in calls: - assert change_call.data["entity_id"] in entities_to_assert_change - assert change_call.data["new_state"].state == last_state - calls.clear() - - config_entry = MockConfigEntry(domain="test") - config_entry.add_to_hass(hass) - - device_reg = dr.async_get(hass) - device_entry = device_reg.async_get_or_create( - config_entry_id=config_entry.entry_id, - identifiers={("test", "device_1")}, - ) - - untargeted_device_entry = device_reg.async_get_or_create( - config_entry_id=config_entry.entry_id, - identifiers={("test", "area_device")}, - ) - - entity_reg = er.async_get(hass) - device_entity = entity_reg.async_get_or_create( - domain="light", - platform="test", - unique_id="device_light", - device_id=device_entry.id, - ).entity_id - - untargeted_device_entity = entity_reg.async_get_or_create( - domain="light", - platform="test", - unique_id="area_device_light", - device_id=untargeted_device_entry.id, - ).entity_id - - untargeted_entity = entity_reg.async_get_or_create( - domain="light", - platform="test", - unique_id="untargeted_light", - ).entity_id - - targeted_entity = "light.test_light" - - entities_to_set_state.extend([targeted_entity, device_entity, untargeted_entity]) - await toggle_states() - - label = lr.async_get(hass).async_create("Test Label").name - area = ar.async_get(hass).async_create("Test Area").id - floor = fr.async_get(hass).async_create("Test Floor").floor_id - - selector_config = { - ATTR_ENTITY_ID: targeted_entity, - ATTR_DEVICE_ID: device_entry.id, - ATTR_AREA_ID: area, - ATTR_FLOOR_ID: floor, - ATTR_LABEL_ID: label, - } - unsub = async_track_target_selector_state_change_event( - hass, selector_config, state_change_callback - ) - - # Test directly targeted entity and device - entities_to_assert_change.extend([targeted_entity, device_entity]) - await toggle_states() - assert_entity_calls_and_reset() - - # Add new entity to the targeted device -> should trigger on state change - device_entity_2 = entity_reg.async_get_or_create( - domain="light", - platform="test", - unique_id="device_light_2", - device_id=device_entry.id, - ).entity_id - await hass.async_block_till_done() - - entities_to_set_state.append(device_entity_2) - entities_to_assert_change.append(device_entity_2) - await toggle_states() - assert_entity_calls_and_reset() - - # Test untargeted entity -> should not trigger - entities_to_set_state.append(untargeted_entity) - await toggle_states() - assert_entity_calls_and_reset() - - # Add label to untargeted entity -> should trigger now - entity_reg.async_update_entity(untargeted_entity, labels={label}) - await hass.async_block_till_done() - entities_to_assert_change.append(untargeted_entity) - await toggle_states() - assert_entity_calls_and_reset() - - # Remove label from untargeted entity -> should not trigger anymore - entity_reg.async_update_entity(untargeted_entity, labels={}) - await hass.async_block_till_done() - entities_to_assert_change.remove(untargeted_entity) - await toggle_states() - assert_entity_calls_and_reset() - - # Add area to untargeted entity -> should trigger now - entity_reg.async_update_entity(untargeted_entity, area_id=area) - await hass.async_block_till_done() - entities_to_assert_change.append(untargeted_entity) - await toggle_states() - assert_entity_calls_and_reset() - - # Remove area from untargeted entity -> should not trigger anymore - entity_reg.async_update_entity(untargeted_entity, area_id=None) - await hass.async_block_till_done() - entities_to_assert_change.remove(untargeted_entity) - await toggle_states() - assert_entity_calls_and_reset() - - # Add area to untargeted device -> should trigger on state change - device_reg.async_update_device(untargeted_device_entry.id, area_id=area) - await hass.async_block_till_done() - entities_to_set_state.append(untargeted_device_entity) - entities_to_assert_change.append(untargeted_device_entity) - await toggle_states() - assert_entity_calls_and_reset() - - # Remove area from untargeted device -> should not trigger anymore - device_reg.async_update_device(untargeted_device_entry.id, area_id=None) - await hass.async_block_till_done() - entities_to_assert_change.remove(untargeted_device_entity) - await toggle_states() - assert_entity_calls_and_reset() - - # Set the untargeted area on the untargeted entity -> should not trigger - untracked_area = ar.async_get(hass).async_create("Untargeted Area").id - entity_reg.async_update_entity(untargeted_entity, area_id=untracked_area) - await hass.async_block_till_done() - await toggle_states() - assert_entity_calls_and_reset() - - # Set targeted floor on the untargeted area -> should trigger now - ar.async_get(hass).async_update(untracked_area, floor_id=floor) - await hass.async_block_till_done() - entities_to_assert_change.append(untargeted_entity) - await toggle_states() - assert_entity_calls_and_reset() - - # Remove untargeted area from targeted floor -> should not trigger anymore - ar.async_get(hass).async_update(untracked_area, floor_id=None) - await hass.async_block_till_done() - entities_to_assert_change.remove(untargeted_entity) - await toggle_states() - assert_entity_calls_and_reset() - - # After unsubscribing, changes should not trigger - unsub() - await toggle_states() - assert len(calls) == 0 From b6163b8c47210b2e89042e6a9aca15de83391983 Mon Sep 17 00:00:00 2001 From: abmantis Date: Tue, 8 Jul 2025 14:39:08 +0100 Subject: [PATCH 11/19] Cleanup --- homeassistant/helpers/target.py | 11 +++-------- tests/helpers/test_target.py | 3 --- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/homeassistant/helpers/target.py b/homeassistant/helpers/target.py index 1c0fd0745ca479..e0f75c55c7e8ff 100644 --- a/homeassistant/helpers/target.py +++ b/homeassistant/helpers/target.py @@ -315,18 +315,13 @@ def resubscribe_state_change_event(event: Event[Any] | None = None) -> None: # changes don't affect which entities are tracked. self._registry_unsubs = [ self._hass.bus.async_listen( - er.EVENT_ENTITY_REGISTRY_UPDATED, - resubscribe_state_change_event, - # TODO(abmantis): filter for entities that match the target selector? - # event_filter=self._filter_entity_registry_changes, + er.EVENT_ENTITY_REGISTRY_UPDATED, resubscribe_state_change_event ), self._hass.bus.async_listen( - dr.EVENT_DEVICE_REGISTRY_UPDATED, - resubscribe_state_change_event, + dr.EVENT_DEVICE_REGISTRY_UPDATED, resubscribe_state_change_event ), self._hass.bus.async_listen( - ar.EVENT_AREA_REGISTRY_UPDATED, - resubscribe_state_change_event, + ar.EVENT_AREA_REGISTRY_UPDATED, resubscribe_state_change_event ), ] diff --git a/tests/helpers/test_target.py b/tests/helpers/test_target.py index 69f022bc9cd00b..2d3a26e2cd856a 100644 --- a/tests/helpers/test_target.py +++ b/tests/helpers/test_target.py @@ -2,9 +2,6 @@ import pytest -# TODO(abmantis): is this import needed? -# To prevent circular import when running just this file -import homeassistant.components # noqa: F401 from homeassistant.components.group import Group from homeassistant.const import ( ATTR_AREA_ID, From 62632751a7c6e17019b2ca626f608e84b6089407 Mon Sep 17 00:00:00 2001 From: abmantis Date: Tue, 8 Jul 2025 15:01:51 +0100 Subject: [PATCH 12/19] Raise instead of logging --- homeassistant/helpers/target.py | 8 +++----- tests/helpers/test_target.py | 19 +++++++------------ 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/homeassistant/helpers/target.py b/homeassistant/helpers/target.py index e0f75c55c7e8ff..513d6b70d19401 100644 --- a/homeassistant/helpers/target.py +++ b/homeassistant/helpers/target.py @@ -24,6 +24,7 @@ HomeAssistant, callback, ) +from homeassistant.exceptions import HomeAssistantError from . import ( area_registry as ar, @@ -344,11 +345,8 @@ def async_track_target_selector_state_change_event( """Track state changes for entities referenced directly or indirectly in a target selector.""" selector_data = TargetSelectorData(target_selector_config) if not selector_data.has_any_selector: - _LOGGER.warning( - "Target selector %s does not have any selectors defined", - target_selector_config, + raise HomeAssistantError( + f"Target selector {target_selector_config} does not have any selectors defined" ) - return lambda: None - tracker = TargetStateChangeTracker(hass, selector_data, job_type, action) return tracker.unsub diff --git a/tests/helpers/test_target.py b/tests/helpers/test_target.py index 2d3a26e2cd856a..a5b75cf0615266 100644 --- a/tests/helpers/test_target.py +++ b/tests/helpers/test_target.py @@ -15,6 +15,7 @@ EntityCategory, ) from homeassistant.core import HomeAssistant, callback +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import ( area_registry as ar, device_registry as dr, @@ -470,20 +471,14 @@ def state_change_callback(event): """Handle state change events.""" calls.append(event) - unsub = target.async_track_target_selector_state_change_event( - hass, {}, state_change_callback + with pytest.raises(HomeAssistantError) as excinfo: + target.async_track_target_selector_state_change_event( + hass, {}, state_change_callback + ) + assert str(excinfo.value) == ( + "Target selector {} does not have any selectors defined" ) - assert "Target selector {} does not have any selectors defined" in caplog.text - - # Test that no state changes are tracked - hass.states.async_set("light.test", "on") - await hass.async_block_till_done() - - assert len(calls) == 0 - - unsub() - async def test_async_track_target_selector_state_change_event( hass: HomeAssistant, From 28406144d30a5e22033c9439daf9774165ff7f27 Mon Sep 17 00:00:00 2001 From: abmantis Date: Mon, 14 Jul 2025 11:35:58 +0100 Subject: [PATCH 13/19] Revert change to test_trigger.py --- tests/helpers/test_trigger.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/helpers/test_trigger.py b/tests/helpers/test_trigger.py index f2f81e8809c17e..ba9db9cb053646 100644 --- a/tests/helpers/test_trigger.py +++ b/tests/helpers/test_trigger.py @@ -727,6 +727,19 @@ def _load_yaml(fname, secrets=None): ) in caplog.text +async def test_invalid_trigger_platform( + hass: HomeAssistant, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test invalid trigger platform.""" + mock_integration(hass, MockModule("test", async_setup=AsyncMock(return_value=True))) + mock_platform(hass, "test.trigger", MockPlatform()) + + await async_setup_component(hass, "test", {}) + + assert "Integration test does not provide trigger support, skipping" in caplog.text + + @patch("annotatedyaml.loader.load_yaml") @patch.object(Integration, "has_triggers", return_value=True) async def test_subscribe_triggers( @@ -767,16 +780,3 @@ async def good_subscriber(new_triggers: set[str]): assert trigger_events == [{"sun"}] assert "Error while notifying trigger platform listener" in caplog.text - - -async def test_invalid_trigger_platform( - hass: HomeAssistant, - caplog: pytest.LogCaptureFixture, -) -> None: - """Test invalid trigger platform.""" - mock_integration(hass, MockModule("test", async_setup=AsyncMock(return_value=True))) - mock_platform(hass, "test.trigger", MockPlatform()) - - await async_setup_component(hass, "test", {}) - - assert "Integration test does not provide trigger support, skipping" in caplog.text From 19e9ce3a1b32e26a367b420e4825183acc6716f9 Mon Sep 17 00:00:00 2001 From: abmantis Date: Mon, 14 Jul 2025 11:52:38 +0100 Subject: [PATCH 14/19] Remove job_type --- homeassistant/helpers/target.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/homeassistant/helpers/target.py b/homeassistant/helpers/target.py index 513d6b70d19401..d95a9ef108a8fb 100644 --- a/homeassistant/helpers/target.py +++ b/homeassistant/helpers/target.py @@ -20,7 +20,6 @@ CALLBACK_TYPE, Event, EventStateChangedData, - HassJobType, HomeAssistant, callback, ) @@ -260,13 +259,11 @@ def __init__( self, hass: HomeAssistant, selector_data: TargetSelectorData, - job_type: HassJobType | None, action: Callable[[Event[EventStateChangedData]], Any], ) -> None: """Initialize the state change tracker.""" self._hass = hass self._selector_data = selector_data - self._job_type = job_type self._action = action self._state_change_unsub: CALLBACK_TYPE | None = None @@ -294,7 +291,7 @@ def state_change_listener(event: Event[EventStateChangedData]) -> None: _LOGGER.debug("Tracking state changes for entities: %s", tracked_entities) self._state_change_unsub = async_track_state_change_event( - self._hass, tracked_entities, state_change_listener, job_type=self._job_type + self._hass, tracked_entities, state_change_listener ) def _setup_registry_listeners(self) -> None: @@ -340,7 +337,6 @@ def async_track_target_selector_state_change_event( hass: HomeAssistant, target_selector_config: ConfigType, action: Callable[[Event[EventStateChangedData]], Any], - job_type: HassJobType | None = None, ) -> CALLBACK_TYPE: """Track state changes for entities referenced directly or indirectly in a target selector.""" selector_data = TargetSelectorData(target_selector_config) @@ -348,5 +344,5 @@ def async_track_target_selector_state_change_event( raise HomeAssistantError( f"Target selector {target_selector_config} does not have any selectors defined" ) - tracker = TargetStateChangeTracker(hass, selector_data, job_type, action) + tracker = TargetStateChangeTracker(hass, selector_data, action) return tracker.unsub From 90c525b222152d1a82dac917f4e139826e8309dd Mon Sep 17 00:00:00 2001 From: abmantis Date: Mon, 14 Jul 2025 11:59:14 +0100 Subject: [PATCH 15/19] Move setup to async_setup --- homeassistant/helpers/target.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/homeassistant/helpers/target.py b/homeassistant/helpers/target.py index d95a9ef108a8fb..239d1e66336273 100644 --- a/homeassistant/helpers/target.py +++ b/homeassistant/helpers/target.py @@ -269,8 +269,11 @@ def __init__( self._state_change_unsub: CALLBACK_TYPE | None = None self._registry_unsubs: list[CALLBACK_TYPE] = [] + def async_setup(self) -> Callable[[], None]: + """Set up the state change tracking.""" self._setup_registry_listeners() self._track_entities_state_change() + return self._unsubscribe def _track_entities_state_change(self) -> None: """Set up state change tracking for currently selected entities.""" @@ -323,7 +326,7 @@ def resubscribe_state_change_event(event: Event[Any] | None = None) -> None: ), ] - def unsub(self) -> None: + def _unsubscribe(self) -> None: """Unsubscribe from all events.""" for registry_unsub in self._registry_unsubs: registry_unsub() @@ -345,4 +348,4 @@ def async_track_target_selector_state_change_event( f"Target selector {target_selector_config} does not have any selectors defined" ) tracker = TargetStateChangeTracker(hass, selector_data, action) - return tracker.unsub + return tracker.async_setup() From a9d8f06db1453ee3f2d6a31ae34d71c381151071 Mon Sep 17 00:00:00 2001 From: abmantis Date: Mon, 14 Jul 2025 12:57:20 +0100 Subject: [PATCH 16/19] Remove async_block_till_done --- tests/helpers/test_target.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/tests/helpers/test_target.py b/tests/helpers/test_target.py index a5b75cf0615266..d23a5a0651e182 100644 --- a/tests/helpers/test_target.py +++ b/tests/helpers/test_target.py @@ -551,7 +551,9 @@ def assert_entity_calls_and_reset() -> None: targeted_entity = "light.test_light" - entities_to_set_state.extend([targeted_entity, device_entity, untargeted_entity]) + entities_to_set_state.extend( + [targeted_entity, device_entity, untargeted_entity, untargeted_device_entity] + ) await toggle_states() label = lr.async_get(hass).async_create("Test Label").name @@ -581,7 +583,6 @@ def assert_entity_calls_and_reset() -> None: unique_id="device_light_2", device_id=device_entry.id, ).entity_id - await hass.async_block_till_done() entities_to_set_state.append(device_entity_2) entities_to_assert_change.append(device_entity_2) @@ -589,49 +590,41 @@ def assert_entity_calls_and_reset() -> None: assert_entity_calls_and_reset() # Test untargeted entity -> should not trigger - entities_to_set_state.append(untargeted_entity) await toggle_states() assert_entity_calls_and_reset() # Add label to untargeted entity -> should trigger now entity_reg.async_update_entity(untargeted_entity, labels={label}) - await hass.async_block_till_done() entities_to_assert_change.append(untargeted_entity) await toggle_states() assert_entity_calls_and_reset() # Remove label from untargeted entity -> should not trigger anymore entity_reg.async_update_entity(untargeted_entity, labels={}) - await hass.async_block_till_done() entities_to_assert_change.remove(untargeted_entity) await toggle_states() assert_entity_calls_and_reset() # Add area to untargeted entity -> should trigger now entity_reg.async_update_entity(untargeted_entity, area_id=area) - await hass.async_block_till_done() entities_to_assert_change.append(untargeted_entity) await toggle_states() assert_entity_calls_and_reset() # Remove area from untargeted entity -> should not trigger anymore entity_reg.async_update_entity(untargeted_entity, area_id=None) - await hass.async_block_till_done() entities_to_assert_change.remove(untargeted_entity) await toggle_states() assert_entity_calls_and_reset() # Add area to untargeted device -> should trigger on state change device_reg.async_update_device(untargeted_device_entry.id, area_id=area) - await hass.async_block_till_done() - entities_to_set_state.append(untargeted_device_entity) entities_to_assert_change.append(untargeted_device_entity) await toggle_states() assert_entity_calls_and_reset() # Remove area from untargeted device -> should not trigger anymore device_reg.async_update_device(untargeted_device_entry.id, area_id=None) - await hass.async_block_till_done() entities_to_assert_change.remove(untargeted_device_entity) await toggle_states() assert_entity_calls_and_reset() @@ -639,20 +632,17 @@ def assert_entity_calls_and_reset() -> None: # Set the untargeted area on the untargeted entity -> should not trigger untracked_area = ar.async_get(hass).async_create("Untargeted Area").id entity_reg.async_update_entity(untargeted_entity, area_id=untracked_area) - await hass.async_block_till_done() await toggle_states() assert_entity_calls_and_reset() # Set targeted floor on the untargeted area -> should trigger now ar.async_get(hass).async_update(untracked_area, floor_id=floor) - await hass.async_block_till_done() entities_to_assert_change.append(untargeted_entity) await toggle_states() assert_entity_calls_and_reset() # Remove untargeted area from targeted floor -> should not trigger anymore ar.async_get(hass).async_update(untracked_area, floor_id=None) - await hass.async_block_till_done() entities_to_assert_change.remove(untargeted_entity) await toggle_states() assert_entity_calls_and_reset() From 571905a96828eda13dc2fbf1ecf9596aaed9c31e Mon Sep 17 00:00:00 2001 From: abmantis Date: Mon, 14 Jul 2025 15:18:07 +0100 Subject: [PATCH 17/19] Address review suggestions --- tests/helpers/test_target.py | 114 +++++++++++++++++------------------ 1 file changed, 55 insertions(+), 59 deletions(-) diff --git a/tests/helpers/test_target.py b/tests/helpers/test_target.py index d23a5a0651e182..9e3cf3c79dcb14 100644 --- a/tests/helpers/test_target.py +++ b/tests/helpers/test_target.py @@ -14,7 +14,7 @@ STATE_ON, EntityCategory, ) -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import Event, EventStateChangedData, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import ( area_registry as ar, @@ -484,35 +484,32 @@ async def test_async_track_target_selector_state_change_event( hass: HomeAssistant, ) -> None: """Test async_track_target_selector_state_change_event with multiple targets.""" - calls = [] + events: list[Event[EventStateChangedData]] = [] @callback - def state_change_callback(event): + def state_change_callback(event: Event[EventStateChangedData]): """Handle state change events.""" - calls.append(event) + events.append(event) - # List of entities to toggle state during the test. This list should be insert-only - # so that all entities are changed every time. - entities_to_set_state = [] - # List of entities that should assert a state change when toggled. Contrary to - # entities_to_set_state, entities should be added and removed. - entities_to_assert_change = [] last_state = STATE_OFF - async def toggle_states(): - """Toggle the state of all the entities in test.""" + async def set_states_and_check_events( + entities_to_set_state: list[str], entities_to_assert_change: list[str] + ) -> None: + """Toggle the state entities and check for events.""" nonlocal last_state last_state = STATE_ON if last_state == STATE_OFF else STATE_OFF for entity_id in entities_to_set_state: hass.states.async_set(entity_id, last_state) await hass.async_block_till_done() - def assert_entity_calls_and_reset() -> None: - assert len(calls) == len(entities_to_assert_change) - for change_call in calls: - assert change_call.data["entity_id"] in entities_to_assert_change - assert change_call.data["new_state"].state == last_state - calls.clear() + assert len(events) == len(entities_to_assert_change) + entities_seen = set() + for event in events: + entities_seen.add(event.data["entity_id"]) + assert event.data["new_state"].state == last_state + assert entities_seen == set(entities_to_assert_change) + events.clear() config_entry = MockConfigEntry(domain="test") config_entry.add_to_hass(hass) @@ -551,10 +548,10 @@ def assert_entity_calls_and_reset() -> None: targeted_entity = "light.test_light" - entities_to_set_state.extend( - [targeted_entity, device_entity, untargeted_entity, untargeted_device_entity] - ) - await toggle_states() + # List of entities to toggle state during the test. This list should be insert-only + # so that all entities are changed every time. + targeted_entities = [targeted_entity, device_entity] + await set_states_and_check_events(targeted_entities, []) label = lr.async_get(hass).async_create("Test Label").name area = ar.async_get(hass).async_create("Test Area").id @@ -572,9 +569,7 @@ def assert_entity_calls_and_reset() -> None: ) # Test directly targeted entity and device - entities_to_assert_change.extend([targeted_entity, device_entity]) - await toggle_states() - assert_entity_calls_and_reset() + await set_states_and_check_events(targeted_entities, targeted_entities) # Add new entity to the targeted device -> should trigger on state change device_entity_2 = entity_reg.async_get_or_create( @@ -584,70 +579,71 @@ def assert_entity_calls_and_reset() -> None: device_id=device_entry.id, ).entity_id - entities_to_set_state.append(device_entity_2) - entities_to_assert_change.append(device_entity_2) - await toggle_states() - assert_entity_calls_and_reset() + targeted_entities = [targeted_entity, device_entity, device_entity_2] + await set_states_and_check_events(targeted_entities, targeted_entities) # Test untargeted entity -> should not trigger - await toggle_states() - assert_entity_calls_and_reset() + await set_states_and_check_events( + [*targeted_entities, untargeted_entity], targeted_entities + ) # Add label to untargeted entity -> should trigger now entity_reg.async_update_entity(untargeted_entity, labels={label}) - entities_to_assert_change.append(untargeted_entity) - await toggle_states() - assert_entity_calls_and_reset() + await set_states_and_check_events( + [*targeted_entities, untargeted_entity], [*targeted_entities, untargeted_entity] + ) # Remove label from untargeted entity -> should not trigger anymore entity_reg.async_update_entity(untargeted_entity, labels={}) - entities_to_assert_change.remove(untargeted_entity) - await toggle_states() - assert_entity_calls_and_reset() + await set_states_and_check_events( + [*targeted_entities, untargeted_entity], targeted_entities + ) # Add area to untargeted entity -> should trigger now entity_reg.async_update_entity(untargeted_entity, area_id=area) - entities_to_assert_change.append(untargeted_entity) - await toggle_states() - assert_entity_calls_and_reset() + await set_states_and_check_events( + [*targeted_entities, untargeted_entity], [*targeted_entities, untargeted_entity] + ) # Remove area from untargeted entity -> should not trigger anymore entity_reg.async_update_entity(untargeted_entity, area_id=None) - entities_to_assert_change.remove(untargeted_entity) - await toggle_states() - assert_entity_calls_and_reset() + await set_states_and_check_events( + [*targeted_entities, untargeted_entity], targeted_entities + ) # Add area to untargeted device -> should trigger on state change device_reg.async_update_device(untargeted_device_entry.id, area_id=area) - entities_to_assert_change.append(untargeted_device_entity) - await toggle_states() - assert_entity_calls_and_reset() + await set_states_and_check_events( + [*targeted_entities, untargeted_device_entity], + [*targeted_entities, untargeted_device_entity], + ) # Remove area from untargeted device -> should not trigger anymore device_reg.async_update_device(untargeted_device_entry.id, area_id=None) - entities_to_assert_change.remove(untargeted_device_entity) - await toggle_states() - assert_entity_calls_and_reset() + await set_states_and_check_events( + [*targeted_entities, untargeted_device_entity], targeted_entities + ) # Set the untargeted area on the untargeted entity -> should not trigger untracked_area = ar.async_get(hass).async_create("Untargeted Area").id entity_reg.async_update_entity(untargeted_entity, area_id=untracked_area) - await toggle_states() - assert_entity_calls_and_reset() + await set_states_and_check_events( + [*targeted_entities, untargeted_entity], targeted_entities + ) # Set targeted floor on the untargeted area -> should trigger now ar.async_get(hass).async_update(untracked_area, floor_id=floor) - entities_to_assert_change.append(untargeted_entity) - await toggle_states() - assert_entity_calls_and_reset() + await set_states_and_check_events( + [*targeted_entities, untargeted_entity], + [*targeted_entities, untargeted_entity], + ) # Remove untargeted area from targeted floor -> should not trigger anymore ar.async_get(hass).async_update(untracked_area, floor_id=None) - entities_to_assert_change.remove(untargeted_entity) - await toggle_states() - assert_entity_calls_and_reset() + await set_states_and_check_events( + [*targeted_entities, untargeted_entity], targeted_entities + ) # After unsubscribing, changes should not trigger unsub() - await toggle_states() - assert len(calls) == 0 + await set_states_and_check_events(targeted_entities, []) From 9c4c4607d2f578df0319f929e3ea719bc38616f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ab=C3=ADlio=20Costa?= Date: Mon, 14 Jul 2025 16:06:46 +0100 Subject: [PATCH 18/19] Apply suggestions from code review Co-authored-by: Erik Montnemery --- tests/helpers/test_target.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/helpers/test_target.py b/tests/helpers/test_target.py index 9e3cf3c79dcb14..12765d82c47915 100644 --- a/tests/helpers/test_target.py +++ b/tests/helpers/test_target.py @@ -464,12 +464,10 @@ async def test_async_track_target_selector_state_change_event_empty_selector( hass: HomeAssistant, caplog: pytest.LogCaptureFixture ) -> None: """Test async_track_target_selector_state_change_event with empty selector.""" - calls = [] @callback def state_change_callback(event): """Handle state change events.""" - calls.append(event) with pytest.raises(HomeAssistantError) as excinfo: target.async_track_target_selector_state_change_event( From 058928eacde9d7c28a02877fd11f1caf9ff3437e Mon Sep 17 00:00:00 2001 From: abmantis Date: Mon, 14 Jul 2025 16:07:29 +0100 Subject: [PATCH 19/19] Remove stale comment --- tests/helpers/test_target.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/helpers/test_target.py b/tests/helpers/test_target.py index 12765d82c47915..c87a320e3789bf 100644 --- a/tests/helpers/test_target.py +++ b/tests/helpers/test_target.py @@ -546,8 +546,6 @@ async def set_states_and_check_events( targeted_entity = "light.test_light" - # List of entities to toggle state during the test. This list should be insert-only - # so that all entities are changed every time. targeted_entities = [targeted_entity, device_entity] await set_states_and_check_events(targeted_entities, [])