Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 14 additions & 88 deletions homeassistant/components/device_automation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from functools import wraps
import logging
from types import ModuleType
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Protocol, Union, overload
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Union, overload

import voluptuous as vol
import voluptuous_serialize

from homeassistant.components import websocket_api
from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM
from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant
from homeassistant.core import HomeAssistant
from homeassistant.helpers import (
config_validation as cv,
device_registry as dr,
Expand All @@ -28,11 +28,16 @@
from .exceptions import DeviceNotFound, InvalidDeviceAutomationConfig

if TYPE_CHECKING:
from homeassistant.components.automation import (
AutomationActionType,
AutomationTriggerInfo,
)
from homeassistant.helpers import condition
from .action import DeviceAutomationActionProtocol
from .condition import DeviceAutomationConditionProtocol
from .trigger import DeviceAutomationTriggerProtocol

DeviceAutomationPlatformType = Union[
ModuleType,
DeviceAutomationTriggerProtocol,
DeviceAutomationConditionProtocol,
DeviceAutomationActionProtocol,
]

# mypy: allow-untyped-calls, allow-untyped-defs

Expand Down Expand Up @@ -83,77 +88,6 @@ class DeviceAutomationType(Enum):
}


class DeviceAutomationTriggerProtocol(Protocol):
"""Define the format of device_trigger modules.

Each module must define either TRIGGER_SCHEMA or async_validate_trigger_config.
"""

TRIGGER_SCHEMA: vol.Schema

async def async_validate_trigger_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
raise NotImplementedError

async def async_attach_trigger(
self,
hass: HomeAssistant,
config: ConfigType,
action: AutomationActionType,
automation_info: AutomationTriggerInfo,
) -> CALLBACK_TYPE:
"""Attach a trigger."""
raise NotImplementedError


class DeviceAutomationConditionProtocol(Protocol):
"""Define the format of device_condition modules.

Each module must define either CONDITION_SCHEMA or async_validate_condition_config.
"""

CONDITION_SCHEMA: vol.Schema

async def async_validate_condition_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
raise NotImplementedError

def async_condition_from_config(
self, hass: HomeAssistant, config: ConfigType
) -> condition.ConditionCheckerType:
"""Evaluate state based on configuration."""
raise NotImplementedError


class DeviceAutomationActionProtocol(Protocol):
"""Define the format of device_action modules.

Each module must define either ACTION_SCHEMA or async_validate_action_config.
"""

ACTION_SCHEMA: vol.Schema

async def async_validate_action_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
raise NotImplementedError

async def async_call_action_from_config(
self,
hass: HomeAssistant,
config: ConfigType,
variables: dict[str, Any],
context: Context | None,
) -> None:
"""Execute a device action."""
raise NotImplementedError


@bind_hass
async def async_get_device_automations(
hass: HomeAssistant,
Expand Down Expand Up @@ -193,14 +127,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True


DeviceAutomationPlatformType = Union[
ModuleType,
DeviceAutomationTriggerProtocol,
DeviceAutomationConditionProtocol,
DeviceAutomationActionProtocol,
]


@overload
async def async_get_device_automation_platform( # noqa: D103
hass: HomeAssistant,
Expand Down Expand Up @@ -231,13 +157,13 @@ async def async_get_device_automation_platform( # noqa: D103
@overload
async def async_get_device_automation_platform( # noqa: D103
hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType | str
) -> DeviceAutomationPlatformType:
) -> "DeviceAutomationPlatformType":
...


async def async_get_device_automation_platform(
hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType | str
) -> DeviceAutomationPlatformType:
) -> "DeviceAutomationPlatformType":
"""Load device automation platform for integration.

Throws InvalidDeviceAutomationConfig if the integration is not found or does not support device automation.
Expand Down
68 changes: 68 additions & 0 deletions homeassistant/components/device_automation/action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Device action validator."""
from __future__ import annotations

from typing import Any, Protocol, cast

import voluptuous as vol

from homeassistant.const import CONF_DOMAIN
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers.typing import ConfigType

from . import DeviceAutomationType, async_get_device_automation_platform
from .exceptions import InvalidDeviceAutomationConfig


class DeviceAutomationActionProtocol(Protocol):
"""Define the format of device_action modules.

Each module must define either ACTION_SCHEMA or async_validate_action_config.
"""

ACTION_SCHEMA: vol.Schema

async def async_validate_action_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
raise NotImplementedError

async def async_call_action_from_config(
self,
hass: HomeAssistant,
config: ConfigType,
variables: dict[str, Any],
context: Context | None,
) -> None:
"""Execute a device action."""
raise NotImplementedError


async def async_validate_action_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
try:
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], DeviceAutomationType.ACTION
)
if hasattr(platform, "async_validate_action_config"):
return await platform.async_validate_action_config(hass, config)
return cast(ConfigType, platform.ACTION_SCHEMA(config))
except InvalidDeviceAutomationConfig as err:
raise vol.Invalid(str(err) or "Invalid action configuration") from err


async def async_call_action_from_config(
hass: HomeAssistant,
config: ConfigType,
variables: dict[str, Any],
context: Context | None,
) -> None:
"""Execute a device action."""
platform = await async_get_device_automation_platform(
hass,
config[CONF_DOMAIN],
DeviceAutomationType.ACTION,
)
await platform.async_call_action_from_config(hass, config, variables, context)
64 changes: 64 additions & 0 deletions homeassistant/components/device_automation/condition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Validate device conditions."""
from __future__ import annotations

from typing import TYPE_CHECKING, Protocol, cast

import voluptuous as vol

from homeassistant.const import CONF_DOMAIN
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.typing import ConfigType

from . import DeviceAutomationType, async_get_device_automation_platform
from .exceptions import InvalidDeviceAutomationConfig

if TYPE_CHECKING:
from homeassistant.helpers import condition


class DeviceAutomationConditionProtocol(Protocol):
"""Define the format of device_condition modules.

Each module must define either CONDITION_SCHEMA or async_validate_condition_config.
"""

CONDITION_SCHEMA: vol.Schema

async def async_validate_condition_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
raise NotImplementedError

def async_condition_from_config(
self, hass: HomeAssistant, config: ConfigType
) -> condition.ConditionCheckerType:
"""Evaluate state based on configuration."""
raise NotImplementedError


async def async_validate_condition_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate device condition config."""
try:
config = cv.DEVICE_CONDITION_SCHEMA(config)
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
)
if hasattr(platform, "async_validate_condition_config"):
return await platform.async_validate_condition_config(hass, config)
return cast(ConfigType, platform.CONDITION_SCHEMA(config))
except InvalidDeviceAutomationConfig as err:
raise vol.Invalid(str(err) or "Invalid condition configuration") from err


async def async_condition_from_config(
hass: HomeAssistant, config: ConfigType
) -> condition.ConditionCheckerType:
"""Test a device condition."""
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
)
return platform.async_condition_from_config(hass, config)
38 changes: 31 additions & 7 deletions homeassistant/components/device_automation/trigger.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Offer device oriented automation."""
from typing import cast
from typing import Protocol, cast

import voluptuous as vol

Expand All @@ -21,17 +21,41 @@
TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)


class DeviceAutomationTriggerProtocol(Protocol):
"""Define the format of device_trigger modules.

Each module must define either TRIGGER_SCHEMA or async_validate_trigger_config.
"""

TRIGGER_SCHEMA: vol.Schema

async def async_validate_trigger_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
raise NotImplementedError

async def async_attach_trigger(
self,
hass: HomeAssistant,
config: ConfigType,
action: AutomationActionType,
automation_info: AutomationTriggerInfo,
) -> CALLBACK_TYPE:
"""Attach a trigger."""
raise NotImplementedError


async def async_validate_trigger_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER
)
if not hasattr(platform, "async_validate_trigger_config"):
return cast(ConfigType, platform.TRIGGER_SCHEMA(config))

try:
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER
)
if not hasattr(platform, "async_validate_trigger_config"):
return cast(ConfigType, platform.TRIGGER_SCHEMA(config))
return await platform.async_validate_trigger_config(hass, config)
except InvalidDeviceAutomationConfig as err:
raise vol.Invalid(str(err) or "Invalid trigger configuration") from err
Expand Down
21 changes: 5 additions & 16 deletions homeassistant/helpers/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@
from typing import Any, cast

from homeassistant.components import zone as zone_cmp
from homeassistant.components.device_automation import (
DeviceAutomationType,
async_get_device_automation_platform,
)
from homeassistant.components.device_automation import condition as device_condition
from homeassistant.components.sensor import SensorDeviceClass
from homeassistant.const import (
ATTR_DEVICE_CLASS,
Expand All @@ -30,7 +27,6 @@
CONF_BELOW,
CONF_CONDITION,
CONF_DEVICE_ID,
CONF_DOMAIN,
CONF_ENTITY_ID,
CONF_ID,
CONF_STATE,
Expand Down Expand Up @@ -872,10 +868,8 @@ async def async_device_from_config(
hass: HomeAssistant, config: ConfigType
) -> ConditionCheckerType:
"""Test a device condition."""
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
)
return trace_condition_function(platform.async_condition_from_config(hass, config))
checker = await device_condition.async_condition_from_config(hass, config)
return trace_condition_function(checker)


async def async_trigger_from_config(
Expand Down Expand Up @@ -931,15 +925,10 @@ async def async_validate_condition_config(
sub_cond = await async_validate_condition_config(hass, sub_cond)
conditions.append(sub_cond)
config["conditions"] = conditions
return config

if condition == "device":
config = cv.DEVICE_CONDITION_SCHEMA(config)
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
)
if hasattr(platform, "async_validate_condition_config"):
return await platform.async_validate_condition_config(hass, config)
return cast(ConfigType, platform.CONDITION_SCHEMA(config))
return await device_condition.async_validate_condition_config(hass, config)

if condition in ("numeric_state", "state"):
validator = cast(
Expand Down
Loading