Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion homeassistant/components/zone/zone.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def async_active_zone(hass, latitude, longitude, radius=0):
return closest


def in_zone(zone, latitude, longitude, radius=0):
def in_zone(zone, latitude, longitude, radius=0) -> bool:
"""Test if given latitude, longitude is in given zone.
Async friendly.
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ def __init__(self, entity_id: str, state: Any,
"State max length is 255 characters.").format(entity_id))

self.entity_id = entity_id.lower()
self.state = state
self.state = state # type: str
self.attributes = MappingProxyType(attributes or {})
self.last_updated = last_updated or dt_util.utcnow()
self.last_changed = last_changed or self.last_updated
Expand Down
141 changes: 92 additions & 49 deletions homeassistant/helpers/condition.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Offer reusable conditions."""
from datetime import timedelta
from datetime import datetime, timedelta
import functools as ft
import logging
import sys
from typing import Callable, Container, Optional, Union, cast

from homeassistant.helpers.typing import ConfigType
from homeassistant.helpers.template import Template
from homeassistant.helpers.typing import ConfigType, TemplateVarsType

from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, State
from homeassistant.components import zone as zone_cmp
from homeassistant.const import (
ATTR_GPS_ACCURACY, ATTR_LATITUDE, ATTR_LONGITUDE,
Expand All @@ -29,25 +31,30 @@
# pylint: disable=invalid-name


def _threaded_factory(async_factory):
def _threaded_factory(async_factory:
Callable[[ConfigType, bool], Callable[..., bool]]) \
-> Callable[[ConfigType, bool], Callable[..., bool]]:
"""Create threaded versions of async factories."""
@ft.wraps(async_factory)
def factory(config, config_validation=True):
def factory(config: ConfigType,
config_validation: bool = True) -> Callable[..., bool]:
"""Threaded factory."""
async_check = async_factory(config, config_validation)

def condition_if(hass, variables=None):
def condition_if(hass: HomeAssistant,
variables: TemplateVarsType = None) -> bool:
"""Validate condition."""
return run_callback_threadsafe(
return cast(bool, run_callback_threadsafe(
hass.loop, async_check, hass, variables,
).result()
).result())

return condition_if

return factory


def async_from_config(config: ConfigType, config_validation: bool = True):
def async_from_config(config: ConfigType,
config_validation: bool = True) -> Callable[..., bool]:
"""Turn a condition configuration into a method.

Should be run on the event loop.
Expand All @@ -64,20 +71,22 @@ def async_from_config(config: ConfigType, config_validation: bool = True):
raise HomeAssistantError('Invalid condition "{}" specified {}'.format(
config.get(CONF_CONDITION), config))

return factory(config, config_validation)
return cast(Callable[..., bool], factory(config, config_validation))


from_config = _threaded_factory(async_from_config)


def async_and_from_config(config: ConfigType, config_validation: bool = True):
def async_and_from_config(config: ConfigType,
config_validation: bool = True) \
-> Callable[..., bool]:
"""Create multi condition matcher using 'AND'."""
if config_validation:
config = cv.AND_CONDITION_SCHEMA(config)
checks = None

def if_and_condition(hass: HomeAssistant,
variables=None) -> bool:
variables: TemplateVarsType = None) -> bool:
"""Test and condition."""
nonlocal checks

Expand All @@ -101,14 +110,16 @@ def if_and_condition(hass: HomeAssistant,
and_from_config = _threaded_factory(async_and_from_config)


def async_or_from_config(config: ConfigType, config_validation: bool = True):
def async_or_from_config(config: ConfigType,
config_validation: bool = True) \
-> Callable[..., bool]:
"""Create multi condition matcher using 'OR'."""
if config_validation:
config = cv.OR_CONDITION_SCHEMA(config)
checks = None

def if_or_condition(hass: HomeAssistant,
variables=None) -> bool:
variables: TemplateVarsType = None) -> bool:
"""Test and condition."""
nonlocal checks

Expand All @@ -131,17 +142,22 @@ def if_or_condition(hass: HomeAssistant,
or_from_config = _threaded_factory(async_or_from_config)


def numeric_state(hass: HomeAssistant, entity, below=None, above=None,
value_template=None, variables=None):
def numeric_state(hass: HomeAssistant, entity: Union[None, str, State],
below: Optional[float] = None, above: Optional[float] = None,
value_template: Optional[Template] = None,
variables: TemplateVarsType = None) -> bool:
"""Test a numeric state condition."""
return run_callback_threadsafe(
return cast(bool, run_callback_threadsafe(
hass.loop, async_numeric_state, hass, entity, below, above,
value_template, variables,
).result()
).result())


def async_numeric_state(hass: HomeAssistant, entity, below=None, above=None,
value_template=None, variables=None):
def async_numeric_state(hass: HomeAssistant, entity: Union[None, str, State],
below: Optional[float] = None,
above: Optional[float] = None,
value_template: Optional[Template] = None,
variables: TemplateVarsType = None) -> bool:
"""Test a numeric state condition."""
if isinstance(entity, str):
entity = hass.states.get(entity)
Expand All @@ -164,22 +180,24 @@ def async_numeric_state(hass: HomeAssistant, entity, below=None, above=None,
return False

try:
value = float(value)
fvalue = float(value)
except ValueError:
_LOGGER.warning("Value cannot be processed as a number: %s "
"(Offending entity: %s)", entity, value)
return False

if below is not None and value >= below:
if below is not None and fvalue >= below:
return False

if above is not None and value <= above:
if above is not None and fvalue <= above:
return False

return True


def async_numeric_state_from_config(config, config_validation=True):
def async_numeric_state_from_config(config: ConfigType,
config_validation: bool = True) \
-> Callable[..., bool]:
"""Wrap action method with state based condition."""
if config_validation:
config = cv.NUMERIC_STATE_CONDITION_SCHEMA(config)
Expand All @@ -188,7 +206,8 @@ def async_numeric_state_from_config(config, config_validation=True):
above = config.get(CONF_ABOVE)
value_template = config.get(CONF_VALUE_TEMPLATE)

def if_numeric_state(hass, variables=None):
def if_numeric_state(hass: HomeAssistant,
variables: TemplateVarsType = None) -> bool:
"""Test numeric state condition."""
if value_template is not None:
value_template.hass = hass
Expand All @@ -202,7 +221,8 @@ def if_numeric_state(hass, variables=None):
numeric_state_from_config = _threaded_factory(async_numeric_state_from_config)


def state(hass, entity, req_state, for_period=None):
def state(hass: HomeAssistant, entity: Union[None, str, State], req_state: str,
for_period: Optional[timedelta] = None) -> bool:
"""Test if state matches requirements.

Async friendly.
Expand All @@ -212,6 +232,7 @@ def state(hass, entity, req_state, for_period=None):

if entity is None:
return False
assert isinstance(entity, State)

is_state = entity.state == req_state

Expand All @@ -221,22 +242,26 @@ def state(hass, entity, req_state, for_period=None):
return dt_util.utcnow() - for_period > entity.last_changed


def state_from_config(config, config_validation=True):
def state_from_config(config: ConfigType,
config_validation: bool = True) -> Callable[..., bool]:
"""Wrap action method with state based condition."""
if config_validation:
config = cv.STATE_CONDITION_SCHEMA(config)
entity_id = config.get(CONF_ENTITY_ID)
req_state = config.get(CONF_STATE)
req_state = cast(str, config.get(CONF_STATE))
for_period = config.get('for')

def if_state(hass, variables=None):
def if_state(hass: HomeAssistant,
variables: TemplateVarsType = None) -> bool:
"""Test if condition."""
return state(hass, entity_id, req_state, for_period)

return if_state


def sun(hass, before=None, after=None, before_offset=None, after_offset=None):
def sun(hass: HomeAssistant, before: Optional[str] = None,
after: Optional[str] = None, before_offset: Optional[timedelta] = None,
after_offset: Optional[timedelta] = None) -> bool:
"""Test if current time matches sun requirements."""
utcnow = dt_util.utcnow()
today = dt_util.as_local(utcnow).date()
Expand All @@ -254,22 +279,27 @@ def sun(hass, before=None, after=None, before_offset=None, after_offset=None):
# There is no sunset today
return False

if before == SUN_EVENT_SUNRISE and utcnow > sunrise + before_offset:
if before == SUN_EVENT_SUNRISE and \
utcnow > cast(datetime, sunrise) + before_offset:
return False

if before == SUN_EVENT_SUNSET and utcnow > sunset + before_offset:
if before == SUN_EVENT_SUNSET and \
utcnow > cast(datetime, sunset) + before_offset:
return False

if after == SUN_EVENT_SUNRISE and utcnow < sunrise + after_offset:
if after == SUN_EVENT_SUNRISE and \
utcnow < cast(datetime, sunrise) + after_offset:
return False

if after == SUN_EVENT_SUNSET and utcnow < sunset + after_offset:
if after == SUN_EVENT_SUNSET and \
utcnow < cast(datetime, sunset) + after_offset:
return False

return True


def sun_from_config(config, config_validation=True):
def sun_from_config(config: ConfigType,
config_validation: bool = True) -> Callable[..., bool]:
"""Wrap action method with sun based condition."""
if config_validation:
config = cv.SUN_CONDITION_SCHEMA(config)
Expand All @@ -278,21 +308,24 @@ def sun_from_config(config, config_validation=True):
before_offset = config.get('before_offset')
after_offset = config.get('after_offset')

def time_if(hass, variables=None):
def time_if(hass: HomeAssistant,
variables: TemplateVarsType = None) -> bool:
"""Validate time based if-condition."""
return sun(hass, before, after, before_offset, after_offset)

return time_if


def template(hass, value_template, variables=None):
def template(hass: HomeAssistant, value_template: Template,
variables: TemplateVarsType = None) -> bool:
"""Test if template condition matches."""
return run_callback_threadsafe(
return cast(bool, run_callback_threadsafe(
hass.loop, async_template, hass, value_template, variables,
).result()
).result())


def async_template(hass, value_template, variables=None):
def async_template(hass: HomeAssistant, value_template: Template,
variables: TemplateVarsType = None) -> bool:
"""Test if template condition matches."""
try:
value = value_template.async_render(variables)
Expand All @@ -303,13 +336,16 @@ def async_template(hass, value_template, variables=None):
return value.lower() == 'true'


def async_template_from_config(config, config_validation=True):
def async_template_from_config(config: ConfigType,
config_validation: bool = True) \
-> Callable[..., bool]:
"""Wrap action method with state based condition."""
if config_validation:
config = cv.TEMPLATE_CONDITION_SCHEMA(config)
value_template = config.get(CONF_VALUE_TEMPLATE)
value_template = cast(Template, config.get(CONF_VALUE_TEMPLATE))

def template_if(hass, variables=None):
def template_if(hass: HomeAssistant,
variables: TemplateVarsType = None) -> bool:
"""Validate template based if-condition."""
value_template.hass = hass

Expand All @@ -321,7 +357,9 @@ def template_if(hass, variables=None):
template_from_config = _threaded_factory(async_template_from_config)


def time(before=None, after=None, weekday=None):
def time(before: Optional[dt_util.dt.time] = None,
after: Optional[dt_util.dt.time] = None,
weekday: Union[None, str, Container[str]] = None) -> bool:
"""Test if local time condition matches.

Handle the fact that time is continuous and we may be testing for
Expand Down Expand Up @@ -354,22 +392,25 @@ def time(before=None, after=None, weekday=None):
return True


def time_from_config(config, config_validation=True):
def time_from_config(config: ConfigType,
config_validation: bool = True) -> Callable[..., bool]:
"""Wrap action method with time based condition."""
if config_validation:
config = cv.TIME_CONDITION_SCHEMA(config)
before = config.get(CONF_BEFORE)
after = config.get(CONF_AFTER)
weekday = config.get(CONF_WEEKDAY)

def time_if(hass, variables=None):
def time_if(hass: HomeAssistant,
variables: TemplateVarsType = None) -> bool:
"""Validate time based if-condition."""
return time(before, after, weekday)

return time_if


def zone(hass, zone_ent, entity):
def zone(hass: HomeAssistant, zone_ent: Union[None, str, State],
entity: Union[None, str, State]) -> bool:
"""Test if zone-condition matches.

Async friendly.
Expand All @@ -396,14 +437,16 @@ def zone(hass, zone_ent, entity):
entity.attributes.get(ATTR_GPS_ACCURACY, 0))


def zone_from_config(config, config_validation=True):
def zone_from_config(config: ConfigType,
config_validation: bool = True) -> Callable[..., bool]:
"""Wrap action method with zone based condition."""
if config_validation:
config = cv.ZONE_CONDITION_SCHEMA(config)
entity_id = config.get(CONF_ENTITY_ID)
zone_entity_id = config.get(CONF_ZONE)

def if_in_zone(hass, variables=None):
def if_in_zone(hass: HomeAssistant,
variables: TemplateVarsType = None) -> bool:
"""Test if condition."""
return zone(hass, zone_entity_id, entity_id)

Expand Down
Loading