Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 64 additions & 33 deletions homeassistant/auth.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
"""Provide an authentication layer for Home Assistant."""
import asyncio
import binascii
from collections import OrderedDict
from datetime import datetime, timedelta
import os
import importlib
import logging
import os
import uuid
from collections import OrderedDict
from datetime import datetime, timedelta

import attr
import voluptuous as vol
from voluptuous.humanize import humanize_error

from homeassistant import data_entry_flow, requirements
from homeassistant.core import callback
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
from homeassistant.util.decorator import Registry
from homeassistant.core import callback
from homeassistant.util import dt as dt_util

from homeassistant.util.decorator import Registry

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -337,6 +336,16 @@ async def async_create_client(self, name, *, redirect_uris=None,
return await self._store.async_create_client(
name, redirect_uris, no_secret)

async def async_get_or_create_client(self, name, *, redirect_uris=None,
no_secret=False):
"""Find a client, if not exists, create a new one."""
for client in await self._store.async_get_clients():
if client.name == name:
return client

return await self._store.async_create_client(
name, redirect_uris, no_secret)

async def async_get_client(self, client_id):
"""Get a client."""
return await self._store.async_get_client(client_id)
Expand Down Expand Up @@ -380,45 +389,52 @@ class AuthStore:
def __init__(self, hass):
"""Initialize the auth store."""
self.hass = hass
self.users = None
self.clients = None
self._users = None
self._clients = None
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)

async def credentials_for_provider(self, provider_type, provider_id):
"""Return credentials for specific auth provider type and id."""
if self.users is None:
if self._users is None:
await self.async_load()

return [
credentials
for user in self.users.values()
for user in self._users.values()
for credentials in user.credentials
if (credentials.auth_provider_type == provider_type and
credentials.auth_provider_id == provider_id)
]

async def async_get_users(self):
"""Retrieve all users."""
if self._users is None:
await self.async_load()

return list(self._users.values())

async def async_get_user(self, user_id):
"""Retrieve a user."""
if self.users is None:
if self._users is None:
await self.async_load()

return self.users.get(user_id)
return self._users.get(user_id)

async def async_get_or_create_user(self, credentials, auth_provider):
"""Get or create a new user for given credentials.

If link_user is passed in, the credentials will be linked to the passed
in user if the credentials are new.
"""
if self.users is None:
if self._users is None:
await self.async_load()

# New credentials, store in user
if credentials.is_new:
info = await auth_provider.async_user_meta_for_credentials(
credentials)
# Make owner and activate user if it's the first user.
if self.users:
if self._users:
is_owner = False
is_active = False
else:
Expand All @@ -430,11 +446,11 @@ async def async_get_or_create_user(self, credentials, auth_provider):
is_active=is_active,
name=info.get('name'),
)
self.users[new_user.id] = new_user
self._users[new_user.id] = new_user
await self.async_link_user(new_user, credentials)
return new_user

for user in self.users.values():
for user in self._users.values():
for creds in user.credentials:
if (creds.auth_provider_type == credentials.auth_provider_type
and creds.auth_provider_id ==
Expand All @@ -451,22 +467,30 @@ async def async_link_user(self, user, credentials):

async def async_remove_user(self, user):
"""Remove a user."""
self.users.pop(user.id)
self._users.pop(user.id)
await self.async_save()

async def async_create_refresh_token(self, user, client_id):
"""Create a new token for a user."""
local_user = await self.async_get_user(user.id)
if local_user is None:
raise ValueError('Invalid user')

local_client = await self.async_get_client(client_id)
if local_client is None:
raise ValueError('Invalid client_id')

refresh_token = RefreshToken(user, client_id)
user.refresh_tokens[refresh_token.token] = refresh_token
await self.async_save()
return refresh_token

async def async_get_refresh_token(self, token):
"""Get refresh token by token."""
if self.users is None:
if self._users is None:
await self.async_load()

for user in self.users.values():
for user in self._users.values():
refresh_token = user.refresh_tokens.get(token)
if refresh_token is not None:
return refresh_token
Expand All @@ -475,7 +499,7 @@ async def async_get_refresh_token(self, token):

async def async_create_client(self, name, redirect_uris, no_secret):
"""Create a new client."""
if self.clients is None:
if self._clients is None:
await self.async_load()

kwargs = {
Expand All @@ -487,29 +511,36 @@ async def async_create_client(self, name, redirect_uris, no_secret):
kwargs['secret'] = None

client = Client(**kwargs)
self.clients[client.id] = client
self._clients[client.id] = client
await self.async_save()
return client

async def async_get_clients(self):
"""Return all clients."""
if self._clients is None:
await self.async_load()

return list(self._clients.values())

async def async_get_client(self, client_id):
"""Get a client."""
if self.clients is None:
if self._clients is None:
await self.async_load()

return self.clients.get(client_id)
return self._clients.get(client_id)

async def async_load(self):
"""Load the users."""
data = await self._store.async_load()

# Make sure that we're not overriding data if 2 loads happened at the
# same time
if self.users is not None:
if self._users is not None:
return

if data is None:
self.users = {}
self.clients = {}
self._users = {}
self._clients = {}
return

users = {
Expand Down Expand Up @@ -553,8 +584,8 @@ async def async_load(self):
cl_dict['id']: Client(**cl_dict) for cl_dict in data['clients']
}

self.users = users
self.clients = clients
self._users = users
self._clients = clients

async def async_save(self):
"""Save users."""
Expand All @@ -565,7 +596,7 @@ async def async_save(self):
'is_active': user.is_active,
'name': user.name,
}
for user in self.users.values()
for user in self._users.values()
]

credentials = [
Expand All @@ -576,7 +607,7 @@ async def async_save(self):
'auth_provider_id': credential.auth_provider_id,
'data': credential.data,
}
for user in self.users.values()
for user in self._users.values()
for credential in user.credentials
]

Expand All @@ -590,7 +621,7 @@ async def async_save(self):
refresh_token.access_token_expiration.total_seconds(),
'token': refresh_token.token,
}
for user in self.users.values()
for user in self._users.values()
for refresh_token in user.refresh_tokens.values()
]

Expand All @@ -601,7 +632,7 @@ async def async_save(self):
'created_at': access_token.created_at.isoformat(),
'token': access_token.token,
}
for user in self.users.values()
for user in self._users.values()
for refresh_token in user.refresh_tokens.values()
for access_token in refresh_token.access_tokens
]
Expand All @@ -613,7 +644,7 @@ async def async_save(self):
'secret': client.secret,
'redirect_uris': client.redirect_uris,
}
for client in self.clients.values()
for client in self._clients.values()
]

data = {
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/frontend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def add_manifest_json_key(key, val):
async def async_setup(hass, config):
"""Set up the serving of the frontend."""
if hass.auth.active:
client = await hass.auth.async_create_client(
client = await hass.auth.async_get_or_create_client(
'Home Assistant Frontend',
redirect_uris=['/'],
no_secret=True,
Expand Down
10 changes: 5 additions & 5 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,18 +321,18 @@ def add_to_hass(self, hass):
def add_to_auth_manager(self, auth_mgr):
"""Test helper to add entry to hass."""
ensure_auth_manager_loaded(auth_mgr)
auth_mgr._store.users[self.id] = self
auth_mgr._store._users[self.id] = self
return self


@ha.callback
def ensure_auth_manager_loaded(auth_mgr):
"""Ensure an auth manager is considered loaded."""
store = auth_mgr._store
if store.clients is None:
store.clients = {}
if store.users is None:
store.users = {}
if store._clients is None:
store._clients = {}
if store._users is None:
store._users = {}


class MockModule(object):
Expand Down
2 changes: 1 addition & 1 deletion tests/components/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async def async_setup_auth(hass, aiohttp_client, provider_configs=BASE_CONFIG,
})
client = auth.Client('Test Client', CLIENT_ID, CLIENT_SECRET,
redirect_uris=[CLIENT_REDIRECT_URI])
hass.auth._store.clients[client.id] = client
hass.auth._store._clients[client.id] = client
if setup_api:
await async_setup_component(hass, 'api', {})
return await aiohttp_client(hass.http.app)
53 changes: 46 additions & 7 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,13 @@ async def test_saving_loading(hass, hass_storage):
await flush_store(manager._store._store)

store2 = auth.AuthStore(hass)
await store2.async_load()
assert len(store2.users) == 1
assert store2.users[user.id] == user
users = await store2.async_get_users()
assert len(users) == 1
assert users[0] == user

assert len(store2.clients) == 1
assert store2.clients[client.id] == client
clients = await store2.async_get_clients()
assert len(clients) == 1
assert clients[0] == client


def test_access_token_expired():
Expand Down Expand Up @@ -224,15 +225,18 @@ def test_access_token_expired():
async def test_cannot_retrieve_expired_access_token(hass):
"""Test that we cannot retrieve expired access tokens."""
manager = await auth.auth_manager_from_config(hass, [])
client = await manager.async_create_client('test')
user = MockUser(
id='mock-user',
is_owner=False,
is_active=False,
name='Paulus',
).add_to_auth_manager(manager)
refresh_token = await manager.async_create_refresh_token(user, 'bla')
access_token = manager.async_create_access_token(refresh_token)
refresh_token = await manager.async_create_refresh_token(user, client.id)
assert refresh_token.user.id is user.id
assert refresh_token.client_id is client.id

access_token = manager.async_create_access_token(refresh_token)
assert manager.async_get_access_token(access_token.token) is access_token

with patch('homeassistant.auth.dt_util.utcnow',
Expand All @@ -241,3 +245,38 @@ async def test_cannot_retrieve_expired_access_token(hass):

# Even with unpatched time, it should have been removed from manager
assert manager.async_get_access_token(access_token.token) is None


async def test_get_or_create_client(hass):
"""Test that get_or_create_client works."""
manager = await auth.auth_manager_from_config(hass, [])

client1 = await manager.async_get_or_create_client(
'Test Client', redirect_uris=['https://test.com/1'])
assert client1.name is 'Test Client'

client2 = await manager.async_get_or_create_client(
'Test Client', redirect_uris=['https://test.com/1'])
assert client2.id is client1.id


async def test_cannot_create_refresh_token_with_invalide_client_id(hass):
"""Test that we cannot create refresh token with invalid client id."""
manager = await auth.auth_manager_from_config(hass, [])
user = MockUser(
id='mock-user',
is_owner=False,
is_active=False,
name='Paulus',
).add_to_auth_manager(manager)
with pytest.raises(ValueError):
await manager.async_create_refresh_token(user, 'bla')


async def test_cannot_create_refresh_token_with_invalide_user(hass):
"""Test that we cannot create refresh token with invalid client id."""
manager = await auth.auth_manager_from_config(hass, [])
client = await manager.async_create_client('test')
user = MockUser(id='invalid-user')
with pytest.raises(ValueError):
await manager.async_create_refresh_token(user, client.id)