Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 8 additions & 19 deletions homeassistant/components/device_automation/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

from homeassistant.const import CONF_DOMAIN
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.trigger import (
TriggerActionType,
TriggerInfo,
TriggerProtocol,
)
from homeassistant.helpers.typing import ConfigType

from . import (
Expand All @@ -20,28 +24,13 @@
TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)


class DeviceAutomationTriggerProtocol(Protocol):
class DeviceAutomationTriggerProtocol(TriggerProtocol, Protocol):
Comment thread
epenet marked this conversation as resolved.
"""Define the format of device_trigger modules.

Each module must define either TRIGGER_SCHEMA or async_validate_trigger_config.
Each module must define either TRIGGER_SCHEMA or async_validate_trigger_config
from TriggerProtocol.
"""

TRIGGER_SCHEMA: vol.Schema

async def async_validate_trigger_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""

async def async_attach_trigger(
self,
hass: HomeAssistant,
config: ConfigType,
action: TriggerActionType,
trigger_info: TriggerInfo,
) -> CALLBACK_TYPE:
"""Attach a trigger."""

async def async_get_trigger_capabilities(
self, hass: HomeAssistant, config: ConfigType
) -> dict[str, vol.Schema]:
Expand Down
32 changes: 25 additions & 7 deletions homeassistant/helpers/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dataclasses import dataclass, field
import functools
import logging
from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast
from typing import Any, Protocol, TypedDict, cast

import voluptuous as vol

Expand All @@ -31,11 +31,6 @@

from .typing import ConfigType, TemplateVarsType

if TYPE_CHECKING:
from homeassistant.components.device_automation.trigger import (
DeviceAutomationTriggerProtocol,
)

_PLATFORM_ALIASES = {
"device_automation": ("device",),
"homeassistant": ("event", "numeric_state", "state", "time_pattern", "time"),
Expand All @@ -44,6 +39,29 @@
DATA_PLUGGABLE_ACTIONS = "pluggable_actions"


class TriggerProtocol(Protocol):
"""Define the format of trigger modules.

Each module must define either TRIGGER_SCHEMA or async_validate_trigger_config.
"""

TRIGGER_SCHEMA: vol.Schema

async def async_validate_trigger_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""

async def async_attach_trigger(
self,
hass: HomeAssistant,
config: ConfigType,
action: TriggerActionType,
trigger_info: TriggerInfo,
) -> CALLBACK_TYPE:
"""Attach a trigger."""


class TriggerActionType(Protocol):
"""Protocol type for trigger action callback."""

Expand Down Expand Up @@ -191,7 +209,7 @@ async def async_run(

async def _async_get_trigger_platform(
hass: HomeAssistant, config: ConfigType
) -> DeviceAutomationTriggerProtocol:
) -> TriggerProtocol:
platform_and_sub_type = config[CONF_PLATFORM].split(".")
platform = platform_and_sub_type[0]
for alias, triggers in _PLATFORM_ALIASES.items():
Expand Down