Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 64 additions & 22 deletions homeassistant/helpers/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from abc import ABC, abstractmethod
import asyncio
from dataclasses import dataclass
from itertools import groupby
import logging
from typing import Any, Awaitable, Callable, Iterable, Optional, cast
from typing import Any, Awaitable, Callable, Coroutine, Iterable, Optional, cast

import voluptuous as vol
from voluptuous.humanize import humanize_error
Expand Down Expand Up @@ -55,6 +56,8 @@ class CollectionChangeSet:
Awaitable[None],
]

ChangeSetListener = Callable[[Iterable[CollectionChangeSet]], Awaitable[None]]


class CollectionError(HomeAssistantError):
"""Base class for collection related errors."""
Expand Down Expand Up @@ -106,6 +109,7 @@ def __init__(self, logger: logging.Logger, id_manager: IDManager | None = None):
self.id_manager = id_manager or IDManager()
self.data: dict[str, dict] = {}
self.listeners: list[ChangeListener] = []
self.change_set_listeners: list[ChangeSetListener] = []

self.id_manager.add_collection(self.data)

Expand All @@ -122,14 +126,26 @@ def async_add_listener(self, listener: ChangeListener) -> None:
"""
self.listeners.append(listener)

@callback
def async_add_change_set_listener(self, listener: ChangeSetListener) -> None:
"""Add a listener for a full change set.

Will be called with [(change_type, item_id, updated_config), ...]
"""
self.change_set_listeners.append(listener)

async def notify_changes(self, change_sets: Iterable[CollectionChangeSet]) -> None:
"""Notify listeners of a change."""
await asyncio.gather(
*[
listener(change_set.change_type, change_set.item_id, change_set.item)
for listener in self.listeners
for change_set in change_sets
]
],
*[
change_set_listener(change_sets)
for change_set_listener in self.change_set_listeners
],
)


Expand Down Expand Up @@ -312,29 +328,55 @@ def sync_entity_lifecycle(
) -> None:
"""Map a collection to an entity component."""
entities = {}
ent_reg = entity_registry.async_get(hass)

async def _collection_changed(change_type: str, item_id: str, config: dict) -> None:
"""Handle a collection change."""
if change_type == CHANGE_ADDED:
entity = create_entity(config)
await entity_component.async_add_entities([entity])
entities[item_id] = entity
return

if change_type == CHANGE_REMOVED:
ent_reg = await entity_registry.async_get_registry(hass)
ent_to_remove = ent_reg.async_get_entity_id(domain, platform, item_id)
if ent_to_remove is not None:
ent_reg.async_remove(ent_to_remove)
else:
await entities[item_id].async_remove(force_remove=True)
entities.pop(item_id)
return
async def _add_entity(change_set: CollectionChangeSet) -> Entity:
entities[change_set.item_id] = create_entity(change_set.item)
return entities[change_set.item_id]

# CHANGE_UPDATED
await entities[item_id].async_update_config(config) # type: ignore
async def _remove_entity(change_set: CollectionChangeSet) -> None:
ent_to_remove = ent_reg.async_get_entity_id(
domain, platform, change_set.item_id
)
if ent_to_remove is not None:
ent_reg.async_remove(ent_to_remove)
else:
await entities[change_set.item_id].async_remove(force_remove=True)
entities.pop(change_set.item_id)

async def _update_entity(change_set: CollectionChangeSet) -> None:
await entities[change_set.item_id].async_update_config(change_set.item) # type: ignore

_func_map: dict[
str, Callable[[CollectionChangeSet], Coroutine[Any, Any, Entity | None]]
] = {
CHANGE_ADDED: _add_entity,
CHANGE_REMOVED: _remove_entity,
CHANGE_UPDATED: _update_entity,
}

async def _collection_changed(change_sets: Iterable[CollectionChangeSet]) -> None:
"""Handle a collection change."""
# Create a new bucket every time we have a different change type
# to ensure operations happen in order. We only group
# the same change type.
for _, grouped in groupby(
change_sets, lambda change_set: change_set.change_type
):
new_entities = [
entity
for entity in await asyncio.gather(
*[
_func_map[change_set.change_type](change_set)
for change_set in grouped
]
)
if entity is not None
]
if new_entities:
await entity_component.async_add_entities(new_entities)

collection.async_add_listener(_collection_changed)
collection.async_add_change_set_listener(_collection_changed)


class StorageCollectionWebsocket:
Expand Down