Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions homeassistant/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,3 +623,7 @@

# The ID of the Home Assistant Cast App
CAST_APP_ID_HOMEASSISTANT = "B12CE3CA"

# The tracker error allow when converting
# loop time to human readable time
MAX_TIME_TRACKING_ERROR = 0.001
2 changes: 1 addition & 1 deletion homeassistant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def __init__(
event_type: str,
data: Optional[Dict[str, Any]] = None,
origin: EventOrigin = EventOrigin.local,
time_fired: Optional[int] = None,
time_fired: Optional[datetime.datetime] = None,
context: Optional[Context] = None,
) -> None:
"""Initialize a new event."""
Expand Down
24 changes: 21 additions & 3 deletions homeassistant/helpers/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
EVENT_STATE_CHANGED,
EVENT_TIME_CHANGED,
MATCH_ALL,
MAX_TIME_TRACKING_ERROR,
SUN_EVENT_SUNRISE,
SUN_EVENT_SUNSET,
)
Expand All @@ -40,15 +41,14 @@
)
from homeassistant.exceptions import TemplateError
from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
from homeassistant.helpers.ratelimit import KeyedRateLimit
from homeassistant.helpers.sun import get_astral_event_next
from homeassistant.helpers.template import RenderInfo, Template, result_as_boolean
from homeassistant.helpers.typing import TemplateVarsType
from homeassistant.loader import bind_hass
from homeassistant.util import dt as dt_util
from homeassistant.util.async_ import run_callback_threadsafe

MAX_TIME_TRACKING_ERROR = 0.001

TRACK_STATE_CHANGE_CALLBACKS = "track_state_change_callbacks"
TRACK_STATE_CHANGE_LISTENER = "track_state_change_listener"

Expand Down Expand Up @@ -88,10 +88,12 @@ class TrackTemplate:

The template is template to calculate.
The variables are variables to pass to the template.
The rate_limit is a rate limit on how often the template is re-rendered.
"""

template: Template
variables: TemplateVarsType
rate_limit: Optional[timedelta] = None


@dataclass
Expand Down Expand Up @@ -724,6 +726,8 @@ def __init__(
self._track_templates = track_templates

self._last_result: Dict[Template, Union[str, TemplateError]] = {}

self._rate_limit = KeyedRateLimit(hass)
self._info: Dict[Template, RenderInfo] = {}
self._track_state_changes: Optional[_TrackStateChangeFiltered] = None

Expand Down Expand Up @@ -763,6 +767,7 @@ def async_remove(self) -> None:
"""Cancel the listener."""
assert self._track_state_changes
self._track_state_changes.async_remove()
self._rate_limit.async_remove()

@callback
def async_refresh(self) -> None:
Expand All @@ -784,11 +789,23 @@ def _event_triggers_template(self, template: Template, event: Event) -> bool:
def _refresh(self, event: Optional[Event]) -> None:
updates = []
info_changed = False
now = dt_util.utcnow()

for track_template_ in self._track_templates:
template = track_template_.template
if event:
if not self._event_triggers_template(template, event):
if not self._rate_limit.async_has_timer(
template
) and not self._event_triggers_template(template, event):
continue

if self._rate_limit.async_schedule_action(
template,
self._info[template].rate_limit or track_template_.rate_limit,
now,
self._refresh,
event,
):
continue

_LOGGER.debug(
Expand All @@ -797,6 +814,7 @@ def _refresh(self, event: Optional[Event]) -> None:
event,
)

self._rate_limit.async_triggered(template, now)
self._info[template] = template.async_render_to_info(
track_template_.variables
)
Expand Down
97 changes: 97 additions & 0 deletions homeassistant/helpers/ratelimit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""Ratelimit helper."""
import asyncio
from datetime import datetime, timedelta
import logging
from typing import Any, Callable, Dict, Hashable, Optional

from homeassistant.const import MAX_TIME_TRACKING_ERROR
from homeassistant.core import HomeAssistant, callback
import homeassistant.util.dt as dt_util

_LOGGER = logging.getLogger(__name__)


class KeyedRateLimit:
"""Class to track rate limits."""

def __init__(
self,
hass: HomeAssistant,
):
"""Initialize ratelimit tracker."""
self.hass = hass
self._last_triggered: Dict[Hashable, datetime] = {}
self._rate_limit_timers: Dict[Hashable, asyncio.TimerHandle] = {}

@callback
def async_has_timer(self, key: Hashable) -> bool:
"""Check if a rate limit timer is running."""
return key in self._rate_limit_timers

@callback
def async_triggered(self, key: Hashable, now: Optional[datetime] = None) -> None:
"""Call when the action we are tracking was triggered."""
self.async_cancel_timer(key)
self._last_triggered[key] = now or dt_util.utcnow()

@callback
def async_cancel_timer(self, key: Hashable) -> None:
"""Cancel a rate limit time that will call the action."""
if not self.async_has_timer(key):
return

self._rate_limit_timers.pop(key).cancel()

@callback
def async_remove(self) -> None:
"""Remove all timers."""
for timer in self._rate_limit_timers.values():
timer.cancel()
self._rate_limit_timers.clear()

@callback
def async_schedule_action(
self,
key: Hashable,
rate_limit: Optional[timedelta],
now: datetime,
action: Callable,
*args: Any,
) -> Optional[datetime]:
"""Check rate limits and schedule an action if we hit the limit.

If the rate limit is hit:
Schedules the action for when the rate limit expires
if there are no pending timers. The action must
be called in async.

Returns the time the rate limit will expire

If the rate limit is not hit:

Return None
"""
if rate_limit is None or key not in self._last_triggered:
return None

next_call_time = self._last_triggered[key] + rate_limit

if next_call_time <= now:
self.async_cancel_timer(key)
return None

_LOGGER.debug(
"Reached rate limit of %s for %s and deferred action until %s",
rate_limit,
key,
next_call_time,
)

if key not in self._rate_limit_timers:
self._rate_limit_timers[key] = self.hass.loop.call_later(
(next_call_time - now).total_seconds() + MAX_TIME_TRACKING_ERROR,
action,
*args,
)

return next_call_time
44 changes: 38 additions & 6 deletions homeassistant/helpers/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@
"name",
}

DEFAULT_RATE_LIMIT = timedelta(seconds=1)


@bind_hass
def attach(hass: HomeAssistantType, obj: Any) -> None:
Expand Down Expand Up @@ -198,10 +200,11 @@ def __init__(self, template):
self.domains = set()
self.domains_lifecycle = set()
self.entities = set()
self.rate_limit = None

def __repr__(self) -> str:
"""Representation of RenderInfo."""
return f"<RenderInfo {self.template} all_states={self.all_states} all_states_lifecycle={self.all_states_lifecycle} domains={self.domains} domains_lifecycle={self.domains_lifecycle} entities={self.entities}>"
return f"<RenderInfo {self.template} all_states={self.all_states} all_states_lifecycle={self.all_states_lifecycle} domains={self.domains} domains_lifecycle={self.domains_lifecycle} entities={self.entities} rate_limit={self.rate_limit}>"

def _filter_domains_and_entities(self, entity_id: str) -> bool:
"""Template should re-render if the entity state changes when we match specific domains or entities."""
Expand All @@ -221,16 +224,24 @@ def result(self) -> str:

def _freeze_static(self) -> None:
self.is_static = True
self.entities = frozenset(self.entities)
self.domains = frozenset(self.domains)
self.domains_lifecycle = frozenset(self.domains_lifecycle)
self._freeze_sets()
self.all_states = False

def _freeze(self) -> None:
def _freeze_sets(self) -> None:
self.entities = frozenset(self.entities)
self.domains = frozenset(self.domains)
self.domains_lifecycle = frozenset(self.domains_lifecycle)

def _freeze(self) -> None:
self._freeze_sets()

if self.rate_limit is None and (
self.domains or self.domains_lifecycle or self.all_states or self.exception
):
# If the template accesses all states or an entire
# domain, and no rate limit is set, we use the default.
self.rate_limit = DEFAULT_RATE_LIMIT

if self.exception:
return

Expand Down Expand Up @@ -478,6 +489,26 @@ def __repr__(self) -> str:
return 'Template("' + self.template + '")'


class RateLimit:
"""Class to control update rate limits."""

def __init__(self, hass: HomeAssistantType):
"""Initialize rate limit."""
self._hass = hass

def __call__(self, *args: Any, **kwargs: Any) -> str:
"""Handle a call to the class."""
render_info = self._hass.data.get(_RENDER_INFO)
if render_info is not None:
render_info.rate_limit = timedelta(*args, **kwargs)

return ""

def __repr__(self) -> str:
"""Representation of a RateLimit."""
return "<template RateLimit>"


class AllStates:
"""Class to expose all HA states as attributes."""

Expand Down Expand Up @@ -1279,10 +1310,11 @@ def wrapper(*args, **kwargs):
self.globals["is_state_attr"] = hassfunction(is_state_attr)
self.globals["state_attr"] = hassfunction(state_attr)
self.globals["states"] = AllStates(hass)
self.globals["rate_limit"] = RateLimit(hass)

def is_safe_callable(self, obj):
"""Test if callback is safe."""
return isinstance(obj, AllStates) or super().is_safe_callable(obj)
return isinstance(obj, (AllStates, RateLimit)) or super().is_safe_callable(obj)

def is_safe_attribute(self, obj, attr, value):
"""Test if attribute is safe."""
Expand Down
1 change: 1 addition & 0 deletions tests/components/template/test_cover.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ async def test_template_state_boolean(hass, calls):

async def test_template_position(hass, calls):
"""Test the position_template attribute."""
hass.states.async_set("cover.test", STATE_OPEN)
with assert_setup_component(1, "cover"):
assert await setup.async_setup_component(
hass,
Expand Down
16 changes: 13 additions & 3 deletions tests/components/template/test_sensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""The test for the Template sensor platform."""
from asyncio import Event
from datetime import timedelta
from unittest.mock import patch

from homeassistant.bootstrap import async_from_config_dict
Expand All @@ -17,7 +18,11 @@
from homeassistant.setup import ATTR_COMPONENT, async_setup_component, setup_component
import homeassistant.util.dt as dt_util

from tests.common import assert_setup_component, get_test_home_assistant
from tests.common import (
assert_setup_component,
async_fire_time_changed,
get_test_home_assistant,
)


class TestTemplateSensor:
Expand Down Expand Up @@ -900,8 +905,13 @@ async def test_self_referencing_entity_picture_loop(hass, caplog):

assert len(hass.states.async_all()) == 1

await hass.async_block_till_done()
await hass.async_block_till_done()
next_time = dt_util.utcnow() + timedelta(seconds=1.2)
with patch(
"homeassistant.helpers.ratelimit.dt_util.utcnow", return_value=next_time
):
async_fire_time_changed(hass, next_time)
await hass.async_block_till_done()
await hass.async_block_till_done()

assert "Template loop detected" in caplog.text

Expand Down
Loading