Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 47 additions & 30 deletions homeassistant/components/zha/core/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
from collections import Counter
from collections.abc import Callable
import logging
from typing import TYPE_CHECKING

from homeassistant import const as ha_const
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
async_dispatcher_send,
)
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.entity_registry import async_entries_for_device
from homeassistant.helpers.typing import ConfigType

from . import const as zha_const, registries as zha_regs, typing as zha_typing
from . import const as zha_const, registries as zha_regs
from .. import ( # noqa: F401 pylint: disable=unused-import,
alarm_control_panel,
binary_sensor,
Expand All @@ -32,16 +35,23 @@
)
from .channels import base

if TYPE_CHECKING:
from ..entity import ZhaEntity
from .channels import ChannelPool
from .device import ZHADevice
from .gateway import ZHAGateway
from .group import ZHAGroup

_LOGGER = logging.getLogger(__name__)


@callback
async def async_add_entities(
_async_add_entities: Callable,
_async_add_entities: AddEntitiesCallback,
entities: list[
tuple[
zha_typing.ZhaEntityType,
tuple[str, zha_typing.ZhaDeviceType, list[zha_typing.ChannelType]],
type[ZhaEntity],
tuple[str, ZHADevice, list[base.ZigbeeChannel]],
]
],
update_before_add: bool = True,
Expand All @@ -50,33 +60,35 @@ async def async_add_entities(
if not entities:
return
to_add = [ent_cls.create_entity(*args) for ent_cls, args in entities]
to_add = [entity for entity in to_add if entity is not None]
_async_add_entities(to_add, update_before_add=update_before_add)
entities_to_add = [entity for entity in to_add if entity is not None]
_async_add_entities(entities_to_add, update_before_add=update_before_add)
entities.clear()


class ProbeEndpoint:
"""All discovered channels and entities of an endpoint."""

def __init__(self):
def __init__(self) -> None:
"""Initialize instance."""
self._device_configs = {}
self._device_configs: ConfigType = {}

@callback
def discover_entities(self, channel_pool: zha_typing.ChannelPoolType) -> None:
def discover_entities(self, channel_pool: ChannelPool) -> None:
"""Process an endpoint on a zigpy device."""
self.discover_by_device_type(channel_pool)
self.discover_multi_entities(channel_pool)
self.discover_by_cluster_id(channel_pool)
zha_regs.ZHA_ENTITIES.clean_up()

@callback
def discover_by_device_type(self, channel_pool: zha_typing.ChannelPoolType) -> None:
def discover_by_device_type(self, channel_pool: ChannelPool) -> None:
"""Process an endpoint on a zigpy device."""

unique_id = channel_pool.unique_id

component = self._device_configs.get(unique_id, {}).get(ha_const.CONF_TYPE)
component: str | None = self._device_configs.get(unique_id, {}).get(
ha_const.CONF_TYPE
)
if component is None:
ep_profile_id = channel_pool.endpoint.profile_id
ep_device_type = channel_pool.endpoint.device_type
Expand All @@ -93,7 +105,7 @@ def discover_by_device_type(self, channel_pool: zha_typing.ChannelPoolType) -> N
channel_pool.async_new_entity(component, entity_class, unique_id, claimed)

@callback
def discover_by_cluster_id(self, channel_pool: zha_typing.ChannelPoolType) -> None:
def discover_by_cluster_id(self, channel_pool: ChannelPool) -> None:
"""Process an endpoint on a zigpy device."""

items = zha_regs.SINGLE_INPUT_CLUSTER_DEVICE_CLASS.items()
Expand Down Expand Up @@ -125,8 +137,8 @@ def discover_by_cluster_id(self, channel_pool: zha_typing.ChannelPoolType) -> No
@staticmethod
def probe_single_cluster(
component: str,
channel: zha_typing.ChannelType,
ep_channels: zha_typing.ChannelPoolType,
channel: base.ZigbeeChannel,
ep_channels: ChannelPool,
) -> None:
"""Probe specified cluster for specific component."""
if component is None or component not in zha_const.PLATFORMS:
Expand All @@ -142,9 +154,7 @@ def probe_single_cluster(
ep_channels.claim_channels(claimed)
ep_channels.async_new_entity(component, entity_class, unique_id, claimed)

def handle_on_off_output_cluster_exception(
self, ep_channels: zha_typing.ChannelPoolType
) -> None:
def handle_on_off_output_cluster_exception(self, ep_channels: ChannelPool) -> None:
"""Process output clusters of the endpoint."""

profile_id = ep_channels.endpoint.profile_id
Expand All @@ -167,7 +177,7 @@ def handle_on_off_output_cluster_exception(

@staticmethod
@callback
def discover_multi_entities(channel_pool: zha_typing.ChannelPoolType) -> None:
def discover_multi_entities(channel_pool: ChannelPool) -> None:
"""Process an endpoint on and discover multiple entities."""

ep_profile_id = channel_pool.endpoint.profile_id
Expand Down Expand Up @@ -209,18 +219,21 @@ def discover_multi_entities(channel_pool: zha_typing.ChannelPoolType) -> None:

def initialize(self, hass: HomeAssistant) -> None:
"""Update device overrides config."""
zha_config = hass.data[zha_const.DATA_ZHA].get(zha_const.DATA_ZHA_CONFIG, {})
zha_config: ConfigType = hass.data[zha_const.DATA_ZHA].get(
zha_const.DATA_ZHA_CONFIG, {}
)
if overrides := zha_config.get(zha_const.CONF_DEVICE_CONFIG):
self._device_configs.update(overrides)


class GroupProbe:
"""Determine the appropriate component for a group."""

def __init__(self):
_hass: HomeAssistant

def __init__(self) -> None:
"""Initialize instance."""
self._hass = None
self._unsubs = []
self._unsubs: list[Callable[[], None]] = []

def initialize(self, hass: HomeAssistant) -> None:
"""Initialize the group probe."""
Expand All @@ -231,7 +244,7 @@ def initialize(self, hass: HomeAssistant) -> None:
)
)

def cleanup(self):
def cleanup(self) -> None:
"""Clean up on when zha shuts down."""
for unsub in self._unsubs[:]:
unsub()
Expand All @@ -240,13 +253,15 @@ def cleanup(self):
@callback
def _reprobe_group(self, group_id: int) -> None:
"""Reprobe a group for entities after its members change."""
zha_gateway = self._hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY]
zha_gateway: ZHAGateway = self._hass.data[zha_const.DATA_ZHA][
zha_const.DATA_ZHA_GATEWAY
]
if (zha_group := zha_gateway.groups.get(group_id)) is None:
return
self.discover_group_entities(zha_group)

@callback
def discover_group_entities(self, group: zha_typing.ZhaGroupType) -> None:
def discover_group_entities(self, group: ZHAGroup) -> None:
"""Process a group and create any entities that are needed."""
# only create a group entity if there are 2 or more members in a group
if len(group.members) < 2:
Expand All @@ -262,7 +277,9 @@ def discover_group_entities(self, group: zha_typing.ZhaGroupType) -> None:
if not entity_domains:
return

zha_gateway = self._hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY]
zha_gateway: ZHAGateway = self._hass.data[zha_const.DATA_ZHA][
zha_const.DATA_ZHA_GATEWAY
]
for domain in entity_domains:
entity_class = zha_regs.ZHA_ENTITIES.get_group_entity(domain)
if entity_class is None:
Expand All @@ -281,12 +298,12 @@ def discover_group_entities(self, group: zha_typing.ZhaGroupType) -> None:
async_dispatcher_send(self._hass, zha_const.SIGNAL_ADD_ENTITIES)

@staticmethod
def determine_entity_domains(
hass: HomeAssistant, group: zha_typing.ZhaGroupType
) -> list[str]:
def determine_entity_domains(hass: HomeAssistant, group: ZHAGroup) -> list[str]:
"""Determine the entity domains for this group."""
entity_domains: list[str] = []
zha_gateway = hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY]
zha_gateway: ZHAGateway = hass.data[zha_const.DATA_ZHA][
zha_const.DATA_ZHA_GATEWAY
]
all_domain_occurrences = []
for member in group.members:
if member.device.is_coordinator:
Expand Down