Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion homeassistant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def __init__(
event_type: str,
data: Optional[Dict[str, Any]] = None,
origin: EventOrigin = EventOrigin.local,
time_fired: Optional[int] = None,
time_fired: Optional[datetime.datetime] = None,
context: Optional[Context] = None,
) -> None:
"""Initialize a new event."""
Expand Down
126 changes: 92 additions & 34 deletions homeassistant/helpers/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,12 @@ class TrackTemplate:

The template is template to calculate.
The variables are variables to pass to the template.
The rate_limit is a rate limit on how often the template is re-rendered.
"""

template: Template
variables: TemplateVarsType
rate_limit: Optional[timedelta] = None


@dataclass
Expand Down Expand Up @@ -560,7 +562,8 @@ def __init__(
self._listeners: Dict[str, Callable] = {}

self._last_result: Dict[Template, Union[str, TemplateError]] = {}
self._last_info: Dict[Template, RenderInfo] = {}
self._last_rendered: Dict[Template, datetime] = {}
self._rate_limit_timers: Dict[Template, asyncio.TimerHandle] = {}
self._info: Dict[Template, RenderInfo] = {}
self._last_domains: Set = set()
self._last_entities: Set = set()
Expand All @@ -581,7 +584,6 @@ def async_setup(self, raise_on_template_error: bool) -> None:
exc_info=self._info[template].exception,
)

self._last_info = self._info.copy()
self._create_listeners()
_LOGGER.debug(
"Template group %s listens for %s",
Expand All @@ -600,28 +602,23 @@ def listeners(self) -> Dict:

@property
def _needs_all_listener(self) -> bool:
for track_template_ in self._track_templates:
template = track_template_.template

for info in self._info.values():
# Tracking all states
if (
self._info[template].all_states
or self._info[template].all_states_lifecycle
):
if info.all_states or info.all_states_lifecycle:
return True

# Previous call had an exception
# so we do not know which states
# to track
if self._info[template].exception:
if info.exception:
return True

return False

@property
def _all_templates_are_static(self) -> bool:
for track_template_ in self._track_templates:
if not self._info[track_template_.template].is_static:
for info in self._info.values():
if not info.is_static:
return False

return True
Expand All @@ -648,6 +645,13 @@ def _cancel_listener(self, listener_name: str) -> None:

self._listeners.pop(listener_name)()

@callback
def _cancel_rate_limit_timer(self, template: Template) -> None:
if template not in self._rate_limit_timers:
return

self._rate_limit_timers.pop(template).cancel()

@callback
def _update_listeners(self) -> None:
had_all_listener = _TEMPLATE_ALL_LISTENER in self._listeners
Expand Down Expand Up @@ -712,42 +716,97 @@ def _setup_all_listener(self) -> None:
@callback
def async_remove(self) -> None:
"""Cancel the listener."""
self._cancel_listener(_TEMPLATE_ALL_LISTENER)
self._cancel_listener(_TEMPLATE_DOMAINS_LISTENER)
self._cancel_listener(_TEMPLATE_ENTITIES_LISTENER)
for key in list(self._listeners):
self._listeners.pop(key)()
for track_template_ in self._track_templates:
self._cancel_rate_limit_timer(track_template_.template)

@callback
def async_refresh(self) -> None:
"""Force recalculate the template."""
self._refresh(None)

@callback
def _refresh(self, event: Optional[Event]) -> None:
entity_id = event and event.data.get(ATTR_ENTITY_ID)
lifecycle_event = event and (
event.data.get("new_state") is None or event.data.get("old_state") is None
def _handle_rate_limit(
self,
track_template_: TrackTemplate,
event: Event,
now: datetime,
) -> bool:
"""Check rate limits and call later if the rate limit is hit.

If there is already a call later scheduled for the template
we do not setup a second one.

Returns True if the rate limit has been hit or False on miss.
"""
template = track_template_.template
rate_limit = self._info[template].rate_limit or track_template_.rate_limit

if not rate_limit or template not in self._last_rendered:
return False

next_render_time = self._last_rendered[template] + rate_limit

if next_render_time <= now:
self._cancel_rate_limit_timer(template)
return False

_LOGGER.debug(
"Template rate_limit %s hit by event %s deferred by rate_limit %s to %s",
template.template,
event,
rate_limit,
next_render_time,
)

if template not in self._rate_limit_timers:
self._rate_limit_timers[template] = self.hass.loop.call_later(
(next_render_time - now).total_seconds() + MAX_TIME_TRACKING_ERROR,
self._refresh,
event,
)

return True

@callback
def _event_triggers_template(self, template: Template, event: Event) -> bool:
"""Determine if a template should be re-rendered from an event."""
entity_id = event.data.get(ATTR_ENTITY_ID)
return (
self._info[template].filter(entity_id)
or event.data.get("new_state") is None
or event.data.get("old_state") is None
and self._info[template].filter_lifecycle(entity_id)
)

@callback
def _refresh(self, event: Optional[Event]) -> None:
updates = []
info_changed = False
now = dt_util.utcnow()

for track_template_ in self._track_templates:
template = track_template_.template
if (
entity_id
and not self._last_info[template].filter(entity_id)
and (
not lifecycle_event
or not self._last_info[template].filter_lifecycle(entity_id)
if event:
if (
template not in self._rate_limit_timers
and not self._event_triggers_template(template, event)
):
continue

if self._handle_rate_limit(track_template_, event, now):
continue

_LOGGER.debug(
"Template update %s triggered by event: %s",
template.template,
event,
)
):
continue

_LOGGER.debug(
"Template update %s triggered by event: %s",
template.template,
event,
)
else:
self._cancel_rate_limit_timer(template)

self._last_rendered[template] = now
self._info[template] = template.async_render_to_info(
track_template_.variables
)
Expand Down Expand Up @@ -778,7 +837,6 @@ def _refresh(self, event: Optional[Event]) -> None:
self._track_templates,
self.listeners,
)
self._last_info = self._info.copy()

if not updates:
return
Expand Down
28 changes: 26 additions & 2 deletions homeassistant/helpers/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,11 @@ def __init__(self, template):
self.domains = set()
self.domains_lifecycle = set()
self.entities = set()
self.rate_limit = None

def __repr__(self) -> str:
"""Representation of RenderInfo."""
return f"<RenderInfo {self.template} all_states={self.all_states} all_states_lifecycle={self.all_states_lifecycle} domains={self.domains} domains_lifecycle={self.domains_lifecycle} entities={self.entities}>"
return f"<RenderInfo {self.template} all_states={self.all_states} all_states_lifecycle={self.all_states_lifecycle} domains={self.domains} domains_lifecycle={self.domains_lifecycle} entities={self.entities} rate_limit={self.rate_limit}>"

def _filter_domains_and_entities(self, entity_id: str) -> bool:
"""Template should re-render if the entity state changes when we match specific domains or entities."""
Expand Down Expand Up @@ -467,6 +468,28 @@ def __repr__(self) -> str:
return 'Template("' + self.template + '")'


class RateLimit:
"""Class to control update rate limits."""

def __init__(self, hass: HomeAssistantType):
"""Initialize rate limit."""
self._hass = hass

def __call__(self, *args: Any, **kwargs: Any) -> Optional[timedelta]:
"""Handle a call to the class."""
delta = timedelta(*args, **kwargs)

render_info = self._hass.data.get(_RENDER_INFO)
if render_info is not None:
render_info.rate_limit = delta

return delta

def __repr__(self) -> str:
"""Representation of a RateLimit."""
return "<template RateLimit>"


class AllStates:
"""Class to expose all HA states as attributes."""

Expand Down Expand Up @@ -1201,10 +1224,11 @@ def wrapper(*args, **kwargs):
self.globals["is_state_attr"] = hassfunction(is_state_attr)
self.globals["state_attr"] = hassfunction(state_attr)
self.globals["states"] = AllStates(hass)
self.globals["rate_limit"] = RateLimit(hass)

def is_safe_callable(self, obj):
"""Test if callback is safe."""
return isinstance(obj, AllStates) or super().is_safe_callable(obj)
return isinstance(obj, (AllStates, RateLimit)) or super().is_safe_callable(obj)

def is_safe_attribute(self, obj, attr, value):
"""Test if attribute is safe."""
Expand Down
Loading