Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion homeassistant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def __init__(
event_type: str,
data: Optional[Dict[str, Any]] = None,
origin: EventOrigin = EventOrigin.local,
time_fired: Optional[int] = None,
time_fired: Optional[datetime.datetime] = None,
context: Optional[Context] = None,
) -> None:
"""Initialize a new event."""
Expand Down
220 changes: 172 additions & 48 deletions homeassistant/helpers/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,12 @@ class TrackTemplate:

The template is template to calculate.
The variables are variables to pass to the template.
The rate_limit is a rate limit on how often the template is re-rendered.
"""

template: Template
variables: TemplateVarsType
rate_limit: Optional[timedelta] = None


@dataclass
Expand Down Expand Up @@ -560,11 +562,15 @@ def __init__(
self._listeners: Dict[str, Callable] = {}

self._last_result: Dict[Template, Union[str, TemplateError]] = {}
self._last_info: Dict[Template, RenderInfo] = {}
self._last_rendered: Dict[Template, datetime] = {}
self._rate_limit_timers: Dict[Template, asyncio.TimerHandle] = {}
self._info: Dict[Template, RenderInfo] = {}
self._last_domains: Set = set()
self._last_entities: Set = set()

self._scan_intervals: Dict[Template, timedelta] = {}
self._scan_interval_listeners: Dict[Template, Callable] = {}

def async_setup(self, raise_on_template_error: bool) -> None:
"""Activation of template tracking."""
for track_template_ in self._track_templates:
Expand All @@ -581,7 +587,6 @@ def async_setup(self, raise_on_template_error: bool) -> None:
exc_info=self._info[template].exception,
)

self._last_info = self._info.copy()
self._create_listeners()
_LOGGER.debug(
"Template group %s listens for %s",
Expand All @@ -596,32 +601,28 @@ def listeners(self) -> Dict:
"all": _TEMPLATE_ALL_LISTENER in self._listeners,
"entities": self._last_entities,
"domains": self._last_domains,
"scan_intervals": list(self._scan_intervals.values()),
}

@property
def _needs_all_listener(self) -> bool:
for track_template_ in self._track_templates:
template = track_template_.template

for info in self._info.values():
# Tracking all states
if (
self._info[template].all_states
or self._info[template].all_states_lifecycle
):
if info.all_states or info.all_states_lifecycle:
return True

# Previous call had an exception
# so we do not know which states
# to track
if self._info[template].exception:
if info.exception:
return True

return False

@property
def _all_templates_are_static(self) -> bool:
for track_template_ in self._track_templates:
if not self._info[track_template_.template].is_static:
for info in self._info.values():
if not info.is_static:
return False

return True
Expand All @@ -631,6 +632,10 @@ def _create_listeners(self) -> None:
if self._all_templates_are_static:
return

for template, info in self._info.items():
if info.scan_interval:
self._setup_scan_interval_listener(template, info.scan_interval)

if self._needs_all_listener:
self._setup_all_listener()
return
Expand All @@ -641,15 +646,39 @@ def _create_listeners(self) -> None:
self._setup_domains_listener(self._last_domains)
self._setup_entities_listener(self._last_domains, self._last_entities)

@callback
def _cancel_scan_interval_listener(self, template: Template) -> None:
if template not in self._scan_interval_listeners:
return

self._scan_intervals.pop(template)
self._scan_interval_listeners.pop(template)()

@callback
def _cancel_listener(self, listener_name: str) -> None:
if listener_name not in self._listeners:
return

self._listeners.pop(listener_name)()

@callback
def _cancel_rate_limit_timer(self, template: Template) -> None:
if template not in self._rate_limit_timers:
return

self._rate_limit_timers.pop(template).cancel()

@callback
def _update_listeners(self) -> None:

for template, info in self._info.items():
if (
info.scan_interval
and info.scan_interval != self._scan_intervals[template]
):
self._cancel_scan_interval_listener(template)
self._setup_scan_interval_listener(template, info.scan_interval)

had_all_listener = _TEMPLATE_ALL_LISTENER in self._listeners

if self._needs_all_listener:
Expand Down Expand Up @@ -680,6 +709,24 @@ def _update_listeners(self) -> None:
self._last_domains = domains
self._last_entities = entities

@callback
def _setup_scan_interval_listener(
self, template: Template, scan_interval: timedelta
) -> None:
self._scan_intervals[template] = scan_interval

# Set to None
if not scan_interval:
return

@callback
def _refresh_from_interval(now: datetime) -> None:
self._refresh(None, template)

self._scan_interval_listeners[template] = async_track_time_interval(
self.hass, _refresh_from_interval, scan_interval
)

@callback
def _setup_entities_listener(self, domains: Set, entities: Set) -> None:
if domains:
Expand Down Expand Up @@ -712,64 +759,140 @@ def _setup_all_listener(self) -> None:
@callback
def async_remove(self) -> None:
"""Cancel the listener."""
self._cancel_listener(_TEMPLATE_ALL_LISTENER)
self._cancel_listener(_TEMPLATE_DOMAINS_LISTENER)
self._cancel_listener(_TEMPLATE_ENTITIES_LISTENER)
for key in list(self._listeners):
self._listeners.pop(key)()
for template in list(self._scan_interval_listeners):
self._scan_interval_listeners.pop(template)()
for track_template_ in self._track_templates:
self._cancel_rate_limit_timer(track_template_.template)

@callback
def async_refresh(self) -> None:
"""Force recalculate the template."""
self._refresh(None)

@callback
def _refresh(self, event: Optional[Event]) -> None:
entity_id = event and event.data.get(ATTR_ENTITY_ID)
lifecycle_event = event and (
event.data.get("new_state") is None or event.data.get("old_state") is None
def _handle_rate_limit(
self,
track_template_: TrackTemplate,
event: Event,
now: datetime,
) -> bool:
"""Check rate limits and call later if the rate limit is hit.

If there is already a call later scheduled for the template
we do not setup a second one.

Returns True if the rate limit has been hit or False on miss.
"""
template = track_template_.template
rate_limit = self._info[template].rate_limit or track_template_.rate_limit

if not rate_limit or template not in self._last_rendered:
return False

next_render_time = self._last_rendered[template] + rate_limit

if next_render_time <= now:
self._cancel_rate_limit_timer(template)
return False

_LOGGER.debug(
"Template rate_limit %s hit by event %s deferred by rate_limit %s to %s",
template.template,
event,
rate_limit,
next_render_time,
)
updates = []
info_changed = False

for track_template_ in self._track_templates:
template = track_template_.template
if template not in self._rate_limit_timers:
self._rate_limit_timers[template] = self.hass.loop.call_later(
(next_render_time - now).total_seconds() + MAX_TIME_TRACKING_ERROR,
self._refresh,
event,
)

return True

@callback
def _event_triggers_template(self, template: Template, event: Event) -> bool:
"""Determine if a template should be re-rendered from an event."""
entity_id = event.data.get(ATTR_ENTITY_ID)
return (
self._info[template].filter(entity_id)
or event.data.get("new_state") is None
or event.data.get("old_state") is None
and self._info[template].filter_lifecycle(entity_id)
)

@callback
def _render_template_if_ready(
self,
track_template_: TrackTemplate,
event: Optional[Event],
now: datetime,
template_filter: Optional[Template],
) -> Tuple[bool, Optional[TrackTemplateResult]]:
template = track_template_.template
if template_filter and template != template_filter:
return False, None

if event and template not in self._scan_interval_listeners:
if (
entity_id
and not self._last_info[template].filter(entity_id)
and (
not lifecycle_event
or not self._last_info[template].filter_lifecycle(entity_id)
)
template not in self._rate_limit_timers
and not self._event_triggers_template(template, event)
):
continue
return False, None

if self._handle_rate_limit(track_template_, event, now):
return False, None

_LOGGER.debug(
"Template update %s triggered by event: %s",
template.template,
event,
)
else:
self._cancel_rate_limit_timer(template)

self._info[template] = template.async_render_to_info(
track_template_.variables
)
info_changed = True
self._last_rendered[template] = now
self._info[template] = template.async_render_to_info(track_template_.variables)

try:
result: Union[str, TemplateError] = self._info[template].result()
except TemplateError as ex:
result = ex

try:
result: Union[str, TemplateError] = self._info[template].result()
except TemplateError as ex:
result = ex
last_result = self._last_result.get(template)

last_result = self._last_result.get(template)
# Check to see if the result has changed
if result == last_result:
return True, None

# Check to see if the result has changed
if result == last_result:
continue
if isinstance(result, TemplateError) and isinstance(last_result, TemplateError):
return True, None

if isinstance(result, TemplateError) and isinstance(
last_result, TemplateError
):
continue
return True, TrackTemplateResult(template, last_result, result)

updates.append(TrackTemplateResult(template, last_result, result))
@callback
def _refresh(
self, event: Optional[Event], template_filter: Optional[Template] = None
) -> None:
updates = []
info_changed = False
now = dt_util.utcnow()

for track_template_ in self._track_templates:
rendered, update = self._render_template_if_ready(
track_template_,
event,
now,
template_filter,
)
if rendered:
info_changed = True
if update:
updates.append(update)

if info_changed:
self._update_listeners()
Expand All @@ -778,7 +901,6 @@ def _refresh(self, event: Optional[Event]) -> None:
self._track_templates,
self.listeners,
)
self._last_info = self._info.copy()

if not updates:
return
Expand Down Expand Up @@ -1243,6 +1365,8 @@ def _entities_domains_from_info(render_infos: Iterable[RenderInfo]) -> Tuple[Set
domains = set()

for render_info in render_infos:
if render_info.scan_interval:
continue
if render_info.entities:
entities.update(render_info.entities)
if render_info.domains:
Expand Down
Loading