Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 123 additions & 5 deletions homeassistant/helpers/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,12 @@
NumericThresholdType,
TargetSelector,
)
from .target import TargetSelection, async_extract_referenced_entity_ids
from .target import (
TargetSelection,
TargetStateChangedData,
async_extract_referenced_entity_ids,
async_track_target_selector_state_change_event,
)
from .template import Template, render_complex
from .trace import (
TraceElement,
Expand Down Expand Up @@ -458,18 +463,108 @@ def __init__(self, hass: HomeAssistant, config: ConditionConfig) -> None:
if TYPE_CHECKING:
assert config.target
assert config.options
self._target = config.target
self._target_selection = TargetSelection(config.target)
self._behavior = config.options[ATTR_BEHAVIOR]
self._duration: timedelta | None = config.options.get(CONF_FOR)
if self._behavior == BEHAVIOR_ANY:
self._matcher = self._check_any_match_state
elif self._behavior == BEHAVIOR_ALL:
self._matcher = self._check_all_match_state
self._on_unload: list[Callable[[], None]] = []
self._valid_since: dict[str, datetime] = {}

def entity_filter(self, entities: set[str]) -> set[str]:
"""Filter entities matching any of the domain specs."""
return filter_by_domain_specs(self._hass, self._domain_specs, entities)

@property
def _needs_duration_tracking(self) -> bool:
"""Whether this condition needs active state change tracking for duration.

The base implementation intentionally defaults to always tracking
duration and should be overridden by subclasses that can safely use
state.last_changed directly. For example, conditions that are true
for a single main state value may not need active tracking, while
conditions that track attributes or match multiple states do because
last_changed does not capture those transitions.
"""
return True

def _update_valid_since(self, entity_id: str, _state: State | None) -> None:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we prefix state with underscore?

"""Update _valid_since tracking for an entity based on its current state.

If the entity is in a valid state and not already tracked, records when
the condition became true. If the entity is not in a valid state, removes
it from tracking.

For state-based conditions (value_source is None), last_changed
accurately reflects when the state changed to the current value.
For attribute-based conditions, last_changed only tracks main state
changes, so we use last_updated which is bumped on any update
(state or attributes). This is conservative — the tracked attribute
may have held its value longer — but it's the best we can do
to avoid false positives.
"""
if (
_state is not None
and _state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
and self.is_valid_state(_state)
):
# Only record the time if not already tracked, to avoid
# resetting the duration on unrelated state/attribute updates.
if entity_id not in self._valid_since:
domain_spec = self._domain_specs[_state.domain]
if domain_spec.value_source is None:
self._valid_since[entity_id] = _state.last_changed
else:
self._valid_since[entity_id] = _state.last_updated
else:
self._valid_since.pop(entity_id, None)

@override
async def async_setup(self) -> None:
"""Set up state tracking for duration-based conditions."""
await super().async_setup()
if not self._duration or not self._needs_duration_tracking:
return

@callback
def _state_change_listener(
data: TargetStateChangedData,
) -> None:
"""Track when entities enter or leave a valid state."""
event = data.state_change_event
entity_id = event.data["entity_id"]
to_state = event.data["new_state"]

self._update_valid_since(entity_id, to_state)

@callback
def _on_entities_update(added: set[str], removed: set[str]) -> None:
"""Handle changes to the tracked entity set."""
for entity_id in added:
self._update_valid_since(entity_id, self._hass.states.get(entity_id))
for entity_id in removed:
self._valid_since.pop(entity_id, None)

unsub = async_track_target_selector_state_change_event(
self._hass,
self._target,
_state_change_listener,
self.entity_filter,
_on_entities_update,
)
self._on_unload.append(unsub)
Comment thread
arturpragacz marked this conversation as resolved.

@override
def async_unload(self) -> None:
Comment thread
emontnemery marked this conversation as resolved.
"""Unsubscribe from listeners."""
super().async_unload()
for cb in self._on_unload:
cb()
self._on_unload.clear()
Comment thread
arturpragacz marked this conversation as resolved.

def _get_tracked_value(self, entity_state: State) -> Any:
"""Get the tracked value from a state based on the DomainSpec."""
domain_spec = self._domain_specs[entity_state.domain]
Expand All @@ -486,9 +581,16 @@ def _check_any_match_state(self, states: list[State]) -> bool:
if not self._duration:
# Skip duration check if duration is not specified or 0
return any(self.is_valid_state(state) for state in states)
duration = dt_util.utcnow() - self._duration
cutoff = dt_util.utcnow() - self._duration
if not self._needs_duration_tracking:
return any(
self.is_valid_state(state) and state.last_changed <= cutoff
for state in states
)
return any(
self.is_valid_state(state) and duration > state.last_changed
self.is_valid_state(state)
and (valid_since := self._valid_since.get(state.entity_id)) is not None
and valid_since <= cutoff
for state in states
)

Expand All @@ -497,9 +599,16 @@ def _check_all_match_state(self, states: list[State]) -> bool:
if not self._duration:
# Skip duration check if duration is not specified or 0
return all(self.is_valid_state(state) for state in states)
duration = dt_util.utcnow() - self._duration
cutoff = dt_util.utcnow() - self._duration
if not self._needs_duration_tracking:
return all(
self.is_valid_state(state) and state.last_changed <= cutoff
for state in states
)
return all(
self.is_valid_state(state) and duration > state.last_changed
self.is_valid_state(state)
and (valid_since := self._valid_since.get(state.entity_id)) is not None
and valid_since <= cutoff
for state in states
)

Expand All @@ -526,6 +635,15 @@ class EntityStateConditionBase(EntityConditionBase):

_states: set[str | bool]

@property
def _needs_duration_tracking(self) -> bool:
"""Single-state conditions with no attribute tracking can use last_changed."""
if len(self._states) != 1:
return True
return any(
spec.value_source is not None for spec in self._domain_specs.values()
)

def is_valid_state(self, entity_state: State) -> bool:
"""Check if the state matches the expected state(s)."""
return self._get_tracked_value(entity_state) in self._states
Expand Down
13 changes: 13 additions & 0 deletions homeassistant/helpers/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ def __init__(
target_selection: TargetSelection,
action: Callable[[TargetStateChangedData], Any],
entity_filter: Callable[[set[str]], set[str]],
on_entities_update: Callable[[set[str], set[str]], None] | None = None,
*,
primary_entities_only: bool = True,
) -> None:
Expand All @@ -368,10 +369,20 @@ def __init__(
primary_entities_only=primary_entities_only,
)
self._action = action
self._on_entities_update = on_entities_update
self._state_change_unsub: CALLBACK_TYPE | None = None
self._tracked_entities: set[str] = set()

def _handle_entities_update(self, tracked_entities: set[str]) -> None:
"""Handle the tracked entities."""
previous_entities = self._tracked_entities
self._tracked_entities = tracked_entities

if self._on_entities_update is not None:
added = tracked_entities - previous_entities
removed = previous_entities - tracked_entities
if added or removed:
self._on_entities_update(added, removed)

@callback
def state_change_listener(event: Event[EventStateChangedData]) -> None:
Expand Down Expand Up @@ -399,6 +410,7 @@ def async_track_target_selector_state_change_event(
target_selector_config: ConfigType,
action: Callable[[TargetStateChangedData], Any],
entity_filter: Callable[[set[str]], set[str]] = lambda x: x,
on_entities_update: Callable[[set[str], set[str]], None] | None = None,
*,
primary_entities_only: bool = True,
) -> CALLBACK_TYPE:
Expand All @@ -417,6 +429,7 @@ def async_track_target_selector_state_change_event(
target_selection,
action,
entity_filter,
on_entities_update,
primary_entities_only=primary_entities_only,
)
return tracker.async_setup()
Loading