Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions homeassistant/helpers/entity_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,14 @@ def _async_add_entity(self, entity, update_before_add, component_entities,
entry = registry.async_get_or_create(
self.domain, self.platform_name, entity.unique_id,
suggested_object_id=suggested_object_id)

if entry.disabled:
self.logger.info(
"Not adding entity %s because it's disabled",
entry.name or entity.name or
'"{} {}"'.format(self.platform_name, entity.unique_id))
return

entity.entity_id = entry.entity_id
entity.registry_name = entry.name

Expand Down
14 changes: 13 additions & 1 deletion homeassistant/helpers/entity_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
SAVE_DELAY = 10
_LOGGER = logging.getLogger(__name__)

DISABLED_HASS = 'hass'
DISABLED_USER = 'user'


@attr.s(slots=True, frozen=True)
class RegistryEntry:
Expand All @@ -35,12 +38,20 @@ class RegistryEntry:
unique_id = attr.ib(type=str)
platform = attr.ib(type=str)
name = attr.ib(type=str, default=None)
disabled_by = attr.ib(
type=str, default=None,
validator=attr.validators.in_((DISABLED_HASS, DISABLED_USER, None)))
domain = attr.ib(type=str, default=None, init=False, repr=False)

def __attrs_post_init__(self):
"""Computed properties."""
object.__setattr__(self, "domain", split_entity_id(self.entity_id)[0])

@property
def disabled(self):
"""Return if entry is disabled."""
return self.disabled_by is not None


class EntityRegistry:
"""Class to hold a registry of entities."""
Expand Down Expand Up @@ -116,7 +127,8 @@ def _async_load(self):
entity_id=entity_id,
unique_id=info['unique_id'],
platform=info['platform'],
name=info.get('name')
name=info.get('name'),
disabled_by=info.get('disabled_by')
)

self.entities = entities
Expand Down
30 changes: 25 additions & 5 deletions tests/helpers/test_entity_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@

_LOGGER = logging.getLogger(__name__)
DOMAIN = "test_domain"
PLATFORM = 'test_platform'


class MockEntityPlatform(entity_platform.EntityPlatform):
"""Mock class with some mock defaults."""

def __init__(
self, *, hass,
self, hass,
logger=None,
domain='test',
platform_name='test_platform',
domain=DOMAIN,
platform_name=PLATFORM,
scan_interval=timedelta(seconds=15),
parallel_updates=0,
entity_namespace=None,
Expand Down Expand Up @@ -486,7 +487,26 @@ def test_overriding_name_from_registry(hass):
def test_registry_respect_entity_namespace(hass):
"""Test that the registry respects entity namespace."""
mock_registry(hass)
platform = MockEntityPlatform(hass=hass, entity_namespace='ns')
platform = MockEntityPlatform(hass, entity_namespace='ns')
entity = MockEntity(unique_id='1234', name='Device Name')
yield from platform.async_add_entities([entity])
assert entity.entity_id == 'test.ns_device_name'
assert entity.entity_id == 'test_domain.ns_device_name'


@asyncio.coroutine
def test_registry_respect_entity_disabled(hass):
"""Test that the registry respects entity disabled."""
mock_registry(hass, {
'test_domain.world': entity_registry.RegistryEntry(
entity_id='test_domain.world',
unique_id='1234',
# Using component.async_add_entities is equal to platform "domain"
platform='test_platform',
disabled_by=entity_registry.DISABLED_USER
)
})
platform = MockEntityPlatform(hass)
entity = MockEntity(unique_id='1234')
yield from platform.async_add_entities([entity])
assert entity.entity_id is None
assert hass.states.async_entity_ids() == []
18 changes: 18 additions & 0 deletions tests/helpers/test_entity_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,14 @@ def test_loading_extra_values(hass):
test.no_name:
platform: super_platform
unique_id: without-name
test.disabled_user:
platform: super_platform
unique_id: disabled-user
disabled_by: user
test.disabled_hass:
platform: super_platform
unique_id: disabled-hass
disabled_by: hass
"""

registry = entity_registry.EntityRegistry(hass)
Expand All @@ -162,3 +170,13 @@ def test_loading_extra_values(hass):
'test', 'super_platform', 'without-name')
assert entry_with_name.name == 'registry override'
assert entry_without_name.name is None
assert not entry_with_name.disabled

entry_disabled_hass = registry.async_get_or_create(
'test', 'super_platform', 'disabled-hass')
entry_disabled_user = registry.async_get_or_create(
'test', 'super_platform', 'disabled-user')
assert entry_disabled_hass.disabled
assert entry_disabled_hass.disabled_by == entity_registry.DISABLED_HASS
assert entry_disabled_user.disabled
assert entry_disabled_user.disabled_by == entity_registry.DISABLED_USER