Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion homeassistant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@
block_async_io.enable()
fix_threading_exception_logging()

# pylint: disable=invalid-name
T = TypeVar("T")
# pylint: disable=invalid-name
CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable)
CALLBACK_TYPE = Callable[[], None]
# pylint: enable=invalid-name
Expand Down
28 changes: 7 additions & 21 deletions homeassistant/helpers/device_registry.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""Provide a way to connect entities belonging to one device."""
from asyncio import Event
from collections import OrderedDict
import logging
from typing import Any, Dict, List, Optional, cast
from typing import Any, Dict, List, Optional
import uuid

import attr

from homeassistant.core import callback
from homeassistant.loader import bind_hass

from .singleton import singleton
from .typing import HomeAssistantType

# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
Expand Down Expand Up @@ -356,26 +356,12 @@ def async_clear_area_id(self, area_id: str) -> None:


@bind_hass
@singleton(DATA_REGISTRY)
async def async_get_registry(hass: HomeAssistantType) -> DeviceRegistry:
"""Return device registry instance."""
reg_or_evt = hass.data.get(DATA_REGISTRY)

if not reg_or_evt:
evt = hass.data[DATA_REGISTRY] = Event()

reg = DeviceRegistry(hass)
await reg.async_load()

hass.data[DATA_REGISTRY] = reg
evt.set()
return reg

if isinstance(reg_or_evt, Event):
evt = reg_or_evt
await evt.wait()
return cast(DeviceRegistry, hass.data.get(DATA_REGISTRY))

return cast(DeviceRegistry, reg_or_evt)
"""Create entity registry."""
reg = DeviceRegistry(hass)
await reg.async_load()
return reg


@callback
Expand Down
28 changes: 7 additions & 21 deletions homeassistant/helpers/entity_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
registered. Registering a new entity while a timer is in progress resets the
timer.
"""
import asyncio
from collections import OrderedDict
import logging
from typing import (
Expand Down Expand Up @@ -39,6 +38,7 @@
from homeassistant.util import slugify
from homeassistant.util.yaml import load_yaml

from .singleton import singleton
from .typing import HomeAssistantType

if TYPE_CHECKING:
Expand Down Expand Up @@ -492,26 +492,12 @@ def async_clear_config_entry(self, config_entry: str) -> None:


@bind_hass
@singleton(DATA_REGISTRY)
async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry:
"""Return entity registry instance."""
reg_or_evt = hass.data.get(DATA_REGISTRY)

if not reg_or_evt:
evt = hass.data[DATA_REGISTRY] = asyncio.Event()

reg = EntityRegistry(hass)
await reg.async_load()

hass.data[DATA_REGISTRY] = reg
evt.set()
return reg

if isinstance(reg_or_evt, asyncio.Event):
evt = reg_or_evt
await evt.wait()
return cast(EntityRegistry, hass.data.get(DATA_REGISTRY))

return cast(EntityRegistry, reg_or_evt)
"""Create entity registry."""
reg = EntityRegistry(hass)
await reg.async_load()
return reg


@callback
Expand Down Expand Up @@ -621,4 +607,4 @@ async def async_migrate_entries(
updates = entry_callback(entry)

if updates is not None:
ent_reg.async_update_entity(entry.entity_id, **updates) # type: ignore
ent_reg.async_update_entity(entry.entity_id, **updates)
44 changes: 44 additions & 0 deletions homeassistant/helpers/singleton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Helper to help coordinating calls."""
import asyncio
import functools
from typing import Awaitable, Callable, TypeVar, cast

from homeassistant.core import HomeAssistant

T = TypeVar("T")

FUNC = Callable[[HomeAssistant], Awaitable[T]]


def singleton(data_key: str) -> Callable[[FUNC], FUNC]:
"""Decorate a function that should be called once per instance.

Result will be cached and simultaneous calls will be handled.
"""

def wrapper(func: FUNC) -> FUNC:
"""Wrap a function with caching logic."""

@functools.wraps(func)
async def wrapped(hass: HomeAssistant) -> T:
obj_or_evt = hass.data.get(data_key)

if not obj_or_evt:
evt = hass.data[data_key] = asyncio.Event()

result = await func(hass)

hass.data[data_key] = result
evt.set()
return cast(T, result)

if isinstance(obj_or_evt, asyncio.Event):
evt = obj_or_evt
await evt.wait()
return cast(T, hass.data.get(data_key))

return cast(T, obj_or_evt)

return wrapped

return wrapper
2 changes: 1 addition & 1 deletion pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ persistent=no
extension-pkg-whitelist=ciso8601

[BASIC]
good-names=id,i,j,k,ex,Run,_,fp
good-names=id,i,j,k,ex,Run,_,fp,T

[MESSAGES CONTROL]
# Reasons disabled:
Expand Down