Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions homeassistant/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,15 @@ async def async_get_user_by_credentials(

return None

async def async_create_system_user(self, name: str) -> models.User:
async def async_create_system_user(
self, name: str,
group_ids: Optional[List[str]] = None) -> models.User:
"""Create a system user."""
user = await self._store.async_create_user(
name=name,
system_generated=True,
is_active=True,
group_ids=[],
group_ids=group_ids or [],
)

self.hass.bus.async_fire(EVENT_USER_ADDED, {
Expand Down Expand Up @@ -217,6 +219,17 @@ async def async_remove_user(self, user: models.User) -> None:
'user_id': user.id
})

async def async_update_user(self, user: models.User,
name: Optional[str] = None,
group_ids: Optional[List[str]] = None) -> None:
"""Update a user."""
kwargs = {} # type: Dict[str,Any]
if name is not None:
kwargs['name'] = name
if group_ids is not None:
kwargs['group_ids'] = group_ids
await self._store.async_update_user(user, **kwargs)

async def async_activate_user(self, user: models.User) -> None:
"""Activate a user."""
await self._store.async_activate_user(user)
Expand Down
27 changes: 27 additions & 0 deletions homeassistant/auth/auth_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,33 @@ async def async_remove_user(self, user: models.User) -> None:
self._users.pop(user.id)
self._async_schedule_save()

async def async_update_user(
self, user: models.User, name: Optional[str] = None,
is_active: Optional[bool] = None,
group_ids: Optional[List[str]] = None) -> None:
"""Update a user."""
assert self._groups is not None

if group_ids is not None:
groups = []
for grid in group_ids:
group = self._groups.get(grid)
if group is None:
raise ValueError("Invalid group specified.")
groups.append(group)

user.groups = groups
user.invalidate_permission_cache()

for attr_name, value in (
('name', name),
('is_active', is_active),
):
if value is not None:
setattr(user, attr_name, value)

self._async_schedule_save()

async def async_activate_user(self, user: models.User) -> None:
"""Activate a user."""
user.is_active = True
Expand Down
16 changes: 15 additions & 1 deletion homeassistant/auth/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from homeassistant.util import dt as dt_util

from . import permissions as perm_mdl
from .const import GROUP_ID_ADMIN
from .util import generate_secret

TOKEN_TYPE_NORMAL = 'normal'
Expand Down Expand Up @@ -48,7 +49,7 @@ class User:
) # type: Dict[str, RefreshToken]

_permissions = attr.ib(
type=perm_mdl.PolicyPermissions,
type=Optional[perm_mdl.PolicyPermissions],
init=False,
cmp=False,
default=None,
Expand All @@ -69,6 +70,19 @@ def permissions(self) -> perm_mdl.AbstractPermissions:

return self._permissions

@property
def is_admin(self) -> bool:
"""Return if user is part of the admin group."""
if self.is_owner:
return True

return self.is_active and any(
gr.id == GROUP_ID_ADMIN for gr in self.groups)

def invalidate_permission_cache(self) -> None:
"""Invalidate permission cache."""
self._permissions = None


@attr.s(slots=True)
class RefreshToken:
Expand Down
61 changes: 19 additions & 42 deletions homeassistant/auth/permissions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@

import voluptuous as vol

from homeassistant.core import State

from .const import CAT_ENTITIES
from .types import CategoryType, PolicyType
from .types import PolicyType
from .entities import ENTITY_POLICY_SCHEMA, compile_entities
from .merge import merge_policies # noqa

Expand All @@ -22,49 +20,32 @@
class AbstractPermissions:
"""Default permissions class."""

def check_entity(self, entity_id: str, key: str) -> bool:
"""Test if we can access entity."""
raise NotImplementedError
_cached_entity_func = None

def filter_states(self, states: List[State]) -> List[State]:
"""Filter a list of states for what the user is allowed to see."""
def _entity_func(self) -> Callable[[str, str], bool]:
"""Return a function that can test entity access."""
raise NotImplementedError

def check_entity(self, entity_id: str, key: str) -> bool:
"""Check if we can access entity."""
entity_func = self._cached_entity_func

if entity_func is None:
entity_func = self._cached_entity_func = self._entity_func()

return entity_func(entity_id, key)


class PolicyPermissions(AbstractPermissions):
"""Handle permissions."""

def __init__(self, policy: PolicyType) -> None:
"""Initialize the permission class."""
self._policy = policy
self._compiled = {} # type: Dict[str, Callable[..., bool]]

def check_entity(self, entity_id: str, key: str) -> bool:
"""Test if we can access entity."""
func = self._policy_func(CAT_ENTITIES, compile_entities)
return func(entity_id, (key,))

def filter_states(self, states: List[State]) -> List[State]:
"""Filter a list of states for what the user is allowed to see."""
func = self._policy_func(CAT_ENTITIES, compile_entities)
keys = ('read',)
return [entity for entity in states if func(entity.entity_id, keys)]

def _policy_func(self, category: str,
compile_func: Callable[[CategoryType], Callable]) \
-> Callable[..., bool]:
"""Get a policy function."""
func = self._compiled.get(category)

if func:
return func

func = self._compiled[category] = compile_func(
self._policy.get(category))

_LOGGER.debug("Compiled %s func: %s", category, func)

return func
def _entity_func(self) -> Callable[[str, str], bool]:
"""Return a function that can test entity access."""
return compile_entities(self._policy.get(CAT_ENTITIES))

def __eq__(self, other: Any) -> bool:
"""Equals check."""
Expand All @@ -78,13 +59,9 @@ class _OwnerPermissions(AbstractPermissions):

# pylint: disable=no-self-use

def check_entity(self, entity_id: str, key: str) -> bool:
"""Test if we can access entity."""
return True

def filter_states(self, states: List[State]) -> List[State]:
"""Filter a list of states for what the user is allowed to see."""
return states
def _entity_func(self) -> Callable[[str, str], bool]:
"""Return a function that can test entity access."""
return lambda entity_id, key: True


OwnerPermissions = _OwnerPermissions() # pylint: disable=invalid-name
40 changes: 20 additions & 20 deletions homeassistant/auth/permissions/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,28 +28,28 @@
}))


def _entity_allowed(schema: ValueType, keys: Tuple[str]) \
def _entity_allowed(schema: ValueType, key: str) \
-> Union[bool, None]:
"""Test if an entity is allowed based on the keys."""
if schema is None or isinstance(schema, bool):
return schema
assert isinstance(schema, dict)
return schema.get(keys[0])
return schema.get(key)


def compile_entities(policy: CategoryType) \
-> Callable[[str, Tuple[str]], bool]:
-> Callable[[str, str], bool]:
"""Compile policy into a function that tests policy."""
# None, Empty Dict, False
if not policy:
def apply_policy_deny_all(entity_id: str, keys: Tuple[str]) -> bool:
def apply_policy_deny_all(entity_id: str, key: str) -> bool:
"""Decline all."""
return False

return apply_policy_deny_all

if policy is True:
def apply_policy_allow_all(entity_id: str, keys: Tuple[str]) -> bool:
def apply_policy_allow_all(entity_id: str, key: str) -> bool:
"""Approve all."""
return True

Expand All @@ -61,7 +61,7 @@ def apply_policy_allow_all(entity_id: str, keys: Tuple[str]) -> bool:
entity_ids = policy.get(ENTITY_ENTITY_IDS)
all_entities = policy.get(SUBCAT_ALL)

funcs = [] # type: List[Callable[[str, Tuple[str]], Union[None, bool]]]
funcs = [] # type: List[Callable[[str, str], Union[None, bool]]]

# The order of these functions matter. The more precise are at the top.
# If a function returns None, they cannot handle it.
Expand All @@ -70,55 +70,55 @@ def apply_policy_allow_all(entity_id: str, keys: Tuple[str]) -> bool:
# Setting entity_ids to a boolean is final decision for permissions
# So return right away.
if isinstance(entity_ids, bool):
def allowed_entity_id_bool(entity_id: str, keys: Tuple[str]) -> bool:
def allowed_entity_id_bool(entity_id: str, key: str) -> bool:
"""Test if allowed entity_id."""
return entity_ids # type: ignore

return allowed_entity_id_bool

if entity_ids is not None:
def allowed_entity_id_dict(entity_id: str, keys: Tuple[str]) \
def allowed_entity_id_dict(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed entity_id."""
return _entity_allowed(
entity_ids.get(entity_id), keys) # type: ignore
entity_ids.get(entity_id), key) # type: ignore

funcs.append(allowed_entity_id_dict)

if isinstance(domains, bool):
def allowed_domain_bool(entity_id: str, keys: Tuple[str]) \
def allowed_domain_bool(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed domain."""
return domains

funcs.append(allowed_domain_bool)

elif domains is not None:
def allowed_domain_dict(entity_id: str, keys: Tuple[str]) \
def allowed_domain_dict(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed domain."""
domain = entity_id.split(".", 1)[0]
return _entity_allowed(domains.get(domain), keys) # type: ignore
return _entity_allowed(domains.get(domain), key) # type: ignore

funcs.append(allowed_domain_dict)

if isinstance(all_entities, bool):
def allowed_all_entities_bool(entity_id: str, keys: Tuple[str]) \
def allowed_all_entities_bool(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed domain."""
return all_entities
funcs.append(allowed_all_entities_bool)

elif all_entities is not None:
def allowed_all_entities_dict(entity_id: str, keys: Tuple[str]) \
def allowed_all_entities_dict(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed domain."""
return _entity_allowed(all_entities, keys)
return _entity_allowed(all_entities, key)
funcs.append(allowed_all_entities_dict)

# Can happen if no valid subcategories specified
if not funcs:
def apply_policy_deny_all_2(entity_id: str, keys: Tuple[str]) -> bool:
def apply_policy_deny_all_2(entity_id: str, key: str) -> bool:
"""Decline all."""
return False

Expand All @@ -128,16 +128,16 @@ def apply_policy_deny_all_2(entity_id: str, keys: Tuple[str]) -> bool:
func = funcs[0]

@wraps(func)
def apply_policy_func(entity_id: str, keys: Tuple[str]) -> bool:
def apply_policy_func(entity_id: str, key: str) -> bool:
"""Apply a single policy function."""
return func(entity_id, keys) is True
return func(entity_id, key) is True

return apply_policy_func

def apply_policy_funcs(entity_id: str, keys: Tuple[str]) -> bool:
def apply_policy_funcs(entity_id: str, key: str) -> bool:
"""Apply several policy functions."""
for func in funcs:
result = func(entity_id, keys)
result = func(entity_id, key)
if result is not None:
return result
return False
Expand Down
Loading