Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion homeassistant/components/automation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
CONF_CONDITION,
CONF_INITIAL_STATE,
CONF_TRIGGER,
CONF_TRIGGER_VARIABLES,
DEFAULT_INITIAL_STATE,
DOMAIN,
LOGGER,
Expand Down Expand Up @@ -221,6 +222,7 @@ def __init__(
action_script,
initial_state,
variables,
trigger_variables,
):
"""Initialize an automation entity."""
self._id = automation_id
Expand All @@ -236,6 +238,7 @@ def __init__(
self._referenced_devices: Optional[Set[str]] = None
self._logger = LOGGER
self._variables: ScriptVariables = variables
self._trigger_variables: ScriptVariables = trigger_variables

@property
def name(self):
Expand Down Expand Up @@ -465,6 +468,16 @@ async def _async_attach_triggers(
def log_cb(level, msg, **kwargs):
self._logger.log(level, "%s %s", msg, self._name, **kwargs)

variables = None
if self._trigger_variables:
try:
variables = self._trigger_variables.async_render(
cast(HomeAssistant, self.hass), None, limited=True
)
except template.TemplateError as err:
self._logger.error("Error rendering trigger variables: %s", err)
return None

return await async_initialize_triggers(
cast(HomeAssistant, self.hass),
self._trigger_config,
Expand All @@ -473,6 +486,7 @@ def log_cb(level, msg, **kwargs):
self._name,
log_cb,
home_assistant_start,
variables,
)

@property
Expand Down Expand Up @@ -550,14 +564,27 @@ async def _async_process_config(
else:
cond_func = None

# Add trigger variables to variables
variables = None
if CONF_TRIGGER_VARIABLES in config_block:
variables = ScriptVariables(
dict(config_block[CONF_TRIGGER_VARIABLES].as_dict())
)
if CONF_VARIABLES in config_block:
if variables:
variables.variables.update(config_block[CONF_VARIABLES].as_dict())
else:
variables = config_block[CONF_VARIABLES]

entity = AutomationEntity(
automation_id,
name,
config_block[CONF_TRIGGER],
cond_func,
action_script,
initial_state,
config_block.get(CONF_VARIABLES),
variables,
config_block.get(CONF_TRIGGER_VARIABLES),
)

entities.append(entity)
Expand Down
2 changes: 2 additions & 0 deletions homeassistant/components/automation/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
CONF_HIDE_ENTITY,
CONF_INITIAL_STATE,
CONF_TRIGGER,
CONF_TRIGGER_VARIABLES,
DOMAIN,
)
from .helpers import async_get_blueprints
Expand All @@ -44,6 +45,7 @@
vol.Required(CONF_TRIGGER): cv.TRIGGER_SCHEMA,
vol.Optional(CONF_CONDITION): _CONDITION_SCHEMA,
vol.Optional(CONF_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA,
vol.Optional(CONF_TRIGGER_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA,
vol.Required(CONF_ACTION): cv.SCRIPT_SCHEMA,
},
script.SCRIPT_MODE_SINGLE,
Expand Down
1 change: 1 addition & 0 deletions homeassistant/components/automation/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
CONF_CONDITION = "condition"
CONF_ACTION = "action"
CONF_TRIGGER = "trigger"
CONF_TRIGGER_VARIABLES = "trigger_variables"
DOMAIN = "automation"

CONF_DESCRIPTION = "description"
Expand Down
2 changes: 2 additions & 0 deletions homeassistant/components/mqtt/device_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,14 @@ class TriggerInstance:
async def async_attach_trigger(self):
"""Attach MQTT trigger."""
mqtt_config = {
mqtt_trigger.CONF_PLATFORM: mqtt.DOMAIN,
mqtt_trigger.CONF_TOPIC: self.trigger.topic,
mqtt_trigger.CONF_ENCODING: DEFAULT_ENCODING,
mqtt_trigger.CONF_QOS: self.trigger.qos,
}
if self.trigger.payload:
mqtt_config[CONF_PAYLOAD] = self.trigger.payload
mqtt_config = mqtt_trigger.TRIGGER_SCHEMA(mqtt_config)

if self.remove:
self.remove()
Expand Down
25 changes: 22 additions & 3 deletions homeassistant/components/mqtt/trigger.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Offer MQTT listening automation rules."""
import json
import logging

import voluptuous as vol

from homeassistant.const import CONF_PAYLOAD, CONF_PLATFORM
from homeassistant.core import HassJob, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers import config_validation as cv, template

from .. import mqtt

Expand All @@ -20,15 +21,17 @@
TRIGGER_SCHEMA = vol.Schema(
{
vol.Required(CONF_PLATFORM): mqtt.DOMAIN,
vol.Required(CONF_TOPIC): mqtt.util.valid_subscribe_topic,
vol.Optional(CONF_PAYLOAD): cv.string,
vol.Required(CONF_TOPIC): mqtt.util.valid_subscribe_topic_template,
vol.Optional(CONF_PAYLOAD): cv.template,
vol.Optional(CONF_ENCODING, default=DEFAULT_ENCODING): cv.string,
vol.Optional(CONF_QOS, default=DEFAULT_QOS): vol.All(
vol.Coerce(int), vol.In([0, 1, 2])
),
}
)

_LOGGER = logging.getLogger(__name__)


async def async_attach_trigger(hass, config, action, automation_info):
"""Listen for state changes based on configuration."""
Expand All @@ -37,6 +40,18 @@ async def async_attach_trigger(hass, config, action, automation_info):
encoding = config[CONF_ENCODING] or None
qos = config[CONF_QOS]
job = HassJob(action)
variables = None
if automation_info:
variables = automation_info.get("variables")

template.attach(hass, payload)
if payload:
payload = payload.async_render(variables, limited=True)
Comment thread
balloob marked this conversation as resolved.

template.attach(hass, topic)
if isinstance(topic, template.Template):
topic = topic.async_render(variables, limited=True)
topic = mqtt.util.valid_subscribe_topic(topic)

@callback
def mqtt_automation_listener(mqttmsg):
Expand All @@ -57,6 +72,10 @@ def mqtt_automation_listener(mqttmsg):

hass.async_run_hass_job(job, {"trigger": data})

_LOGGER.debug(
"Attaching MQTT trigger for topic: '%s', payload: '%s'", topic, payload
)

remove = await mqtt.async_subscribe(
hass, topic, mqtt_automation_listener, encoding=encoding, qos=qos
)
Expand Down
12 changes: 11 additions & 1 deletion homeassistant/components/mqtt/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import voluptuous as vol

from homeassistant.const import CONF_PAYLOAD
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers import config_validation as cv, template

from .const import (
ATTR_PAYLOAD,
Expand Down Expand Up @@ -61,6 +61,16 @@ def valid_subscribe_topic(value: Any) -> str:
return value


def valid_subscribe_topic_template(value: Any) -> template.Template:
"""Validate either a jinja2 template or a valid MQTT subscription topic."""
tpl = template.Template(value)

if tpl.is_static:
valid_subscribe_topic(value)

return tpl


def valid_publish_topic(value: Any) -> str:
"""Validate that we can publish using this MQTT topic."""
value = valid_topic(value)
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/helpers/config_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def dynamic_template(value: Optional[Any]) -> template_helper.Template:
if isinstance(value, (list, dict, template_helper.Template)):
raise vol.Invalid("template value should be a string")
if not template_helper.is_template_string(str(value)):
raise vol.Invalid("template value does not contain a dynmamic template")
raise vol.Invalid("template value does not contain a dynamic template")

template_value = template_helper.Template(str(value)) # type: ignore
try:
Expand Down
5 changes: 4 additions & 1 deletion homeassistant/helpers/script_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def async_render(
run_variables: Optional[Mapping[str, Any]],
*,
render_as_defaults: bool = True,
limited: bool = False,
) -> Dict[str, Any]:
"""Render script variables.

Expand Down Expand Up @@ -55,7 +56,9 @@ def async_render(
if render_as_defaults and key in rendered_variables:
continue

rendered_variables[key] = template.render_complex(value, rendered_variables)
rendered_variables[key] = template.render_complex(
value, rendered_variables, limited
)

return rendered_variables

Expand Down
54 changes: 47 additions & 7 deletions homeassistant/helpers/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def attach(hass: HomeAssistantType, obj: Any) -> None:
obj.hass = hass


def render_complex(value: Any, variables: TemplateVarsType = None) -> Any:
def render_complex(
value: Any, variables: TemplateVarsType = None, limited: bool = False
) -> Any:
"""Recursive template creator helper function."""
if isinstance(value, list):
return [render_complex(item, variables) for item in value]
Expand All @@ -94,7 +96,7 @@ def render_complex(value: Any, variables: TemplateVarsType = None) -> Any:
for key, item in value.items()
}
if isinstance(value, Template):
return value.async_render(variables)
return value.async_render(variables, limited=limited)

return value

Expand Down Expand Up @@ -279,6 +281,7 @@ class Template:
"is_static",
"_compiled_code",
"_compiled",
"_limited",
)

def __init__(self, template, hass=None):
Expand All @@ -291,10 +294,11 @@ def __init__(self, template, hass=None):
self._compiled: Optional[Template] = None
self.hass = hass
self.is_static = not is_template_string(template)
self._limited = None

@property
def _env(self) -> "TemplateEnvironment":
if self.hass is None:
if self.hass is None or self._limited:
return _NO_HASS_ENV
ret: Optional[TemplateEnvironment] = self.hass.data.get(_ENVIRONMENT)
if ret is None:
Expand All @@ -315,36 +319,43 @@ def render(
self,
variables: TemplateVarsType = None,
parse_result: bool = True,
limited: bool = False,
**kwargs: Any,
) -> Any:
"""Render given template."""
"""Render given template.

If limited is True, the template is not allowed to access any function or filter depending on hass or the state machine.
"""
if self.is_static:
if self.hass.config.legacy_templates or not parse_result:
return self.template
return self._parse_result(self.template)

return run_callback_threadsafe(
self.hass.loop,
partial(self.async_render, variables, parse_result, **kwargs),
partial(self.async_render, variables, parse_result, limited, **kwargs),
).result()

@callback
def async_render(
self,
variables: TemplateVarsType = None,
parse_result: bool = True,
limited: bool = False,
Comment thread
balloob marked this conversation as resolved.
**kwargs: Any,
) -> Any:
"""Render given template.

This method must be run in the event loop.

If limited is True, the template is not allowed to access any function or filter depending on hass or the state machine.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's nice to break long strings around 88 characters.

"""
if self.is_static:
if self.hass.config.legacy_templates or not parse_result:
return self.template
return self._parse_result(self.template)

compiled = self._compiled or self._ensure_compiled()
compiled = self._compiled or self._ensure_compiled(limited)

if variables is not None:
kwargs.update(variables)
Expand Down Expand Up @@ -519,12 +530,16 @@ def async_render_with_possible_json_value(
)
return value if error_value is _SENTINEL else error_value

def _ensure_compiled(self) -> "Template":
def _ensure_compiled(self, limited: bool = False) -> "Template":
"""Bind a template to a specific hass instance."""
self.ensure_valid()

assert self.hass is not None, "hass variable not set on template"
assert (
self._limited is None or self._limited == limited
), "can't change between limited and non limited template"

self._limited = limited
env = self._env

self._compiled = cast(
Expand Down Expand Up @@ -1352,6 +1367,31 @@ def __init__(self, hass):
self.globals["strptime"] = strptime
self.globals["urlencode"] = urlencode
if hass is None:

def unsupported(name):
def warn_unsupported(*args, **kwargs):
raise TemplateError(
f"Use of '{name}' is not supported in limited templates"
)

return warn_unsupported

hass_globals = [
"closest",
"distance",
"expand",
"is_state",
"is_state_attr",
"state_attr",
"states",
"utcnow",
"now",
]
hass_filters = ["closest", "expand"]
for glob in hass_globals:
self.globals[glob] = unsupported(glob)
for filt in hass_filters:
self.filters[filt] = unsupported(filt)
return

# We mark these as a context functions to ensure they get
Expand Down
5 changes: 4 additions & 1 deletion homeassistant/helpers/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from homeassistant.const import CONF_PLATFORM
from homeassistant.core import CALLBACK_TYPE, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.typing import ConfigType, HomeAssistantType
from homeassistant.loader import IntegrationNotFound, async_get_integration

Expand Down Expand Up @@ -79,7 +80,9 @@ async def async_initialize_triggers(
removes = []

for result in attach_results:
if isinstance(result, Exception):
if isinstance(result, HomeAssistantError):
log_cb(logging.ERROR, f"Got error '{result}' when setting up triggers for")
elif isinstance(result, Exception):
log_cb(logging.ERROR, "Error setting up trigger", exc_info=result)
elif result is None:
log_cb(
Expand Down
Loading