Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions homeassistant/components/zha/button.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""Support for ZHA button."""
from __future__ import annotations

import abc
import functools
import logging
from typing import Any

from homeassistant.components.button import ButtonDeviceClass, ButtonEntity
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ENTITY_CATEGORY_DIAGNOSTIC, Platform
from homeassistant.core import HomeAssistant
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity_platform import AddEntitiesCallback

from .core import discovery
from .core.const import CHANNEL_IDENTIFY, DATA_ZHA, SIGNAL_ADD_ENTITIES
from .core.registries import ZHA_ENTITIES
from .core.typing import ChannelType, ZhaDeviceType
from .entity import ZhaEntity

MULTI_MATCH = functools.partial(ZHA_ENTITIES.multipass_match, Platform.BUTTON)
DEFAULT_DURATION = 5 # seconds

_LOGGER = logging.getLogger(__name__)


async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up the Zigbee Home Automation button from config entry."""
entities_to_create = hass.data[DATA_ZHA][Platform.BUTTON]

unsub = async_dispatcher_connect(
hass,
SIGNAL_ADD_ENTITIES,
functools.partial(
discovery.async_add_entities,
async_add_entities,
entities_to_create,
update_before_add=False,
),
)
config_entry.async_on_unload(unsub)


class ZHAButton(ZhaEntity, ButtonEntity):
"""Defines a ZHA button."""

_command_name: str = None

def __init__(
self,
unique_id: str,
zha_device: ZhaDeviceType,
channels: list[ChannelType],
**kwargs,
) -> None:
"""Init this button."""
super().__init__(unique_id, zha_device, channels, **kwargs)
self._channel: ChannelType = channels[0]

@abc.abstractmethod
def get_args(self) -> list[Any]:
"""Return the arguments to use in the command."""

async def async_press(self) -> None:
"""Send out a update command."""
command = getattr(self._channel, self._command_name)
arguments = self.get_args()
await command(*arguments)


@MULTI_MATCH(channel_names=CHANNEL_IDENTIFY)
class ZHAIdentifyButton(ZHAButton):
"""Defines a ZHA identify button."""

@classmethod
def create_entity(
cls,
unique_id: str,
zha_device: ZhaDeviceType,
channels: list[ChannelType],
**kwargs,
) -> ZhaEntity | None:
"""Entity Factory.

Return entity if it is a supported configuration, otherwise return None
"""
platform_restrictions = ZHA_ENTITIES.single_device_matches[Platform.BUTTON]
device_restrictions = platform_restrictions[zha_device.ieee]
if CHANNEL_IDENTIFY in device_restrictions:
return None
device_restrictions.append(CHANNEL_IDENTIFY)
return cls(unique_id, zha_device, channels, **kwargs)

_attr_device_class: ButtonDeviceClass = ButtonDeviceClass.UPDATE
_attr_entity_category = ENTITY_CATEGORY_DIAGNOSTIC
_command_name = "identify"

def get_args(self) -> list[Any]:
"""Return the arguments to use in the command."""

return [DEFAULT_DURATION]
1 change: 1 addition & 0 deletions homeassistant/components/zha/core/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
PLATFORMS = (
Platform.ALARM_CONTROL_PANEL,
Platform.BINARY_SENSOR,
Platform.BUTTON,
Platform.CLIMATE,
Platform.COVER,
Platform.DEVICE_TRACKER,
Expand Down
2 changes: 2 additions & 0 deletions homeassistant/components/zha/core/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .. import ( # noqa: F401 pylint: disable=unused-import,
alarm_control_panel,
binary_sensor,
button,
climate,
cover,
device_tracker,
Expand Down Expand Up @@ -66,6 +67,7 @@ def discover_entities(self, channel_pool: zha_typing.ChannelPoolType) -> None:
self.discover_by_device_type(channel_pool)
self.discover_multi_entities(channel_pool)
self.discover_by_cluster_id(channel_pool)
zha_regs.ZHA_ENTITIES.clean_up()

@callback
def discover_by_device_type(self, channel_pool: zha_typing.ChannelPoolType) -> None:
Expand Down
10 changes: 10 additions & 0 deletions homeassistant/components/zha/core/registries.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from zigpy import zcl
import zigpy.profiles.zha
import zigpy.profiles.zll
from zigpy.types.named import EUI64

from homeassistant.const import Platform

Expand Down Expand Up @@ -228,6 +229,9 @@ def __init__(self):
lambda: collections.defaultdict(lambda: collections.defaultdict(list))
)
self._group_registry: dict[str, CALLABLE_T] = {}
self.single_device_matches: dict[
Platform, dict[EUI64, list[str]]
] = collections.defaultdict(lambda: collections.defaultdict(list))

def get_entity(
self,
Expand Down Expand Up @@ -342,5 +346,11 @@ def decorator(zha_ent: CALLABLE_T) -> CALLABLE_T:

return decorator

def clean_up(self) -> None:
"""Clean up post discovery."""
self.single_device_matches: dict[
Platform, dict[EUI64, list[str]]
] = collections.defaultdict(lambda: collections.defaultdict(list))


ZHA_ENTITIES = ZHAEntityRegistry()
89 changes: 89 additions & 0 deletions tests/components/zha/test_button.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Test ZHA button."""
from unittest.mock import patch

from freezegun import freeze_time
import pytest
from zigpy.const import SIG_EP_PROFILE
import zigpy.profiles.zha as zha
import zigpy.zcl.clusters.general as general
import zigpy.zcl.clusters.security as security
import zigpy.zcl.foundation as zcl_f

from homeassistant.components.button import DOMAIN, ButtonDeviceClass
from homeassistant.components.button.const import SERVICE_PRESS
from homeassistant.const import (
ATTR_DEVICE_CLASS,
ATTR_ENTITY_ID,
ENTITY_CATEGORY_DIAGNOSTIC,
STATE_UNKNOWN,
)
from homeassistant.helpers import entity_registry as er

from .common import find_entity_id
from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_TYPE

from tests.common import mock_coro


@pytest.fixture
async def contact_sensor(hass, zigpy_device_mock, zha_device_joined_restored):
"""Contact sensor fixture."""

zigpy_device = zigpy_device_mock(
{
1: {
SIG_EP_INPUT: [
general.Basic.cluster_id,
general.Identify.cluster_id,
security.IasZone.cluster_id,
],
SIG_EP_OUTPUT: [],
SIG_EP_TYPE: zha.DeviceType.IAS_ZONE,
SIG_EP_PROFILE: zha.PROFILE_ID,
}
},
)

zha_device = await zha_device_joined_restored(zigpy_device)
return zha_device, zigpy_device.endpoints[1].identify


@freeze_time("2021-11-04 17:37:00", tz_offset=-1)
async def test_button(hass, contact_sensor):
"""Test zha button platform."""

entity_registry = er.async_get(hass)
zha_device, cluster = contact_sensor
assert cluster is not None
entity_id = await find_entity_id(DOMAIN, zha_device, hass)
assert entity_id is not None

state = hass.states.get(entity_id)
assert state
assert state.state == STATE_UNKNOWN
assert state.attributes[ATTR_DEVICE_CLASS] == ButtonDeviceClass.UPDATE

entry = entity_registry.async_get(entity_id)
assert entry
assert entry.entity_category == ENTITY_CATEGORY_DIAGNOSTIC

with patch(
"zigpy.zcl.Cluster.request",
return_value=mock_coro([0x00, zcl_f.Status.SUCCESS]),
):
await hass.services.async_call(
DOMAIN,
SERVICE_PRESS,
{ATTR_ENTITY_ID: entity_id},
blocking=True,
)
await hass.async_block_till_done()
assert len(cluster.request.mock_calls) == 1
assert cluster.request.call_args[0][0] is False
assert cluster.request.call_args[0][1] == 0
assert cluster.request.call_args[0][3] == 5 # duration in seconds

state = hass.states.get(entity_id)
assert state
assert state.state == "2021-11-04T16:37:00+00:00"
assert state.attributes[ATTR_DEVICE_CLASS] == ButtonDeviceClass.UPDATE
27 changes: 16 additions & 11 deletions tests/components/zha/test_discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,17 +125,25 @@ async def test_devices(
ch.id for pool in zha_dev.channels.pools for ch in pool.client_channels.values()
}
assert event_channels == set(device[DEV_SIG_EVT_CHANNELS])

# we need to probe the class create entity factory so we need to reset this to get accurate results
zha_regs.ZHA_ENTITIES.clean_up()
# build a dict of entity_class -> (component, unique_id, channels) tuple
ha_ent_info = {}
created_entity_count = 0
for call in _dispatch.call_args_list:
_, component, entity_cls, unique_id, channels = call[0]
unique_id_head = UNIQUE_ID_HD.match(unique_id).group(0) # ieee + endpoint_id
ha_ent_info[(unique_id_head, entity_cls.__name__)] = (
component,
unique_id,
channels,
)
# the factory can return None. We filter these out to get an accurate created entity count
response = entity_cls.create_entity(unique_id, zha_dev, channels)
if response:
created_entity_count += 1
unique_id_head = UNIQUE_ID_HD.match(unique_id).group(
0
) # ieee + endpoint_id
ha_ent_info[(unique_id_head, entity_cls.__name__)] = (
component,
unique_id,
channels,
)

for comp_id, ent_info in device[DEV_SIG_ENT_MAP].items():
component, unique_id = comp_id
Expand All @@ -156,7 +164,7 @@ async def test_devices(
assert unique_id.startswith(ha_unique_id)
assert {ch.name for ch in ha_channels} == set(ent_info[DEV_SIG_CHANNELS])

assert _dispatch.call_count == len(device[DEV_SIG_ENT_MAP])
assert created_entity_count == len(device[DEV_SIG_ENT_MAP])

entity_ids = hass_disable_services.states.async_entity_ids()
await hass_disable_services.async_block_till_done()
Expand Down Expand Up @@ -298,7 +306,6 @@ async def test_discover_endpoint(device_info, channels_mock, hass):
assert device_info[DEV_SIG_EVT_CHANNELS] == sorted(
ch.id for pool in channels.pools for ch in pool.client_channels.values()
)
assert new_ent.call_count == len(list(device_info[DEV_SIG_ENT_MAP].values()))

# build a dict of entity_class -> (component, unique_id, channels) tuple
ha_ent_info = {}
Expand Down Expand Up @@ -326,8 +333,6 @@ async def test_discover_endpoint(device_info, channels_mock, hass):
assert unique_id.startswith(ha_unique_id)
assert {ch.name for ch in ha_channels} == set(ent_info[DEV_SIG_CHANNELS])

assert new_ent.call_count == len(device_info[DEV_SIG_ENT_MAP])


def _ch_mock(cluster):
"""Return mock of a channel with a cluster."""
Expand Down
Loading