Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 45 additions & 22 deletions homeassistant/helpers/collection.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Helper to deal with YAML + storage."""
from abc import ABC, abstractmethod
import asyncio
from dataclasses import dataclass
import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional, cast
from typing import Any, Awaitable, Callable, Dict, Iterable, List, Optional, cast

import voluptuous as vol
from voluptuous.humanize import humanize_error
Expand All @@ -26,6 +27,20 @@
CHANGE_REMOVED = "removed"


@dataclass
class CollectionChangeSet:
"""Class to represent a change set.

change_type: One of CHANGE_*
item_id: The id of the item
item: The item
"""

change_type: str
item_id: str
item: Any


ChangeListener = Callable[
[
# Change type
Expand Down Expand Up @@ -105,11 +120,14 @@ def async_add_listener(self, listener: ChangeListener) -> None:
"""
self.listeners.append(listener)

async def notify_change(self, change_type: str, item_id: str, item: dict) -> None:
async def notify_changes(self, change_sets: Iterable[CollectionChangeSet]) -> None:
"""Notify listeners of a change."""
self.logger.debug("%s %s: %s", change_type, item_id, item)
await asyncio.gather(
*[listener(change_type, item_id, item) for listener in self.listeners]
*[
listener(change_set.change_type, change_set.item_id, change_set.item)
for listener in self.listeners
for change_set in change_sets
]
)


Expand All @@ -118,9 +136,10 @@ class YamlCollection(ObservableCollection):

async def async_load(self, data: List[dict]) -> None:
"""Load the YAML collection. Overrides existing data."""

old_ids = set(self.data)

tasks = []
change_sets = []

for item in data:
item_id = item[CONF_ID]
Expand All @@ -135,15 +154,15 @@ async def async_load(self, data: List[dict]) -> None:
event = CHANGE_ADDED

self.data[item_id] = item
tasks.append(self.notify_change(event, item_id, item))
change_sets.append(CollectionChangeSet(event, item_id, item))

for item_id in old_ids:
tasks.append(
self.notify_change(CHANGE_REMOVED, item_id, self.data.pop(item_id))
change_sets.append(
CollectionChangeSet(CHANGE_REMOVED, item_id, self.data.pop(item_id))
)

if tasks:
await asyncio.gather(*tasks)
if change_sets:
await self.notify_changes(change_sets)


class StorageCollection(ObservableCollection):
Expand Down Expand Up @@ -178,9 +197,9 @@ async def async_load(self) -> None:
for item in raw_storage["items"]:
self.data[item[CONF_ID]] = item

await asyncio.gather(
*[
self.notify_change(CHANGE_ADDED, item[CONF_ID], item)
await self.notify_changes(
[
CollectionChangeSet(CHANGE_ADDED, item[CONF_ID], item)
for item in raw_storage["items"]
]
)
Expand All @@ -204,7 +223,9 @@ async def async_create_item(self, data: dict) -> dict:
item[CONF_ID] = self.id_manager.generate_id(self._get_suggested_id(item))
self.data[item[CONF_ID]] = item
self._async_schedule_save()
await self.notify_change(CHANGE_ADDED, item[CONF_ID], item)
await self.notify_changes(
[CollectionChangeSet(CHANGE_ADDED, item[CONF_ID], item)]
)
return item

async def async_update_item(self, item_id: str, updates: dict) -> dict:
Expand All @@ -222,7 +243,9 @@ async def async_update_item(self, item_id: str, updates: dict) -> dict:
self.data[item_id] = updated
self._async_schedule_save()

await self.notify_change(CHANGE_UPDATED, item_id, updated)
await self.notify_changes(
[CollectionChangeSet(CHANGE_UPDATED, item_id, updated)]
)

return self.data[item_id]

Expand All @@ -234,7 +257,7 @@ async def async_delete_item(self, item_id: str) -> None:
item = self.data.pop(item_id)
self._async_schedule_save()

await self.notify_change(CHANGE_REMOVED, item_id, item)
await self.notify_changes([CollectionChangeSet(CHANGE_REMOVED, item_id, item)])

@callback
def _async_schedule_save(self) -> None:
Expand All @@ -254,9 +277,9 @@ class IDLessCollection(ObservableCollection):

async def async_load(self, data: List[dict]) -> None:
"""Load the collection. Overrides existing data."""
await asyncio.gather(
*[
self.notify_change(CHANGE_REMOVED, item_id, item)
await self.notify_changes(
[
CollectionChangeSet(CHANGE_REMOVED, item_id, item)
for item_id, item in list(self.data.items())
]
)
Expand All @@ -269,9 +292,9 @@ async def async_load(self, data: List[dict]) -> None:

self.data[item_id] = item

await asyncio.gather(
*[
self.notify_change(CHANGE_ADDED, item_id, item)
await self.notify_changes(
[
CollectionChangeSet(CHANGE_ADDED, item_id, item)
for item_id, item in self.data.items()
]
)
Expand Down
32 changes: 22 additions & 10 deletions tests/helpers/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ async def test_observable_collection():
assert coll.async_items() == [1]

changes = track_changes(coll)
await coll.notify_change("mock_type", "mock_id", {"mock": "item"})
await coll.notify_changes(
[collection.CollectionChangeSet("mock_type", "mock_id", {"mock": "item"})]
)
assert len(changes) == 1
assert changes[0] == ("mock_type", "mock_id", {"mock": "item"})

Expand Down Expand Up @@ -226,25 +228,35 @@ async def test_attach_entity_component_collection(hass):
coll = collection.ObservableCollection(_LOGGER)
collection.attach_entity_component_collection(ent_comp, coll, MockEntity)

await coll.notify_change(
collection.CHANGE_ADDED,
"mock_id",
{"id": "mock_id", "state": "initial", "name": "Mock 1"},
await coll.notify_changes(
[
collection.CollectionChangeSet(
collection.CHANGE_ADDED,
"mock_id",
{"id": "mock_id", "state": "initial", "name": "Mock 1"},
)
],
)

assert hass.states.get("test.mock_1").name == "Mock 1"
assert hass.states.get("test.mock_1").state == "initial"

await coll.notify_change(
collection.CHANGE_UPDATED,
"mock_id",
{"id": "mock_id", "state": "second", "name": "Mock 1 updated"},
await coll.notify_changes(
[
collection.CollectionChangeSet(
collection.CHANGE_UPDATED,
"mock_id",
{"id": "mock_id", "state": "second", "name": "Mock 1 updated"},
)
],
)

assert hass.states.get("test.mock_1").name == "Mock 1 updated"
assert hass.states.get("test.mock_1").state == "second"

await coll.notify_change(collection.CHANGE_REMOVED, "mock_id", None)
await coll.notify_changes(
[collection.CollectionChangeSet(collection.CHANGE_REMOVED, "mock_id", None)],
)

assert hass.states.get("test.mock_1") is None

Expand Down