Skip to content
175 changes: 34 additions & 141 deletions homeassistant/components/zha/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from homeassistant.components import websocket_api
from homeassistant.core import callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.device_registry import async_get_registry
from homeassistant.helpers.dispatcher import async_dispatcher_connect

from .core.const import (
Expand Down Expand Up @@ -53,11 +52,7 @@
WARNING_DEVICE_STROBE_HIGH,
WARNING_DEVICE_STROBE_YES,
)
from .core.helpers import (
async_get_device_info,
async_is_bindable_target,
get_matched_clusters,
)
from .core.helpers import async_is_bindable_target, get_matched_clusters

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -212,13 +207,9 @@ def async_cleanup() -> None:
async def websocket_get_devices(hass, connection, msg):
"""Get ZHA devices."""
zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]
ha_device_registry = await async_get_registry(hass)

devices = []
for device in zha_gateway.devices.values():
devices.append(
async_get_device_info(hass, device, ha_device_registry=ha_device_registry)
)
devices = [device.async_get_info() for device in zha_gateway.devices.values()]

connection.send_result(msg[ID], devices)


Expand All @@ -228,16 +219,13 @@ async def websocket_get_devices(hass, connection, msg):
async def websocket_get_groupable_devices(hass, connection, msg):
"""Get ZHA devices that can be grouped."""
zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]
ha_device_registry = await async_get_registry(hass)

devices = []
for device in zha_gateway.devices.values():
if device.is_groupable:
devices.append(
async_get_device_info(
hass, device, ha_device_registry=ha_device_registry
)
)

devices = [
device.async_get_info()
for device in zha_gateway.devices.values()
if device.is_groupable or device.is_coordinator
]

connection.send_result(msg[ID], devices)


Expand All @@ -246,7 +234,8 @@ async def websocket_get_groupable_devices(hass, connection, msg):
@websocket_api.websocket_command({vol.Required(TYPE): "zha/groups"})
async def websocket_get_groups(hass, connection, msg):
"""Get ZHA groups."""
groups = await get_groups(hass)
zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]
groups = [group.async_get_info() for group in zha_gateway.groups.values()]
connection.send_result(msg[ID], groups)


Expand All @@ -258,13 +247,10 @@ async def websocket_get_groups(hass, connection, msg):
async def websocket_get_device(hass, connection, msg):
"""Get ZHA devices."""
zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]
ha_device_registry = await async_get_registry(hass)
ieee = msg[ATTR_IEEE]
device = None
if ieee in zha_gateway.devices:
device = async_get_device_info(
hass, zha_gateway.devices[ieee], ha_device_registry=ha_device_registry
)
device = zha_gateway.devices[ieee].async_get_info()
if not device:
connection.send_message(
websocket_api.error_message(
Expand All @@ -283,17 +269,11 @@ async def websocket_get_device(hass, connection, msg):
async def websocket_get_group(hass, connection, msg):
"""Get ZHA group."""
zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]
ha_device_registry = await async_get_registry(hass)
group_id = msg[GROUP_ID]
group = None

if group_id in zha_gateway.application_controller.groups:
group = async_get_group_info(
hass,
zha_gateway,
zha_gateway.application_controller.groups[group_id],
ha_device_registry,
)
if group_id in zha_gateway.groups:
group = zha_gateway.groups.get(group_id).async_get_info()
if not group:
connection.send_message(
websocket_api.error_message(
Expand All @@ -316,28 +296,10 @@ async def websocket_get_group(hass, connection, msg):
async def websocket_add_group(hass, connection, msg):
"""Add a new ZHA group."""
zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]
ha_device_registry = await async_get_registry(hass)
group_name = msg[GROUP_NAME]
zigpy_group = async_get_group_by_name(zha_gateway, group_name)
ret_group = None
members = msg.get(ATTR_MEMBERS)
# we start with one to fill any gaps from a user removing existing groups
group_id = 1
while group_id in zha_gateway.application_controller.groups:
group_id += 1

# guard against group already existing
if zigpy_group is None:
zigpy_group = zha_gateway.application_controller.groups.add_group(
group_id, group_name
)
if members is not None:
tasks = []
for ieee in members:
tasks.append(zha_gateway.devices[ieee].async_add_to_group(group_id))
await asyncio.gather(*tasks)
ret_group = async_get_group_info(hass, zha_gateway, zigpy_group, ha_device_registry)
connection.send_result(msg[ID], ret_group)
group = await zha_gateway.async_create_zigpy_group(group_name, members)
connection.send_result(msg[ID], group.async_get_info())


@websocket_api.require_admin
Expand All @@ -351,17 +313,16 @@ async def websocket_add_group(hass, connection, msg):
async def websocket_remove_groups(hass, connection, msg):
"""Remove the specified ZHA groups."""
zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]
groups = zha_gateway.application_controller.groups
group_ids = msg[GROUP_IDS]

if len(group_ids) > 1:
tasks = []
for group_id in group_ids:
tasks.append(remove_group(groups[group_id], zha_gateway))
tasks.append(zha_gateway.async_remove_zigpy_group(group_id))
await asyncio.gather(*tasks)
else:
await remove_group(groups[group_ids[0]], zha_gateway)
ret_groups = await get_groups(hass)
await zha_gateway.async_remove_zigpy_group(group_ids[0])
ret_groups = [group.async_get_info() for group in zha_gateway.groups.values()]
connection.send_result(msg[ID], ret_groups)


Expand All @@ -377,25 +338,21 @@ async def websocket_remove_groups(hass, connection, msg):
async def websocket_add_group_members(hass, connection, msg):
"""Add members to a ZHA group."""
zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]
ha_device_registry = await async_get_registry(hass)
group_id = msg[GROUP_ID]
members = msg[ATTR_MEMBERS]
zigpy_group = None
zha_group = None

if group_id in zha_gateway.application_controller.groups:
zigpy_group = zha_gateway.application_controller.groups[group_id]
tasks = []
for ieee in members:
tasks.append(zha_gateway.devices[ieee].async_add_to_group(group_id))
await asyncio.gather(*tasks)
if not zigpy_group:
if group_id in zha_gateway.groups:
zha_group = zha_gateway.groups.get(group_id)
await zha_group.async_add_members(members)
if not zha_group:
connection.send_message(
websocket_api.error_message(
msg[ID], websocket_api.const.ERR_NOT_FOUND, "ZHA Group not found"
)
)
return
ret_group = async_get_group_info(hass, zha_gateway, zigpy_group, ha_device_registry)
ret_group = zha_group.async_get_info()
connection.send_result(msg[ID], ret_group)


Expand All @@ -411,88 +368,24 @@ async def websocket_add_group_members(hass, connection, msg):
async def websocket_remove_group_members(hass, connection, msg):
"""Remove members from a ZHA group."""
zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]
ha_device_registry = await async_get_registry(hass)
group_id = msg[GROUP_ID]
members = msg[ATTR_MEMBERS]
zigpy_group = None
zha_group = None

if group_id in zha_gateway.application_controller.groups:
zigpy_group = zha_gateway.application_controller.groups[group_id]
tasks = []
for ieee in members:
tasks.append(zha_gateway.devices[ieee].async_remove_from_group(group_id))
await asyncio.gather(*tasks)
if not zigpy_group:
if group_id in zha_gateway.groups:
zha_group = zha_gateway.groups.get(group_id)
await zha_group.async_remove_members(members)
if not zha_group:
connection.send_message(
websocket_api.error_message(
msg[ID], websocket_api.const.ERR_NOT_FOUND, "ZHA Group not found"
)
)
return
ret_group = async_get_group_info(hass, zha_gateway, zigpy_group, ha_device_registry)
ret_group = zha_group.async_get_info()
connection.send_result(msg[ID], ret_group)


async def get_groups(hass,):
"""Get ZHA Groups."""
zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]
ha_device_registry = await async_get_registry(hass)

groups = []
for group in zha_gateway.application_controller.groups.values():
groups.append(
async_get_group_info(hass, zha_gateway, group, ha_device_registry)
)
return groups


async def remove_group(group, zha_gateway):
"""Remove ZHA Group."""
if group.members:
tasks = []
for member_ieee in group.members.keys():
if member_ieee[0] in zha_gateway.devices:
tasks.append(
zha_gateway.devices[member_ieee[0]].async_remove_from_group(
group.group_id
)
)
if tasks:
await asyncio.gather(*tasks)
else:
# we have members but none are tracked by ZHA for whatever reason
zha_gateway.application_controller.groups.pop(group.group_id)
else:
zha_gateway.application_controller.groups.pop(group.group_id)


@callback
def async_get_group_info(hass, zha_gateway, group, ha_device_registry):
"""Get ZHA group."""
ret_group = {}
ret_group["group_id"] = group.group_id
ret_group["name"] = group.name
ret_group["members"] = [
async_get_device_info(
hass,
zha_gateway.get_device(member_ieee[0]),
ha_device_registry=ha_device_registry,
)
for member_ieee in group.members.keys()
if member_ieee[0] in zha_gateway.devices
]
return ret_group


@callback
def async_get_group_by_name(zha_gateway, group_name):
"""Get ZHA group by name."""
for group in zha_gateway.application_controller.groups.values():
if group.name == group_name:
return group
return None


@websocket_api.require_admin
@websocket_api.async_response
@websocket_api.websocket_command(
Expand Down Expand Up @@ -712,9 +605,9 @@ async def websocket_get_bindable_devices(hass, connection, msg):
zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]
source_ieee = msg[ATTR_IEEE]
source_device = zha_gateway.get_device(source_ieee)
ha_device_registry = await async_get_registry(hass)

devices = [
async_get_device_info(hass, device, ha_device_registry=ha_device_registry)
device.async_get_info()
for device in zha_gateway.devices.values()
if async_is_bindable_target(source_device, device)
]
Expand Down
13 changes: 9 additions & 4 deletions homeassistant/components/zha/core/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,18 @@ def list(cls):
WARNING_DEVICE_SQUAWK_MODE_DISARMED = 1

ZHA_DISCOVERY_NEW = "zha_discovery_new_{}"
ZHA_GW_MSG_RAW_INIT = "raw_device_initialized"
ZHA_GW_MSG = "zha_gateway_message"
ZHA_GW_MSG_DEVICE_REMOVED = "device_removed"
ZHA_GW_MSG_DEVICE_INFO = "device_info"
ZHA_GW_MSG_DEVICE_FULL_INIT = "device_fully_initialized"
ZHA_GW_MSG_DEVICE_INFO = "device_info"
ZHA_GW_MSG_DEVICE_JOINED = "device_joined"
ZHA_GW_MSG_LOG_OUTPUT = "log_output"
ZHA_GW_MSG_DEVICE_REMOVED = "device_removed"
ZHA_GW_MSG_GROUP_ADDED = "group_added"
ZHA_GW_MSG_GROUP_INFO = "group_info"
ZHA_GW_MSG_GROUP_MEMBER_ADDED = "group_member_added"
ZHA_GW_MSG_GROUP_MEMBER_REMOVED = "group_member_removed"
ZHA_GW_MSG_GROUP_REMOVED = "group_removed"
ZHA_GW_MSG_LOG_ENTRY = "log_entry"
ZHA_GW_MSG_LOG_OUTPUT = "log_output"
ZHA_GW_MSG_RAW_INIT = "raw_device_initialized"
ZHA_GW_RADIO = "radio"
ZHA_GW_RADIO_DESCRIPTION = "radio_description"
29 changes: 29 additions & 0 deletions homeassistant/components/zha/core/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,18 @@ def __init__(self, hass, zigpy_device, zha_gateway):
self._available_check = async_track_time_interval(
self.hass, self._check_available, _UPDATE_ALIVE_INTERVAL
)
self._ha_device_id = None
self.status = DeviceStatus.CREATED

@property
def device_id(self):
"""Return the HA device registry device id."""
return self._ha_device_id

def set_device_id(self, device_id):
"""Set the HA device registry device id."""
self._ha_device_id = device_id

@property
def name(self):
"""Return device name."""
Expand Down Expand Up @@ -406,6 +416,25 @@ def async_update_last_seen(self, last_seen):
"""Set last seen on the zigpy device."""
self._zigpy_device.last_seen = last_seen

@callback
def async_get_info(self):
"""Get ZHA device information."""
device_info = {}
device_info.update(self.device_info)
device_info["entities"] = [
{
"entity_id": entity_ref.reference_id,
ATTR_NAME: entity_ref.device_info[ATTR_NAME],
}
for entity_ref in self.gateway.device_registry[self.ieee]
]
reg_device = self.gateway.ha_device_registry.async_get(self.device_id)
if reg_device is not None:
device_info["user_given_name"] = reg_device.name_by_user
device_info["device_reg_id"] = reg_device.id
device_info["area_id"] = reg_device.area_id
return device_info

@callback
def async_get_clusters(self):
"""Get all clusters for this device."""
Expand Down
Loading