From 76af7b4178b34577d98236415ef1e16a99981c21 Mon Sep 17 00:00:00 2001 From: Jason Hu Date: Fri, 29 Jun 2018 09:55:03 -0700 Subject: [PATCH 1/6] Only create frontend client_id once --- homeassistant/auth.py | 23 ++++++++++++++++- homeassistant/components/frontend/__init__.py | 2 +- tests/test_auth.py | 25 +++++++++++++++++++ 3 files changed, 48 insertions(+), 2 deletions(-) diff --git a/homeassistant/auth.py b/homeassistant/auth.py index 22abcdf213ccc2..822c81fec600e6 100644 --- a/homeassistant/auth.py +++ b/homeassistant/auth.py @@ -1,7 +1,7 @@ """Provide an authentication layer for Home Assistant.""" import asyncio import binascii -from collections import OrderedDict +from collections import OrderedDict, Counter from datetime import datetime, timedelta import os import importlib @@ -337,6 +337,12 @@ async def async_create_client(self, name, *, redirect_uris=None, return await self._store.async_create_client( name, redirect_uris, no_secret) + async def async_get_or_create_client(self, name, *, redirect_uris=None, + no_secret=False): + """Find a client, if not exists, create a new one.""" + return await self._store.async_get_or_create_client( + name, redirect_uris, no_secret) + async def async_get_client(self, client_id): """Get a client.""" return await self._store.async_get_client(client_id) @@ -491,6 +497,21 @@ async def async_create_client(self, name, redirect_uris, no_secret): await self.async_save() return client + async def async_get_or_create_client(self, name, redirect_uris, no_secret): + """Find a client, if not exists, create a new one.""" + if self.clients is None: + await self.async_load() + + redirect_uris_counter = Counter(redirect_uris) + + for client_id, client in self.clients.items(): + if (client.name == name + and Counter(client.redirect_uris) == redirect_uris_counter + and no_secret == (client.secret is None)): + return client + + return await self.async_create_client(name, redirect_uris, no_secret) + async def async_get_client(self, client_id): """Get a client.""" if self.clients is None: diff --git a/homeassistant/components/frontend/__init__.py b/homeassistant/components/frontend/__init__.py index 0e9d7612669152..24c639021801ce 100644 --- a/homeassistant/components/frontend/__init__.py +++ b/homeassistant/components/frontend/__init__.py @@ -201,7 +201,7 @@ def add_manifest_json_key(key, val): async def async_setup(hass, config): """Set up the serving of the frontend.""" if hass.auth.active: - client = await hass.auth.async_create_client( + client = await hass.auth.async_get_or_create_client( 'Home Assistant Frontend', redirect_uris=['/'], no_secret=True, diff --git a/tests/test_auth.py b/tests/test_auth.py index 4c0db71466e97d..e2828878226734 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -241,3 +241,28 @@ async def test_cannot_retrieve_expired_access_token(hass): # Even with unpatched time, it should have been removed from manager assert manager.async_get_access_token(access_token.token) is None + + +async def test_get_or_create_client(hass): + """Test that get_or_create_client works.""" + manager = await auth.auth_manager_from_config(hass, []) + + client1 = await manager.async_get_or_create_client( + 'Test Client', redirect_uris=['https://test.com/1']) + assert client1.name is 'Test Client' + + client2 = await manager.async_get_or_create_client( + 'Test Client', redirect_uris=['https://test.com/1']) + assert client2.id is client1.id + + client_no_secret = await manager.async_get_or_create_client( + 'Test Client', + redirect_uris=['https://test.com/1', '/'], + no_secret=True) + assert client_no_secret.id is not client1.id + + client_no_secret2 = await manager.async_get_or_create_client( + 'Test Client', + redirect_uris=['/', 'https://test.com/1'], # different order + no_secret=True) + assert client_no_secret2.id is client_no_secret.id From e2cdaf965e4bc8ff7e8e08e33ba2c5eb76af473e Mon Sep 17 00:00:00 2001 From: Jason Hu Date: Fri, 29 Jun 2018 10:28:13 -0700 Subject: [PATCH 2/6] Check user and client_id before create refresh token --- homeassistant/auth.py | 8 ++++++++ tests/test_auth.py | 29 +++++++++++++++++++++++++++-- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/homeassistant/auth.py b/homeassistant/auth.py index 822c81fec600e6..9a0452030e38cf 100644 --- a/homeassistant/auth.py +++ b/homeassistant/auth.py @@ -462,6 +462,14 @@ async def async_remove_user(self, user): async def async_create_refresh_token(self, user, client_id): """Create a new token for a user.""" + local_user = await self.async_get_user(user.id) + if local_user is None: + raise ValueError('Invalid user') + + local_client = await self.async_get_client(client_id) + if local_client is None: + raise ValueError('Invalid client_id') + refresh_token = RefreshToken(user, client_id) user.refresh_tokens[refresh_token.token] = refresh_token await self.async_save() diff --git a/tests/test_auth.py b/tests/test_auth.py index e2828878226734..c655e2e3c93f18 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -224,15 +224,18 @@ def test_access_token_expired(): async def test_cannot_retrieve_expired_access_token(hass): """Test that we cannot retrieve expired access tokens.""" manager = await auth.auth_manager_from_config(hass, []) + client = await manager.async_create_client('test') user = MockUser( id='mock-user', is_owner=False, is_active=False, name='Paulus', ).add_to_auth_manager(manager) - refresh_token = await manager.async_create_refresh_token(user, 'bla') - access_token = manager.async_create_access_token(refresh_token) + refresh_token = await manager.async_create_refresh_token(user, client.id) + assert refresh_token.user.id is user.id + assert refresh_token.client_id is client.id + access_token = manager.async_create_access_token(refresh_token) assert manager.async_get_access_token(access_token.token) is access_token with patch('homeassistant.auth.dt_util.utcnow', @@ -266,3 +269,25 @@ async def test_get_or_create_client(hass): redirect_uris=['/', 'https://test.com/1'], # different order no_secret=True) assert client_no_secret2.id is client_no_secret.id + + +async def test_cannot_create_refresh_token_with_invalide_client_id(hass): + """Test that we cannot create refresh token with invalid client id.""" + manager = await auth.auth_manager_from_config(hass, []) + user = MockUser( + id='mock-user', + is_owner=False, + is_active=False, + name='Paulus', + ).add_to_auth_manager(manager) + with pytest.raises(ValueError): + await manager.async_create_refresh_token(user, 'bla') + + +async def test_cannot_create_refresh_token_with_invalide_user(hass): + """Test that we cannot create refresh token with invalid client id.""" + manager = await auth.auth_manager_from_config(hass, []) + client = await manager.async_create_client('test') + user = MockUser(id='invalid-user') + with pytest.raises(ValueError): + await manager.async_create_refresh_token(user, client.id) From 804002e5fda90ae1a0cbcfa0a17d2f7158fd862e Mon Sep 17 00:00:00 2001 From: Jason Hu Date: Fri, 29 Jun 2018 10:37:07 -0700 Subject: [PATCH 3/6] Lint --- homeassistant/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/homeassistant/auth.py b/homeassistant/auth.py index 9a0452030e38cf..9939a20a97bee0 100644 --- a/homeassistant/auth.py +++ b/homeassistant/auth.py @@ -512,7 +512,7 @@ async def async_get_or_create_client(self, name, redirect_uris, no_secret): redirect_uris_counter = Counter(redirect_uris) - for client_id, client in self.clients.items(): + for _, client in self.clients.items(): if (client.name == name and Counter(client.redirect_uris) == redirect_uris_counter and no_secret == (client.secret is None)): From 8de8a58d5b7ed5f9bbc9dae210ac6c77bf32fba3 Mon Sep 17 00:00:00 2001 From: Jason Hu Date: Fri, 29 Jun 2018 15:08:49 -0700 Subject: [PATCH 4/6] Follow code review comment --- homeassistant/auth.py | 19 +++++++------------ tests/test_auth.py | 12 ------------ 2 files changed, 7 insertions(+), 24 deletions(-) diff --git a/homeassistant/auth.py b/homeassistant/auth.py index 9939a20a97bee0..38f08412b82f5e 100644 --- a/homeassistant/auth.py +++ b/homeassistant/auth.py @@ -1,23 +1,22 @@ """Provide an authentication layer for Home Assistant.""" import asyncio import binascii -from collections import OrderedDict, Counter -from datetime import datetime, timedelta -import os import importlib import logging +import os import uuid +from collections import OrderedDict +from datetime import datetime, timedelta import attr import voluptuous as vol from voluptuous.humanize import humanize_error from homeassistant import data_entry_flow, requirements -from homeassistant.core import callback from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID -from homeassistant.util.decorator import Registry +from homeassistant.core import callback from homeassistant.util import dt as dt_util - +from homeassistant.util.decorator import Registry _LOGGER = logging.getLogger(__name__) @@ -510,12 +509,8 @@ async def async_get_or_create_client(self, name, redirect_uris, no_secret): if self.clients is None: await self.async_load() - redirect_uris_counter = Counter(redirect_uris) - - for _, client in self.clients.items(): - if (client.name == name - and Counter(client.redirect_uris) == redirect_uris_counter - and no_secret == (client.secret is None)): + for client in self.clients.values(): + if client.name == name: return client return await self.async_create_client(name, redirect_uris, no_secret) diff --git a/tests/test_auth.py b/tests/test_auth.py index c655e2e3c93f18..89222b455a79de 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -258,18 +258,6 @@ async def test_get_or_create_client(hass): 'Test Client', redirect_uris=['https://test.com/1']) assert client2.id is client1.id - client_no_secret = await manager.async_get_or_create_client( - 'Test Client', - redirect_uris=['https://test.com/1', '/'], - no_secret=True) - assert client_no_secret.id is not client1.id - - client_no_secret2 = await manager.async_get_or_create_client( - 'Test Client', - redirect_uris=['/', 'https://test.com/1'], # different order - no_secret=True) - assert client_no_secret2.id is client_no_secret.id - async def test_cannot_create_refresh_token_with_invalide_client_id(hass): """Test that we cannot create refresh token with invalid client id.""" From 0408b08dccb5f6592652e40ef8bdada339f91acc Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sun, 1 Jul 2018 13:18:40 -0400 Subject: [PATCH 5/6] Minor clenaup --- homeassistant/auth.py | 77 +++++++++++++++++-------------- tests/common.py | 10 ++-- tests/components/auth/__init__.py | 2 +- tests/test_auth.py | 11 +++-- 4 files changed, 54 insertions(+), 46 deletions(-) diff --git a/homeassistant/auth.py b/homeassistant/auth.py index 38f08412b82f5e..8ef16aa4d29b46 100644 --- a/homeassistant/auth.py +++ b/homeassistant/auth.py @@ -339,7 +339,11 @@ async def async_create_client(self, name, *, redirect_uris=None, async def async_get_or_create_client(self, name, *, redirect_uris=None, no_secret=False): """Find a client, if not exists, create a new one.""" - return await self._store.async_get_or_create_client( + for client in await self._store.async_get_clients(): + if client.name == name: + return client + + return await self._store.async_create_client( name, redirect_uris, no_secret) async def async_get_client(self, client_id): @@ -385,29 +389,36 @@ class AuthStore: def __init__(self, hass): """Initialize the auth store.""" self.hass = hass - self.users = None - self.clients = None + self._users = None + self._clients = None self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) async def credentials_for_provider(self, provider_type, provider_id): """Return credentials for specific auth provider type and id.""" - if self.users is None: + if self._users is None: await self.async_load() return [ credentials - for user in self.users.values() + for user in self._users.values() for credentials in user.credentials if (credentials.auth_provider_type == provider_type and credentials.auth_provider_id == provider_id) ] + async def async_get_users(self): + """Retrieve all users.""" + if self._users is None: + await self.async_load() + + return list(self._users.values()) + async def async_get_user(self, user_id): """Retrieve a user.""" - if self.users is None: + if self._users is None: await self.async_load() - return self.users.get(user_id) + return self._users.get(user_id) async def async_get_or_create_user(self, credentials, auth_provider): """Get or create a new user for given credentials. @@ -415,7 +426,7 @@ async def async_get_or_create_user(self, credentials, auth_provider): If link_user is passed in, the credentials will be linked to the passed in user if the credentials are new. """ - if self.users is None: + if self._users is None: await self.async_load() # New credentials, store in user @@ -423,7 +434,7 @@ async def async_get_or_create_user(self, credentials, auth_provider): info = await auth_provider.async_user_meta_for_credentials( credentials) # Make owner and activate user if it's the first user. - if self.users: + if self._users: is_owner = False is_active = False else: @@ -435,11 +446,11 @@ async def async_get_or_create_user(self, credentials, auth_provider): is_active=is_active, name=info.get('name'), ) - self.users[new_user.id] = new_user + self._users[new_user.id] = new_user await self.async_link_user(new_user, credentials) return new_user - for user in self.users.values(): + for user in self._users.values(): for creds in user.credentials: if (creds.auth_provider_type == credentials.auth_provider_type and creds.auth_provider_id == @@ -456,7 +467,7 @@ async def async_link_user(self, user, credentials): async def async_remove_user(self, user): """Remove a user.""" - self.users.pop(user.id) + self._users.pop(user.id) await self.async_save() async def async_create_refresh_token(self, user, client_id): @@ -476,10 +487,10 @@ async def async_create_refresh_token(self, user, client_id): async def async_get_refresh_token(self, token): """Get refresh token by token.""" - if self.users is None: + if self._users is None: await self.async_load() - for user in self.users.values(): + for user in self._users.values(): refresh_token = user.refresh_tokens.get(token) if refresh_token is not None: return refresh_token @@ -488,7 +499,7 @@ async def async_get_refresh_token(self, token): async def async_create_client(self, name, redirect_uris, no_secret): """Create a new client.""" - if self.clients is None: + if self._clients is None: await self.async_load() kwargs = { @@ -500,27 +511,23 @@ async def async_create_client(self, name, redirect_uris, no_secret): kwargs['secret'] = None client = Client(**kwargs) - self.clients[client.id] = client + self._clients[client.id] = client await self.async_save() return client - async def async_get_or_create_client(self, name, redirect_uris, no_secret): + async def async_get_clients(self): """Find a client, if not exists, create a new one.""" - if self.clients is None: + if self._clients is None: await self.async_load() - for client in self.clients.values(): - if client.name == name: - return client - - return await self.async_create_client(name, redirect_uris, no_secret) + return list(self._clients.values()) async def async_get_client(self, client_id): """Get a client.""" - if self.clients is None: + if self._clients is None: await self.async_load() - return self.clients.get(client_id) + return self._clients.get(client_id) async def async_load(self): """Load the users.""" @@ -528,12 +535,12 @@ async def async_load(self): # Make sure that we're not overriding data if 2 loads happened at the # same time - if self.users is not None: + if self._users is not None: return if data is None: - self.users = {} - self.clients = {} + self._users = {} + self._clients = {} return users = { @@ -577,8 +584,8 @@ async def async_load(self): cl_dict['id']: Client(**cl_dict) for cl_dict in data['clients'] } - self.users = users - self.clients = clients + self._users = users + self._clients = clients async def async_save(self): """Save users.""" @@ -589,7 +596,7 @@ async def async_save(self): 'is_active': user.is_active, 'name': user.name, } - for user in self.users.values() + for user in self._users.values() ] credentials = [ @@ -600,7 +607,7 @@ async def async_save(self): 'auth_provider_id': credential.auth_provider_id, 'data': credential.data, } - for user in self.users.values() + for user in self._users.values() for credential in user.credentials ] @@ -614,7 +621,7 @@ async def async_save(self): refresh_token.access_token_expiration.total_seconds(), 'token': refresh_token.token, } - for user in self.users.values() + for user in self._users.values() for refresh_token in user.refresh_tokens.values() ] @@ -625,7 +632,7 @@ async def async_save(self): 'created_at': access_token.created_at.isoformat(), 'token': access_token.token, } - for user in self.users.values() + for user in self._users.values() for refresh_token in user.refresh_tokens.values() for access_token in refresh_token.access_tokens ] @@ -637,7 +644,7 @@ async def async_save(self): 'secret': client.secret, 'redirect_uris': client.redirect_uris, } - for client in self.clients.values() + for client in self._clients.values() ] data = { diff --git a/tests/common.py b/tests/common.py index 1b8eabaa0db4bc..3a51cd3e059847 100644 --- a/tests/common.py +++ b/tests/common.py @@ -321,7 +321,7 @@ def add_to_hass(self, hass): def add_to_auth_manager(self, auth_mgr): """Test helper to add entry to hass.""" ensure_auth_manager_loaded(auth_mgr) - auth_mgr._store.users[self.id] = self + auth_mgr._store._users[self.id] = self return self @@ -329,10 +329,10 @@ def add_to_auth_manager(self, auth_mgr): def ensure_auth_manager_loaded(auth_mgr): """Ensure an auth manager is considered loaded.""" store = auth_mgr._store - if store.clients is None: - store.clients = {} - if store.users is None: - store.users = {} + if store._clients is None: + store._clients = {} + if store._users is None: + store._users = {} class MockModule(object): diff --git a/tests/components/auth/__init__.py b/tests/components/auth/__init__.py index f0b205ff5ce490..21719c12569b3b 100644 --- a/tests/components/auth/__init__.py +++ b/tests/components/auth/__init__.py @@ -34,7 +34,7 @@ async def async_setup_auth(hass, aiohttp_client, provider_configs=BASE_CONFIG, }) client = auth.Client('Test Client', CLIENT_ID, CLIENT_SECRET, redirect_uris=[CLIENT_REDIRECT_URI]) - hass.auth._store.clients[client.id] = client + hass.auth._store._clients[client.id] = client if setup_api: await async_setup_component(hass, 'api', {}) return await aiohttp_client(hass.http.app) diff --git a/tests/test_auth.py b/tests/test_auth.py index 89222b455a79de..5b545223c15a9c 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -191,12 +191,13 @@ async def test_saving_loading(hass, hass_storage): await flush_store(manager._store._store) store2 = auth.AuthStore(hass) - await store2.async_load() - assert len(store2.users) == 1 - assert store2.users[user.id] == user + users = await store2.async_get_users() + assert len(users) == 1 + assert users[0] == user - assert len(store2.clients) == 1 - assert store2.clients[client.id] == client + clients = await store2.async_get_clients() + assert len(clients) == 1 + assert clients[0] == client def test_access_token_expired(): From 20d5f52ff5f315945391787cb7f224ebab9e5d30 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sun, 1 Jul 2018 13:21:16 -0400 Subject: [PATCH 6/6] Update doc string --- homeassistant/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/homeassistant/auth.py b/homeassistant/auth.py index 8ef16aa4d29b46..767776f7ad9310 100644 --- a/homeassistant/auth.py +++ b/homeassistant/auth.py @@ -516,7 +516,7 @@ async def async_create_client(self, name, redirect_uris, no_secret): return client async def async_get_clients(self): - """Find a client, if not exists, create a new one.""" + """Return all clients.""" if self._clients is None: await self.async_load()