Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion homeassistant/components/zha/core/registries.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from .decorators import CALLABLE_T, DictRegistry, SetRegistry
from .typing import ChannelType

GROUP_ENTITY_DOMAINS = [LIGHT]
GROUP_ENTITY_DOMAINS = [LIGHT, SWITCH]

SMARTTHINGS_ACCELERATION_CLUSTER = 0xFC02
SMARTTHINGS_ARRIVAL_SENSOR_DEVICE_TYPE = 0x8000
Expand Down
106 changes: 86 additions & 20 deletions homeassistant/components/zha/switch.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""Switches on Zigbee Home Automation networks."""
import functools
import logging
from typing import Any, List, Optional

from zigpy.zcl.clusters.general import OnOff
from zigpy.zcl.foundation import Status

from homeassistant.components.switch import DOMAIN, SwitchDevice
from homeassistant.const import STATE_ON
from homeassistant.core import callback
from homeassistant.const import STATE_ON, STATE_UNAVAILABLE
from homeassistant.core import CALLBACK_TYPE, State, callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.event import async_track_state_change

from .core import discovery
from .core.const import (
Expand All @@ -16,12 +19,14 @@
DATA_ZHA_DISPATCHERS,
SIGNAL_ADD_ENTITIES,
SIGNAL_ATTR_UPDATED,
SIGNAL_REMOVE_GROUP,
)
from .core.registries import ZHA_ENTITIES
from .entity import ZhaEntity
from .entity import BaseZhaEntity, ZhaEntity

_LOGGER = logging.getLogger(__name__)
STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, DOMAIN)
GROUP_MATCH = functools.partial(ZHA_ENTITIES.group_match, DOMAIN)


async def async_setup_entry(hass, config_entry, async_add_entities):
Expand All @@ -38,14 +43,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub)


@STRICT_MATCH(channel_names=CHANNEL_ON_OFF)
class Switch(ZhaEntity, SwitchDevice):
"""ZHA switch."""
class BaseSwitch(BaseZhaEntity, SwitchDevice):
"""Common base class for zha switches."""

def __init__(self, unique_id, zha_device, channels, **kwargs):
def __init__(self, *args, **kwargs):
"""Initialize the ZHA switch."""
super().__init__(unique_id, zha_device, channels, **kwargs)
self._on_off_channel = self.cluster_channels.get(CHANNEL_ON_OFF)
self._on_off_channel = None
self._state = None
super().__init__(*args, **kwargs)

@property
def is_on(self) -> bool:
Expand All @@ -54,49 +59,110 @@ def is_on(self) -> bool:
return False
return self._state

async def async_turn_on(self, **kwargs):
async def async_turn_on(self, **kwargs) -> None:
"""Turn the entity on."""
result = await self._on_off_channel.on()
if not isinstance(result, list) or result[1] is not Status.SUCCESS:
return
self._state = True
self.async_write_ha_state()

async def async_turn_off(self, **kwargs):
async def async_turn_off(self, **kwargs) -> None:
"""Turn the entity off."""
result = await self._on_off_channel.off()
if not isinstance(result, list) or result[1] is not Status.SUCCESS:
return
self._state = False
self.async_write_ha_state()


@STRICT_MATCH(channel_names=CHANNEL_ON_OFF)
class Switch(ZhaEntity, BaseSwitch):
"""ZHA switch."""

def __init__(self, unique_id, zha_device, channels, **kwargs):
"""Initialize the ZHA switch."""
super().__init__(unique_id, zha_device, channels, **kwargs)
self._on_off_channel = self.cluster_channels.get(CHANNEL_ON_OFF)

@callback
def async_set_state(self, attr_id, attr_name, value):
def async_set_state(self, attr_id: int, attr_name: str, value: Any):
"""Handle state update from channel."""
self._state = bool(value)
self.async_write_ha_state()

@property
def device_state_attributes(self):
"""Return state attributes."""
return self.state_attributes

async def async_added_to_hass(self):
async def async_added_to_hass(self) -> None:
"""Run when about to be added to hass."""
await super().async_added_to_hass()
await self.async_accept_signal(
self._on_off_channel, SIGNAL_ATTR_UPDATED, self.async_set_state
)

@callback
def async_restore_last_state(self, last_state):
def async_restore_last_state(self, last_state) -> None:
"""Restore previous state."""
self._state = last_state.state == STATE_ON

async def async_update(self):
async def async_update(self) -> None:
"""Attempt to retrieve on off state from the switch."""
await super().async_update()
if self._on_off_channel:
state = await self._on_off_channel.get_attribute_value("on_off")
if state is not None:
self._state = state


@GROUP_MATCH()
class SwitchGroup(BaseSwitch):
"""Representation of a switch group."""

def __init__(
self, entity_ids: List[str], unique_id: str, group_id: int, zha_device, **kwargs
) -> None:
"""Initialize a switch group."""
super().__init__(unique_id, zha_device, **kwargs)
self._name: str = f"{zha_device.gateway.groups.get(group_id).name}_group_{group_id}"
self._group_id: int = group_id
self._available: bool = False
self._entity_ids: List[str] = entity_ids
group = self.zha_device.gateway.get_group(self._group_id)
self._on_off_channel = group.endpoint[OnOff.cluster_id]
self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None

async def async_added_to_hass(self) -> None:
"""Register callbacks."""
await super().async_added_to_hass()
await self.async_accept_signal(
None,
f"{SIGNAL_REMOVE_GROUP}_{self._group_id}",
self.async_remove,
signal_override=True,
)

@callback
def async_state_changed_listener(
entity_id: str, old_state: State, new_state: State
):
"""Handle child updates."""
self.async_schedule_update_ha_state(True)

self._async_unsub_state_changed = async_track_state_change(
self.hass, self._entity_ids, async_state_changed_listener
)
await self.async_update()

async def async_will_remove_from_hass(self) -> None:
"""Handle removal from Home Assistant."""
await super().async_will_remove_from_hass()
if self._async_unsub_state_changed is not None:
self._async_unsub_state_changed()
self._async_unsub_state_changed = None

async def async_update(self) -> None:
"""Query all members and determine the light group state."""
all_states = [self.hass.states.get(x) for x in self._entity_ids]
states: List[State] = list(filter(None, all_states))
on_states = [state for state in states if state.state == STATE_ON]

self._state = len(on_states) > 0
self._available = any(state.state != STATE_UNAVAILABLE for state in states)
156 changes: 156 additions & 0 deletions tests/components/zha/test_switch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import call, patch

import pytest
import zigpy.profiles.zha as zha
import zigpy.zcl.clusters.general as general
import zigpy.zcl.foundation as zcl_f

Expand All @@ -10,15 +11,19 @@

from .common import (
async_enable_traffic,
async_find_group_entity_id,
async_test_rejoin,
find_entity_id,
get_zha_gateway,
send_attributes_report,
)

from tests.common import mock_coro

ON = 1
OFF = 0
IEEE_GROUPABLE_DEVICE = "01:2d:6f:00:0a:90:69:e8"
IEEE_GROUPABLE_DEVICE2 = "02:2d:6f:00:0a:90:69:e8"


@pytest.fixture
Expand All @@ -34,6 +39,64 @@ def zigpy_device(zigpy_device_mock):
return zigpy_device_mock(endpoints)


@pytest.fixture
async def coordinator(hass, zigpy_device_mock, zha_device_joined):
"""Test zha light platform."""

zigpy_device = zigpy_device_mock(
{
1: {
"in_clusters": [],
"out_clusters": [],
"device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT,
}
},
ieee="00:15:8d:00:02:32:4f:32",
nwk=0x0000,
)
zha_device = await zha_device_joined(zigpy_device)
zha_device.set_available(True)
return zha_device


@pytest.fixture
async def device_switch_1(hass, zigpy_device_mock, zha_device_joined):
"""Test zha switch platform."""

zigpy_device = zigpy_device_mock(
{
1: {
"in_clusters": [general.OnOff.cluster_id],
"out_clusters": [],
"device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT,
}
},
ieee=IEEE_GROUPABLE_DEVICE,
)
zha_device = await zha_device_joined(zigpy_device)
zha_device.set_available(True)
return zha_device


@pytest.fixture
async def device_switch_2(hass, zigpy_device_mock, zha_device_joined):
"""Test zha switch platform."""

zigpy_device = zigpy_device_mock(
{
1: {
"in_clusters": [general.OnOff.cluster_id],
"out_clusters": [],
"device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT,
}
},
ieee=IEEE_GROUPABLE_DEVICE2,
)
zha_device = await zha_device_joined(zigpy_device)
zha_device.set_available(True)
return zha_device


async def test_switch(hass, zha_device_joined_restored, zigpy_device):
"""Test zha switch platform."""

Expand Down Expand Up @@ -89,3 +152,96 @@ async def test_switch(hass, zha_device_joined_restored, zigpy_device):

# test joining a new switch to the network and HA
await async_test_rejoin(hass, zigpy_device, [cluster], (1,))


async def async_test_zha_group_switch_entity(
hass, device_switch_1, device_switch_2, coordinator
):
"""Test the switch entity for a ZHA group."""
zha_gateway = get_zha_gateway(hass)
assert zha_gateway is not None
zha_gateway.coordinator_zha_device = coordinator
coordinator._zha_gateway = zha_gateway
device_switch_1._zha_gateway = zha_gateway
device_switch_2._zha_gateway = zha_gateway
member_ieee_addresses = [device_switch_1.ieee, device_switch_2.ieee]

# test creating a group with 2 members
zha_group = await zha_gateway.async_create_zigpy_group(
"Test Group", member_ieee_addresses
)
await hass.async_block_till_done()

assert zha_group is not None
assert zha_group.entity_domain == DOMAIN
assert len(zha_group.members) == 2
for member in zha_group.members:
assert member.ieee in member_ieee_addresses

entity_id = async_find_group_entity_id(hass, DOMAIN, zha_group)
assert hass.states.get(entity_id) is not None

group_cluster_on_off = zha_group.endpoint[general.OnOff.cluster_id]
dev1_cluster_on_off = device_switch_1.endpoints[1].on_off
dev2_cluster_on_off = device_switch_2.endpoints[1].on_off

# test that the lights were created and that they are unavailable
assert hass.states.get(entity_id).state == STATE_UNAVAILABLE

# allow traffic to flow through the gateway and device
await async_enable_traffic(hass, zha_group.members)

# test that the lights were created and are off
assert hass.states.get(entity_id).state == STATE_OFF

# turn on from HA
with patch(
"zigpy.zcl.Cluster.request",
return_value=mock_coro([0x00, zcl_f.Status.SUCCESS]),
):
# turn on via UI
await hass.services.async_call(
DOMAIN, "turn_on", {"entity_id": entity_id}, blocking=True
)
assert len(group_cluster_on_off.request.mock_calls) == 1
assert group_cluster_on_off.request.call_args == call(
False, ON, (), expect_reply=True, manufacturer=None, tsn=None
)
assert hass.states.get(entity_id).state == STATE_ON

# turn off from HA
with patch(
"zigpy.zcl.Cluster.request",
return_value=mock_coro([0x01, zcl_f.Status.SUCCESS]),
):
# turn off via UI
await hass.services.async_call(
DOMAIN, "turn_off", {"entity_id": entity_id}, blocking=True
)
assert len(group_cluster_on_off.request.mock_calls) == 1
assert group_cluster_on_off.request.call_args == call(
False, OFF, (), expect_reply=True, manufacturer=None, tsn=None
)
assert hass.states.get(entity_id).state == STATE_OFF

# test some of the group logic to make sure we key off states correctly
await dev1_cluster_on_off.on()
await dev2_cluster_on_off.on()

# test that group light is on
assert hass.states.get(entity_id).state == STATE_ON

await dev1_cluster_on_off.off()

# test that group light is still on
assert hass.states.get(entity_id).state == STATE_ON

await dev2_cluster_on_off.off()

# test that group light is now off
assert hass.states.get(entity_id).state == STATE_OFF

await dev1_cluster_on_off.on()

# test that group light is now back on
assert hass.states.get(entity_id).state == STATE_ON