Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions homeassistant/components/websocket_api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,22 +239,32 @@ def handle_ping(hass, connection, msg):
connection.send_message(pong_message(msg["id"]))


@callback
@decorators.websocket_command(
{
vol.Required("type"): "render_template",
vol.Required("template"): str,
vol.Optional("entity_ids"): cv.entity_ids,
vol.Optional("variables"): dict,
vol.Optional("timeout"): vol.Coerce(float),
}
)
def handle_render_template(hass, connection, msg):
@decorators.async_response
async def handle_render_template(hass, connection, msg):
"""Handle render_template command."""
template_str = msg["template"]
template = Template(template_str, hass)
variables = msg.get("variables")
timeout = msg.get("timeout")
info = None

if timeout and await template.async_render_will_timeout(timeout):
connection.send_error(
msg["id"],
const.ERR_TEMPLATE_ERROR,
f"Exceeded maximum execution time of {timeout}s",
)
return

@callback
def _template_listener(event, updates):
nonlocal info
Expand Down
50 changes: 50 additions & 0 deletions homeassistant/helpers/template.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Template helper methods for rendering strings with Home Assistant data."""
import asyncio
import base64
import collections.abc
from datetime import datetime, timedelta
Expand Down Expand Up @@ -36,6 +37,7 @@
from homeassistant.loader import bind_hass
from homeassistant.util import convert, dt as dt_util, location as loc_util
from homeassistant.util.async_ import run_callback_threadsafe
from homeassistant.util.thread import ThreadWithException

# mypy: allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any
Expand Down Expand Up @@ -309,6 +311,54 @@ def async_render(self, variables: TemplateVarsType = None, **kwargs: Any) -> str
except jinja2.TemplateError as err:
raise TemplateError(err) from err

async def async_render_will_timeout(
self, timeout: float, variables: TemplateVarsType = None, **kwargs: Any
) -> bool:
"""Check to see if rendering a template will timeout during render.

This is intended to check for expensive templates
that will make the system unstable. The template
is rendered in the executor to ensure it does not
tie up the event loop.

This function is not a security control and is only
intended to be used as a safety check when testing
templates.

This method must be run in the event loop.
"""
assert self.hass

if self.is_static:
return False

compiled = self._compiled or self._ensure_compiled()

if variables is not None:
kwargs.update(variables)

finish_event = asyncio.Event()

def _render_template():
try:
compiled.render(kwargs)
except TimeoutError:
pass
finally:
run_callback_threadsafe(self.hass.loop, finish_event.set)

try:
template_render_thread = ThreadWithException(target=_render_template)
template_render_thread.start()
await asyncio.wait_for(finish_event.wait(), timeout=timeout)
except asyncio.TimeoutError:
template_render_thread.raise_exc(TimeoutError)
return True
finally:
template_render_thread.join()

return False

@callback
def async_render_to_info(
self, variables: TemplateVarsType = None, **kwargs: Any
Expand Down
33 changes: 33 additions & 0 deletions homeassistant/util/thread.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Threading util helpers."""
import ctypes
import inspect
import sys
import threading
from typing import Any
Expand All @@ -24,3 +26,34 @@ def run(*args: Any, **kwargs: Any) -> None:
sys.excepthook(*sys.exc_info())

threading.Thread.run = run # type: ignore


def _async_raise(tid: int, exctype: Any) -> None:
"""Raise an exception in the threads with id tid."""
if not inspect.isclass(exctype):
raise TypeError("Only types can be raised (not instances)")

c_tid = ctypes.c_long(tid)
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(c_tid, ctypes.py_object(exctype))

if res == 1:
return

# "if it returns a number greater than one, you're in trouble,
# and you should call it again with exc=NULL to revert the effect"
ctypes.pythonapi.PyThreadState_SetAsyncExc(c_tid, None)
raise SystemError("PyThreadState_SetAsyncExc failed")


class ThreadWithException(threading.Thread):
"""A thread class that supports raising exception in the thread from another thread.

Based on
https://stackoverflow.com/questions/323972/is-there-any-way-to-kill-a-thread/49877671

"""

def raise_exc(self, exctype: Any) -> None:
"""Raise the given exception type in the context of this thread."""
assert self.ident
_async_raise(self.ident, exctype)
47 changes: 34 additions & 13 deletions tests/components/websocket_api/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,7 @@ async def test_subscribe_unsubscribe_events_state_changed(
assert msg["event"]["data"]["entity_id"] == "light.permitted"


async def test_render_template_renders_template(
hass, websocket_client, hass_admin_user
):
async def test_render_template_renders_template(hass, websocket_client):
"""Test simple template is rendered and updated."""
hass.states.async_set("light.test", "on")

Expand Down Expand Up @@ -437,7 +435,7 @@ async def test_render_template_renders_template(


async def test_render_template_manual_entity_ids_no_longer_needed(
hass, websocket_client, hass_admin_user
hass, websocket_client
):
"""Test that updates to specified entity ids cause a template rerender."""
hass.states.async_set("light.test", "on")
Expand Down Expand Up @@ -475,9 +473,7 @@ async def test_render_template_manual_entity_ids_no_longer_needed(
}


async def test_render_template_with_error(
hass, websocket_client, hass_admin_user, caplog
):
async def test_render_template_with_error(hass, websocket_client, caplog):
"""Test a template with an error."""
await websocket_client.send_json(
{"id": 5, "type": "render_template", "template": "{{ my_unknown_var() + 1 }}"}
Expand All @@ -492,9 +488,7 @@ async def test_render_template_with_error(
assert "TemplateError" not in caplog.text


async def test_render_template_with_delayed_error(
hass, websocket_client, hass_admin_user, caplog
):
async def test_render_template_with_delayed_error(hass, websocket_client, caplog):
"""Test a template with an error that only happens after a state change."""
hass.states.async_set("sensor.test", "on")
await hass.async_block_till_done()
Expand Down Expand Up @@ -539,9 +533,36 @@ async def test_render_template_with_delayed_error(
assert "TemplateError" not in caplog.text


async def test_render_template_returns_with_match_all(
hass, websocket_client, hass_admin_user
):
async def test_render_template_with_timeout(hass, websocket_client, caplog):
"""Test a template that will timeout."""

slow_template_str = """
{% for var in range(1000) -%}
{% for var in range(1000) -%}
{{ var }}
{%- endfor %}
{%- endfor %}
"""

await websocket_client.send_json(
{
"id": 5,
"type": "render_template",
"timeout": 0.000001,
"template": slow_template_str,
}
)

msg = await websocket_client.receive_json()
assert msg["id"] == 5
assert msg["type"] == const.TYPE_RESULT
assert not msg["success"]
assert msg["error"]["code"] == const.ERR_TEMPLATE_ERROR

assert "TemplateError" not in caplog.text


async def test_render_template_returns_with_match_all(hass, websocket_client):
"""Test that a template that would match with all entities still return success."""
await websocket_client.send_json(
{"id": 5, "type": "render_template", "template": "State is: {{ 42 }}"}
Expand Down
28 changes: 28 additions & 0 deletions tests/helpers/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -2455,3 +2455,31 @@ async def test_lifecycle(hass):
assert info.filter("sensor.sensor1") is False
assert info.filter_lifecycle("sensor.new") is True
assert info.filter_lifecycle("sensor.removed") is True


async def test_template_timeout(hass):
"""Test to see if a template will timeout."""
for i in range(2):
hass.states.async_set(f"sensor.sensor{i}", "on")

tmp = template.Template("{{ states | count }}", hass)
assert await tmp.async_render_will_timeout(3) is False

tmp2 = template.Template("{{ error_invalid + 1 }}", hass)
assert await tmp2.async_render_will_timeout(3) is False

tmp3 = template.Template("static", hass)
assert await tmp3.async_render_will_timeout(3) is False

tmp4 = template.Template("{{ var1 }}", hass)
assert await tmp4.async_render_will_timeout(3, {"var1": "ok"}) is False

slow_template_str = """
{% for var in range(1000) -%}
{% for var in range(1000) -%}
{{ var }}
{%- endfor %}
{%- endfor %}
"""
tmp5 = template.Template(slow_template_str, hass)
assert await tmp5.async_render_will_timeout(0.000001) is True
55 changes: 55 additions & 0 deletions tests/util/test_thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Test Home Assistant thread utils."""

import asyncio

import pytest

from homeassistant.util.async_ import run_callback_threadsafe
from homeassistant.util.thread import ThreadWithException


async def test_thread_with_exception_invalid(hass):
"""Test throwing an invalid thread exception."""

finish_event = asyncio.Event()

def _do_nothing(*_):
run_callback_threadsafe(hass.loop, finish_event.set)

test_thread = ThreadWithException(target=_do_nothing)
test_thread.start()
await asyncio.wait_for(finish_event.wait(), timeout=0.1)

with pytest.raises(TypeError):
test_thread.raise_exc(_EmptyClass())
test_thread.join()


async def test_thread_not_started(hass):
"""Test throwing when the thread is not started."""

test_thread = ThreadWithException(target=lambda *_: None)

with pytest.raises(AssertionError):
test_thread.raise_exc(TimeoutError)


async def test_thread_fails_raise(hass):
"""Test throwing after already ended."""

finish_event = asyncio.Event()

def _do_nothing(*_):
run_callback_threadsafe(hass.loop, finish_event.set)

test_thread = ThreadWithException(target=_do_nothing)
test_thread.start()
await asyncio.wait_for(finish_event.wait(), timeout=0.1)
test_thread.join()

with pytest.raises(SystemError):
test_thread.raise_exc(ValueError)


class _EmptyClass:
"""An empty class."""