Skip to content
29 changes: 28 additions & 1 deletion homeassistant/helpers/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,9 @@ class Entity(
# Entry in the entity registry
registry_entry: er.RegistryEntry | None = None

# If the entity is removed from the entity registry
_removed_from_registry: bool = False

# The device entry for this entity
device_entry: dr.DeviceEntry | None = None

Expand Down Expand Up @@ -1361,6 +1364,17 @@ async def __async_remove_impl(self, force_remove: bool) -> None:
not force_remove
and self.registry_entry
and not self.registry_entry.disabled
# Check if entity is still in the entity registry
# by checking self._removed_from_registry
#
# Because self.registry_entry is unset in a task,
# its possible that the entity has been removed but
# the task has not yet been executed.
#
# self._removed_from_registry is set to True in a
# callback which does not have the same issue.
#
and not self._removed_from_registry
):
# Set the entity's state will to unavailable + ATTR_RESTORED: True
self.registry_entry.write_unavailable_state(self.hass)
Expand Down Expand Up @@ -1430,10 +1444,23 @@ async def async_internal_will_remove_from_hass(self) -> None:
if self.platform:
self.hass.data[DATA_ENTITY_SOURCE].pop(self.entity_id)

async def _async_registry_updated(
@callback
def _async_registry_updated(
self, event: EventType[er.EventEntityRegistryUpdatedData]
) -> None:
"""Handle entity registry update."""
action = event.data["action"]
is_remove = action == "remove"
self._removed_from_registry = is_remove
if action == "update" or is_remove:
self.hass.async_create_task(
self._async_process_registry_update_or_remove(event)
)

async def _async_process_registry_update_or_remove(
self, event: EventType[er.EventEntityRegistryUpdatedData]
) -> None:
"""Handle entity registry update or remove."""
data = event.data
if data["action"] == "remove":
await self.async_removed_from_registry()
Expand Down
25 changes: 18 additions & 7 deletions homeassistant/helpers/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ def _async_track_state_change_event(
_async_dispatch_entity_id_event,
_async_state_change_filter,
action,
False,
)


Expand Down Expand Up @@ -378,8 +379,16 @@ def _async_track_event(
bool,
],
action: Callable[[EventType[_TypedDictT]], None],
run_immediately: bool,
) -> CALLBACK_TYPE:
"""Track an event by a specific key."""
"""Track an event by a specific key.

This function is intended for internal use only.

The dispatcher_callable, filter_callable, event_type, and run_immediately
must always be the same for the listener_key as the first call to this
function will set the listener_key in hass.data.
"""
if not keys:
return _remove_empty_listener

Expand All @@ -388,24 +397,22 @@ def _async_track_event(

hass_data = hass.data

callbacks: dict[
str, list[HassJob[[EventType[_TypedDictT]], Any]]
] | None = hass_data.get(callbacks_key)
if not callbacks:
callbacks: dict[str, list[HassJob[[EventType[_TypedDictT]], Any]]] | None
if not (callbacks := hass_data.get(callbacks_key)):
callbacks = hass_data[callbacks_key] = {}

if listeners_key not in hass_data:
hass_data[listeners_key] = hass.bus.async_listen(
event_type,
ft.partial(dispatcher_callable, hass, callbacks),
event_filter=ft.partial(filter_callable, hass, callbacks),
run_immediately=run_immediately,
Comment thread
bdraco marked this conversation as resolved.
)

job = HassJob(action, f"track {event_type} event {keys}")

for key in keys:
callback_list = callbacks.get(key)
if callback_list:
if callback_list := callbacks.get(key):
callback_list.append(job)
else:
callbacks[key] = [job]
Expand Down Expand Up @@ -473,6 +480,7 @@ def async_track_entity_registry_updated_event(
_async_dispatch_old_entity_id_or_entity_id_event,
_async_entity_registry_updated_filter,
action,
True,
)


Expand Down Expand Up @@ -529,6 +537,7 @@ def async_track_device_registry_updated_event(
_async_dispatch_device_id_event,
_async_device_registry_updated_filter,
action,
True,
)


Expand Down Expand Up @@ -590,6 +599,7 @@ def _async_track_state_added_domain(
_async_dispatch_domain_event,
_async_domain_added_filter,
action,
False,
)


Expand Down Expand Up @@ -622,6 +632,7 @@ def async_track_state_removed_domain(
_async_dispatch_domain_event,
_async_domain_removed_filter,
action,
False,
)


Expand Down
88 changes: 88 additions & 0 deletions tests/helpers/test_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2468,3 +2468,91 @@ class MockEntityFeatures(IntFlag):
"is using deprecated supported features values which will be removed"
not in caplog.text
)


async def test_remove_entity_registry(
Comment thread
bdraco marked this conversation as resolved.
hass: HomeAssistant, entity_registry: er.EntityRegistry
) -> None:
"""Test removing an entity from the registry."""
result = []

entry = entity_registry.async_get_or_create(
"test", "test_platform", "5678", suggested_object_id="test"
)
assert entry.entity_id == "test.test"

class MockEntity(entity.Entity):
_attr_unique_id = "5678"

def __init__(self) -> None:
self.added_calls = []
self.remove_calls = []

async def async_added_to_hass(self):
self.added_calls.append(None)
self.async_on_remove(lambda: result.append(1))

async def async_will_remove_from_hass(self):
self.remove_calls.append(None)

platform = MockEntityPlatform(hass, domain="test")
ent = MockEntity()
await platform.async_add_entities([ent])
assert hass.states.get("test.test").state == STATE_UNKNOWN
assert len(ent.added_calls) == 1

entry = entity_registry.async_remove(entry.entity_id)
await hass.async_block_till_done()

assert len(result) == 1
assert len(ent.added_calls) == 1
assert len(ent.remove_calls) == 1

assert hass.states.get("test.test") is None


async def test_reset_right_after_remove_entity_registry(
Comment thread
bdraco marked this conversation as resolved.
hass: HomeAssistant, entity_registry: er.EntityRegistry
) -> None:
"""Test resetting the platform right after removing an entity from the registry.

A reset commonly happens during a reload.
"""
result = []

entry = entity_registry.async_get_or_create(
"test", "test_platform", "5678", suggested_object_id="test"
)
assert entry.entity_id == "test.test"

class MockEntity(entity.Entity):
_attr_unique_id = "5678"

def __init__(self) -> None:
self.added_calls = []
self.remove_calls = []

async def async_added_to_hass(self):
self.added_calls.append(None)
self.async_on_remove(lambda: result.append(1))

async def async_will_remove_from_hass(self):
self.remove_calls.append(None)

platform = MockEntityPlatform(hass, domain="test")
ent = MockEntity()
await platform.async_add_entities([ent])
assert hass.states.get("test.test").state == STATE_UNKNOWN
assert len(ent.added_calls) == 1

entry = entity_registry.async_remove(entry.entity_id)

# Reset the platform immediately after removing the entity from the registry
await platform.async_reset()
await hass.async_block_till_done()

assert len(result) == 1
assert len(ent.added_calls) == 1
assert len(ent.remove_calls) == 1

assert hass.states.get("test.test") is None